AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
acyclic_value_iteration_impl.h
1#ifndef GUARD_INCLUDE_PROBFD_ALGORITHMS_ACYCLIC_VALUE_ITERATION_H
2#error "This file should only be included from acyclic_value_iteration.h"
3#endif
4
5#include "probfd/algorithms/utils.h"
6
7#include "probfd/policies/map_policy.h"
8
9#include "probfd/evaluator.h"
10#include "probfd/mdp.h"
11
12#include "downward/utils/countdown_timer.h"
13
15
16template <typename State, typename Action>
18 const State&)
19{
20 ++state_expansions;
21}
22
23template <typename State, typename Action>
25{
26 ++goal_states;
27}
28
29template <typename State, typename Action>
31{
32 ++terminal_states;
33}
34
35template <typename State, typename Action>
37{
38 ++pruned_states;
39}
40
41template <typename State, typename Action>
42void StatisticsObserver<State, Action>::print(std::ostream& out) const
43{
44 out << " Expanded state(s): " << state_expansions << std::endl;
45 out << " Pruned state(s): " << pruned_states << std::endl;
46 out << " Terminal state(s): " << terminal_states << std::endl;
47 out << " Goal state(s): " << goal_states << std::endl;
48}
49
50namespace internal {
51
52template <typename State, typename Action>
53void AcyclicValueIterationObserverCollection<State, Action>::register_observer(
54 std::shared_ptr<Observer> observer)
55{
56 observers_.push_back(std::move(observer));
57}
58
59template <typename State, typename Action>
60void AcyclicValueIterationObserverCollection<State, Action>::
61 notify_state_selected_for_expansion(const State& state)
62{
63 for (auto& observer : observers_) {
64 observer.on_state_selected_for_expansion(state);
65 }
66}
67
68template <typename State, typename Action>
69void AcyclicValueIterationObserverCollection<State, Action>::notify_goal_state(
70 const State& state)
71{
72 for (auto& observer : observers_) {
73 observer.on_goal_state(state);
74 }
75}
76
77template <typename State, typename Action>
78void AcyclicValueIterationObserverCollection<State, Action>::
79 notify_terminal_state(const State& state)
80{
81 for (auto& observer : observers_) {
82 observer.on_terminal_state(state);
83 }
84}
85
86template <typename State, typename Action>
87void AcyclicValueIterationObserverCollection<State, Action>::
88 notify_pruned_state(const State& state)
89{
90 for (auto& observer : observers_) {
91 observer.on_pruned_state(state);
92 }
93}
94
95} // namespace internal
96
97template <typename State, typename Action>
98AcyclicValueIteration<State, Action>::IncrementalExpansionInfo::
99 IncrementalExpansionInfo(StateID state_id, StateInfo& state_info)
100 : state_id(state_id)
101 , state_info(state_info)
102{
103}
104
105template <typename State, typename Action>
106void AcyclicValueIteration<State, Action>::IncrementalExpansionInfo::
107 setup_transition(MDPType& mdp)
108{
109 assert(transition.empty());
110 auto& next_action = remaining_aops.back();
111 t_value = mdp.get_action_cost(next_action);
112 const State state = mdp.get_state(state_id);
113 mdp.generate_action_transitions(state, next_action, transition);
114 successor = transition.begin();
115}
116
117template <typename State, typename Action>
118void AcyclicValueIteration<State, Action>::IncrementalExpansionInfo::
119 backtrack_successor(value_t probability, StateInfo& succ_info)
120{
121 // Update transition Q-value
122 t_value += probability * succ_info.value;
123 succ_info.status = StateInfo::CLOSED;
124}
125
126template <typename State, typename Action>
127bool AcyclicValueIteration<State, Action>::IncrementalExpansionInfo::advance(
128 MDPType& mdp,
129 MapPolicy* policy)
130{
131 return next_successor() || next_transition(mdp, policy);
132}
133
134template <typename State, typename Action>
135bool AcyclicValueIteration<State, Action>::IncrementalExpansionInfo::
136 next_successor()
137{
138 assert(successor != transition.end());
139 if (++successor != transition.end()) {
140 return true;
141 }
142
143 finalize_transition();
144
145 return false;
146}
147
148template <typename State, typename Action>
149bool AcyclicValueIteration<State, Action>::IncrementalExpansionInfo::
150 next_transition(MDPType& mdp, MapPolicy* policy)
151{
152 assert(!remaining_aops.empty());
153 remaining_aops.pop_back();
154
155 if (remaining_aops.empty()) {
156 finalize_expansion(policy);
157 return false;
158 }
159
160 transition.clear();
161 setup_transition(mdp);
162
163 return true;
164}
165
166template <typename State, typename Action>
167void AcyclicValueIteration<State, Action>::IncrementalExpansionInfo::
168 finalize_transition()
169{
170 // Minimum Q-value
171 if (t_value < state_info.value) {
172 state_info.best_action = remaining_aops.back();
173 state_info.value = t_value;
174 }
175}
176
177template <typename State, typename Action>
178void AcyclicValueIteration<State, Action>::IncrementalExpansionInfo::
179 finalize_expansion(MapPolicy* policy)
180{
181 if (!policy || !state_info.best_action) return;
182 policy->emplace_decision(
183 state_id,
184 *state_info.best_action,
185 Interval(state_info.value));
186}
187
188template <typename State, typename Action>
189auto AcyclicValueIteration<State, Action>::compute_policy(
190 MDPType& mdp,
191 EvaluatorType& heuristic,
192 param_type<State> initial_state,
193 ProgressReport,
194 double max_time) -> std::unique_ptr<PolicyType>
195{
196 std::unique_ptr<MapPolicy> policy(new MapPolicy(&mdp));
197 this->solve(mdp, heuristic, initial_state, max_time, policy.get());
198 return policy;
199}
200
201template <typename State, typename Action>
202Interval AcyclicValueIteration<State, Action>::solve(
203 MDPType& mdp,
204 EvaluatorType& heuristic,
205 param_type<State> initial_state,
206 ProgressReport,
207 double max_time)
208{
209 return solve(mdp, heuristic, initial_state, max_time, nullptr);
210}
211
212template <typename State, typename Action>
213Interval AcyclicValueIteration<State, Action>::solve(
214 MDPType& mdp,
215 EvaluatorType& heuristic,
216 param_type<State> initial_state,
217 double max_time,
218 MapPolicy* policy)
219{
220 utils::CountdownTimer timer(max_time);
221
222 const StateID initial_state_id = mdp.get_state_id(initial_state);
223 StateInfo& iinfo = state_infos_[initial_state_id];
224
225 expansion_stack_.emplace(initial_state_id, iinfo);
226 iinfo.status = StateInfo::ON_STACK;
227
228 IncrementalExpansionInfo* e;
229
230 for (;;) {
231 do {
232 e = &expansion_stack_.top();
233 } while (expand_state(mdp, heuristic, *e) &&
234 push_successor(mdp, policy, *e, timer));
235
236 do {
237 expansion_stack_.pop();
238
239 if (expansion_stack_.empty()) {
240 return Interval(iinfo.value);
241 }
242
243 timer.throw_if_expired();
244
245 e = &expansion_stack_.top();
246
247 const auto [succ_id, probability] = *e->successor;
248 e->backtrack_successor(probability, state_infos_[succ_id]);
249 } while (!e->advance(mdp, policy) ||
250 !push_successor(mdp, policy, *e, timer));
251 }
252}
253
254template <typename State, typename Action>
255void AcyclicValueIteration<State, Action>::register_observer(
256 std::shared_ptr<Observer> observer)
257{
258 observers_.register_observer(std::move(observer));
259}
260
261template <typename State, typename Action>
262bool AcyclicValueIteration<State, Action>::push_successor(
263 MDPType& mdp,
264 MapPolicy* policy,
265 IncrementalExpansionInfo& e,
266 utils::CountdownTimer& timer)
267{
268 do {
269 timer.throw_if_expired();
270
271 const auto [succ_id, probability] = *e.successor;
272 StateInfo& succ_info = state_infos_[succ_id];
273
274 if (succ_info.status == StateInfo::ON_STACK) {
275 std::cerr << "State space is not acyclic!" << std::endl;
276 utils::exit_with(utils::ExitCode::SEARCH_CRITICAL_ERROR);
277 }
278
279 if (succ_info.status == StateInfo::NEW) {
280 expansion_stack_.emplace(succ_id, succ_info);
281 succ_info.status = StateInfo::ON_STACK;
282 return true; // DFS recursion
283 }
284
285 assert(succ_info.status == StateInfo::CLOSED);
286
287 e.backtrack_successor(probability, succ_info);
288 } while (e.advance(mdp, policy));
289
290 return false;
291}
292
293template <typename State, typename Action>
294bool AcyclicValueIteration<State, Action>::expand_state(
295 MDPType& mdp,
296 EvaluatorType& heuristic,
297 IncrementalExpansionInfo& e_info)
298{
299 const State state = mdp.get_state(e_info.state_id);
300 StateInfo& succ_info = e_info.state_info;
301
302 assert(succ_info.status == StateInfo::ON_STACK);
303
304 const TerminationInfo term_info = mdp.get_termination_info(state);
305 const value_t term_value = term_info.get_cost();
306
307 succ_info.value = term_value;
308
309 if (term_info.is_goal_state()) {
310 observers_.notify_state_selected_for_expansion(state);
311 observers_.notify_goal_state(state);
312 return false;
313 }
314
315 if (heuristic.evaluate(state) == term_value) {
316 observers_.notify_pruned_state(state);
317 return false;
318 }
319
320 observers_.notify_state_selected_for_expansion(state);
321
322 assert(e_info.remaining_aops.empty());
323
324 mdp.generate_applicable_actions(state, e_info.remaining_aops);
325 if (e_info.remaining_aops.empty()) {
326 observers_.notify_terminal_state(state);
327 return false;
328 }
329
330 e_info.setup_transition(mdp);
331
332 return true;
333}
334
335} // namespace probfd::algorithms::acyclic_vi
Namespace dedicated to the acyclic value iteration algorithm.
Definition acyclic_value_iteration.h:22
double value_t
Typedef for the state value type.
Definition aliases.h:7
typename std::conditional_t< is_cheap_to_copy_v< T >, T, const T & > param_type
Alias template defining the best way to pass a parameter of a given type.
Definition type_traits.h:25
An observer that collects basic statistics of the acyclic value iteration algorithm.
Definition acyclic_value_iteration.h:58
void on_state_selected_for_expansion(const State &)
Called when the algorithm selects a state for expansion.
Definition acyclic_value_iteration_impl.h:17
void on_goal_state(const State &)
Called when a goal state is encountered during the expansion check.
Definition acyclic_value_iteration_impl.h:24
void on_pruned_state(const State &)
Called when a state is pruned during the expansion check.
Definition acyclic_value_iteration_impl.h:36
void on_terminal_state(const State &)
Called when a terminal state is encountered during the expansion check.
Definition acyclic_value_iteration_impl.h:30