1#ifndef GUARD_INCLUDE_PROBFD_ABSTRACTIONS_POLICY_EXTRACTION_H
2#error "This file should only be included from policy_extraction.h"
5#include "probfd/pdbs/projection_operator.h"
6#include "probfd/pdbs/projection_state_space.h"
8#include "probfd/policies/vector_multi_policy.h"
10#include "probfd/distribution.h"
11#include "probfd/transition.h"
13#include "downward/utils/rng.h"
17#include <unordered_set>
23template <
typename State,
typename Action>
26 std::span<const value_t> value_table,
28 utils::RandomNumberGenerator& rng,
31 using PredecessorEdge = std::pair<State, Action>;
33 std::deque<StateID> open = {initial_state};
34 std::unordered_set<StateID> closed = {initial_state};
36 std::vector<std::vector<PredecessorEdge>> predecessors(value_table.size());
37 std::vector<StateID> goals;
40 while (!open.empty()) {
46 const value_t value = value_table[s];
50 if (value == term_cost) {
56 std::vector<Transition<Action>> transitions;
60 for (
const auto& [op, successor_dist] : transitions) {
62 successor_dist.expectation(value_table);
66 for (
const StateID succ : successor_dist.support()) {
67 if (closed.insert(succ).second) {
71 predecessors[succ].emplace_back(s, op);
76 auto policy = std::make_unique<policies::VectorMultiPolicy<State, Action>>(
83 open.insert(open.end(), goals.begin(), goals.end());
85 closed.insert(goals.begin(), goals.end());
87 while (!open.empty()) {
89 auto it = rng.choose(open);
92 std::swap(*it, open.back());
96 rng.shuffle(predecessors[s]);
98 for (
const auto& [pstate_id, sel_op] : predecessors[s]) {
99 if (!closed.insert(pstate_id).second)
continue;
100 open.push_back(pstate_id);
102 const value_t parent_cost = value_table[pstate_id];
104 const State pstate = mdp.
get_state(pstate_id);
111 std::vector<PolicyDecision<Action>> decisions;
113 std::vector<Transition<Action>> transitions;
116 for (
const auto& [op, successor_dist] : transitions) {
117 if (successor_dist == sel_successor_dist &&
119 decisions.emplace_back(op,
Interval(parent_cost));
124 if (!wildcard) decisions = {*rng.choose(decisions)};
126 (*policy)[pstate_id] = std::move(decisions);
128 assert(!(*policy)[pstate_id].empty());
135template <
typename State,
typename Action>
138 std::span<const value_t> value_table,
140 utils::RandomNumberGenerator& rng,
143 auto policy = std::make_unique<policies::VectorMultiPolicy<State, Action>>(
147 std::deque<StateID> open{initial_state};
148 std::unordered_set<StateID> closed{initial_state};
151 while (!open.empty()) {
155 const value_t value = value_table[s];
161 if (value == term_cost) {
166 std::vector<Transition<Action>> transitions;
169 if (transitions.empty()) {
174 rng.shuffle(transitions);
177 auto it = std::ranges::find(transitions, [&](
const auto& transition) {
178 const auto& [op, successor_dist] = transition;
180 successor_dist.expectation(value_table);
185 assert(it != transitions.end());
190 for (
const StateID succ : it->successor_dist.support()) {
191 if (!closed.insert(succ).second)
continue;
192 open.push_back(succ);
196 std::vector<PolicyDecision<Action>> decisions;
197 decisions.emplace_back(it->action,
Interval(value));
199 for (
const auto& [op, successor_dist] :
200 std::ranges::subrange(std::next(it), transitions.end())) {
203 if (successor_dist == it->successor_dist &&
204 cost_op == cost_greedy) {
205 decisions.emplace_back(op,
Interval(value));
210 if (!wildcard) decisions = {*rng.choose(decisions)};
211 (*policy)[s] = std::move(decisions);
213 assert(!(*policy)[s].empty());
virtual value_t get_action_cost(param_type< Action > action)=0
Gets the cost of an action.
virtual TerminationInfo get_termination_info(param_type< State > state)=0
Returns the cost to terminate in a given state and checks whether a state is a goal.
A convenience class that represents a finite probability distribution.
Definition task_state_space.h:27
Basic interface for MDPs.
Definition mdp_algorithm.h:14
virtual void generate_all_transitions(param_type< State > state, std::vector< Action > &aops, std::vector< Distribution< StateID > > &successors)=0
Generates all applicable actions and their corresponding successor distributions for a given state.
virtual State get_state(StateID state_id)=0
Get the state mapped to a given state ID.
virtual void generate_action_transitions(param_type< State > state, param_type< Action > action, Distribution< StateID > &result)=0
Generates the successor distribution for a given state and action.
value_t get_cost() const
Obtains the cost paid upon termination in the state.
Definition cost_function.h:41
The top-level namespace of probabilistic Fast Downward.
Definition command_line.h:8
std::unique_ptr< MultiPolicy< State, Action > > compute_greedy_projection_policy(MDP< State, Action > &mdp, std::span< const value_t > value_table, param_type< State > initial_state, utils::RandomNumberGenerator &rng, bool wildcard)
Extracts an abstract greedy policy from the value table, which may not be optimal if traps are existe...
Definition policy_extraction_impl.h:136
std::unique_ptr< MultiPolicy< State, Action > > compute_optimal_projection_policy(MDP< State, Action > &mdp, std::span< const value_t > value_table, param_type< State > initial_state, utils::RandomNumberGenerator &rng, bool wildcard)
Extract an abstract optimal policy from the value table.
Definition policy_extraction_impl.h:24
double value_t
Typedef for the state value type.
Definition aliases.h:7
typename std::conditional_t< is_cheap_to_copy_v< T >, T, const T & > param_type
Alias template defining the best way to pass a parameter of a given type.
Definition type_traits.h:25
bool is_approx_equal(value_t v1, value_t v2, value_t epsilon=g_epsilon)
Equivalent to .
Represents a closed interval over the extended reals as a pair of lower and upper bound.
Definition interval.h:12
A StateID represents a state within a StateIDMap. Just like Fast Downward's StateID type,...
Definition types.h:22