AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
policy_extraction_impl.h
1#ifndef GUARD_INCLUDE_PROBFD_ABSTRACTIONS_POLICY_EXTRACTION_H
2#error "This file should only be included from policy_extraction.h"
3#endif
4
5#include "probfd/pdbs/projection_operator.h"
6#include "probfd/pdbs/projection_state_space.h"
7
8#include "probfd/policies/vector_multi_policy.h"
9
10#include "probfd/distribution.h"
11#include "probfd/transition.h"
12
13#include "downward/utils/rng.h"
14
15#include <cassert>
16#include <deque>
17#include <unordered_set>
18#include <utility>
19#include <vector>
20
21namespace probfd {
22
23template <typename State, typename Action>
24std::unique_ptr<MultiPolicy<State, Action>> compute_optimal_projection_policy(
26 std::span<const value_t> value_table,
27 param_type<State> initial_state,
28 utils::RandomNumberGenerator& rng,
29 bool wildcard)
30{
31 using PredecessorEdge = std::pair<State, Action>;
32
33 std::deque<StateID> open = {initial_state};
34 std::unordered_set<StateID> closed = {initial_state};
35
36 std::vector<std::vector<PredecessorEdge>> predecessors(value_table.size());
37 std::vector<StateID> goals;
38
39 // Build the greedy policy graph
40 while (!open.empty()) {
41 StateID s = open.front();
42 open.pop_front();
43
44 const State state = mdp.get_state(s);
45
46 const value_t value = value_table[s];
47
48 // Skip states in which termination is optimal
49 const value_t term_cost = mdp.get_termination_info(s).get_cost();
50 if (value == term_cost) {
51 goals.push_back(s);
52 continue;
53 }
54
55 // Generate operators...
56 std::vector<Transition<Action>> transitions;
57 mdp.generate_all_transitions(state, transitions);
58
59 // Select the greedy operators and add their successors
60 for (const auto& [op, successor_dist] : transitions) {
61 value_t op_value = mdp.get_action_cost(op) +
62 successor_dist.expectation(value_table);
63
64 if (!is_approx_equal(value, op_value)) continue;
65
66 for (const StateID succ : successor_dist.support()) {
67 if (closed.insert(succ).second) {
68 open.push_back(succ);
69 }
70
71 predecessors[succ].emplace_back(s, op);
72 }
73 }
74 }
75
76 auto policy = std::make_unique<policies::VectorMultiPolicy<State, Action>>(
77 &mdp,
78 value_table.size());
79
80 // Do regression search with duplicate checking through the constructed
81 // graph, expanding predecessors randomly to select an optimal policy
82 assert(open.empty());
83 open.insert(open.end(), goals.begin(), goals.end());
84 closed.clear();
85 closed.insert(goals.begin(), goals.end());
86
87 while (!open.empty()) {
88 // Choose a random successor
89 auto it = rng.choose(open);
90 StateID s = *it;
91
92 std::swap(*it, open.back());
93 open.pop_back();
94
95 // Consider predecessors in random order
96 rng.shuffle(predecessors[s]);
97
98 for (const auto& [pstate_id, sel_op] : predecessors[s]) {
99 if (!closed.insert(pstate_id).second) continue;
100 open.push_back(pstate_id);
101
102 const value_t parent_cost = value_table[pstate_id];
103
104 const State pstate = mdp.get_state(pstate_id);
105
106 // Collect all equivalent greedy operators
107 const value_t cost_sel_op = mdp.get_action_cost(sel_op);
108 Distribution<StateID> sel_successor_dist;
109 mdp.generate_action_transitions(pstate, sel_op, sel_successor_dist);
110
111 std::vector<PolicyDecision<Action>> decisions;
112
113 std::vector<Transition<Action>> transitions;
114 mdp.generate_all_transitions(pstate, transitions);
115
116 for (const auto& [op, successor_dist] : transitions) {
117 if (successor_dist == sel_successor_dist &&
118 mdp.get_action_cost(op) == cost_sel_op) {
119 decisions.emplace_back(op, Interval(parent_cost));
120 }
121 }
122
123 // If not wildcard, randomly pick one
124 if (!wildcard) decisions = {*rng.choose(decisions)};
125
126 (*policy)[pstate_id] = std::move(decisions);
127
128 assert(!(*policy)[pstate_id].empty());
129 }
130 }
131
132 return policy;
133}
134
135template <typename State, typename Action>
136std::unique_ptr<MultiPolicy<State, Action>> compute_greedy_projection_policy(
138 std::span<const value_t> value_table,
139 param_type<State> initial_state,
140 utils::RandomNumberGenerator& rng,
141 bool wildcard)
142{
143 auto policy = std::make_unique<policies::VectorMultiPolicy<State, Action>>(
144 &mdp,
145 value_table.size());
146
147 std::deque<StateID> open{initial_state};
148 std::unordered_set<StateID> closed{initial_state};
149
150 // Build the greedy policy graph
151 while (!open.empty()) {
152 StateID s = open.front();
153 open.pop_front();
154
155 const value_t value = value_table[s];
156
157 const State state = mdp.get_state(s);
158
159 // Skip states in which termination is optimal
160 const value_t term_cost = mdp.get_termination_info(state).get_cost();
161 if (value == term_cost) {
162 continue;
163 }
164
165 // Generate operators...
166 std::vector<Transition<Action>> transitions;
167 mdp.generate_all_transitions(state, transitions);
168
169 if (transitions.empty()) {
170 continue;
171 }
172
173 // Look at the (greedy) operators in random order.
174 rng.shuffle(transitions);
175
176 // Find first greedy transition
177 auto it = std::ranges::find(transitions, [&](const auto& transition) {
178 const auto& [op, successor_dist] = transition;
179 const value_t op_value = mdp.get_action_cost(op) +
180 successor_dist.expectation(value_table);
181
182 return is_approx_equal(value, op_value);
183 });
184
185 assert(it != transitions.end());
186
187 const value_t cost_greedy = mdp.get_action_cost(it->action);
188
189 // Generate successors
190 for (const StateID succ : it->successor_dist.support()) {
191 if (!closed.insert(succ).second) continue;
192 open.push_back(succ);
193 }
194
195 // Collect all equivalent greedy operators
196 std::vector<PolicyDecision<Action>> decisions;
197 decisions.emplace_back(it->action, Interval(value));
198
199 for (const auto& [op, successor_dist] :
200 std::ranges::subrange(std::next(it), transitions.end())) {
201 const value_t cost_op = mdp.get_action_cost(op);
202
203 if (successor_dist == it->successor_dist &&
204 cost_op == cost_greedy) {
205 decisions.emplace_back(op, Interval(value));
206 }
207 }
208
209 // If not wildcard, randomly pick one
210 if (!wildcard) decisions = {*rng.choose(decisions)};
211 (*policy)[s] = std::move(decisions);
212
213 assert(!(*policy)[s].empty());
214 }
215
216 return policy;
217}
218
219} // namespace probfd
virtual value_t get_action_cost(param_type< Action > action)=0
Gets the cost of an action.
virtual TerminationInfo get_termination_info(param_type< State > state)=0
Returns the cost to terminate in a given state and checks whether a state is a goal.
A convenience class that represents a finite probability distribution.
Definition task_state_space.h:27
Basic interface for MDPs.
Definition mdp_algorithm.h:14
virtual void generate_all_transitions(param_type< State > state, std::vector< Action > &aops, std::vector< Distribution< StateID > > &successors)=0
Generates all applicable actions and their corresponding successor distributions for a given state.
virtual State get_state(StateID state_id)=0
Get the state mapped to a given state ID.
virtual void generate_action_transitions(param_type< State > state, param_type< Action > action, Distribution< StateID > &result)=0
Generates the successor distribution for a given state and action.
value_t get_cost() const
Obtains the cost paid upon termination in the state.
Definition cost_function.h:41
The top-level namespace of probabilistic Fast Downward.
Definition command_line.h:8
std::unique_ptr< MultiPolicy< State, Action > > compute_greedy_projection_policy(MDP< State, Action > &mdp, std::span< const value_t > value_table, param_type< State > initial_state, utils::RandomNumberGenerator &rng, bool wildcard)
Extracts an abstract greedy policy from the value table, which may not be optimal if traps are existe...
Definition policy_extraction_impl.h:136
std::unique_ptr< MultiPolicy< State, Action > > compute_optimal_projection_policy(MDP< State, Action > &mdp, std::span< const value_t > value_table, param_type< State > initial_state, utils::RandomNumberGenerator &rng, bool wildcard)
Extract an abstract optimal policy from the value table.
Definition policy_extraction_impl.h:24
double value_t
Typedef for the state value type.
Definition aliases.h:7
typename std::conditional_t< is_cheap_to_copy_v< T >, T, const T & > param_type
Alias template defining the best way to pass a parameter of a given type.
Definition type_traits.h:25
bool is_approx_equal(value_t v1, value_t v2, value_t epsilon=g_epsilon)
Equivalent to .
Represents a closed interval over the extended reals as a pair of lower and upper bound.
Definition interval.h:12
A StateID represents a state within a StateIDMap. Just like Fast Downward's StateID type,...
Definition types.h:22