AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
topological_value_iteration_impl.h
1#ifndef GUARD_INCLUDE_PROBFD_ALGORITHMS_TOPOLOGICAL_VALUE_ITERATION_H
2#error "This file should only be included from topological_value_iteration.h"
3#endif
4
5#include "probfd/algorithms/utils.h"
6
7#include "probfd/policies/map_policy.h"
8
9#include "probfd/evaluator.h"
10#include "probfd/progress_report.h"
11
12#include "downward/utils/countdown_timer.h"
13
14#include <type_traits>
15
17
18inline void Statistics::print(std::ostream& out) const
19{
20 out << " Expanded state(s): " << expanded_states << std::endl;
21 out << " Terminal state(s): " << terminal_states << std::endl;
22 out << " Goal state(s): " << goal_states << std::endl;
23 out << " Pruned state(s): " << pruned << std::endl;
24 out << " Maximal SCCs: " << sccs << " (" << singleton_sccs
25 << " are singleton)" << std::endl;
26 out << " Bellman backups: " << bellman_backups << std::endl;
27}
28
29template <typename State, typename Action, bool UseInterval>
30TopologicalValueIteration<State, Action, UseInterval>::ExplorationInfo::
31 ExplorationInfo(StateID state_id, StackInfo& stack_info, unsigned stackidx)
32 : state_id(state_id)
33 , stack_info(stack_info)
34 , stackidx(stackidx)
35 , lowlink(stackidx)
36{
37}
38
39template <typename State, typename Action, bool UseInterval>
40void TopologicalValueIteration<State, Action, UseInterval>::ExplorationInfo::
41 update_lowlink(unsigned upd)
42{
43 lowlink = std::min(lowlink, upd);
44}
45
46template <typename State, typename Action, bool UseInterval>
47bool TopologicalValueIteration<State, Action, UseInterval>::ExplorationInfo::
48 next_transition(MDPType& mdp)
49{
50 aops.pop_back();
51 transition.clear();
52
53 self_loop_prob = 0_vt;
54
55 return !aops.empty() &&
56 forward_non_loop_transition(mdp, mdp.get_state(state_id));
57}
58
59template <typename State, typename Action, bool UseInterval>
60bool TopologicalValueIteration<State, Action, UseInterval>::ExplorationInfo::
61 next_successor()
62{
63 ++successor;
64 if (forward_non_loop_successor()) return true;
65
66 auto& tinfo = stack_info.nconv_qs.back();
67
68 if (tinfo.finalize_transition(self_loop_prob)) {
69 if (set_min(stack_info.conv_part, tinfo.conv_part)) {
70 stack_info.best_converged = tinfo.action;
71 }
72 stack_info.nconv_qs.pop_back();
73 }
74
75 return false;
76}
77
78template <typename State, typename Action, bool UseInterval>
79bool TopologicalValueIteration<State, Action, UseInterval>::ExplorationInfo::
80 forward_non_loop_transition(MDPType& mdp, const State& state)
81{
82 do {
83 mdp.generate_action_transitions(state, aops.back(), transition);
84 successor = transition.begin();
85
86 if (forward_non_loop_successor()) {
87 stack_info.nconv_qs.emplace_back(
88 aops.back(),
89 mdp.get_action_cost(aops.back()));
90 return true;
91 }
92
93 aops.pop_back();
94 transition.clear();
95 } while (!aops.empty());
96
97 return false;
98}
99
100template <typename State, typename Action, bool UseInterval>
101bool TopologicalValueIteration<State, Action, UseInterval>::ExplorationInfo::
102 forward_non_loop_successor()
103{
104 do {
105 if (successor->item != state_id) {
106 return true;
107 }
108
109 self_loop_prob += successor->probability;
110 } while (++successor != transition.end());
111
112 return false;
113}
114
115template <typename State, typename Action, bool UseInterval>
116Action& TopologicalValueIteration<State, Action, UseInterval>::ExplorationInfo::
117 get_current_action()
118{
119 return aops.back();
120}
121
122template <typename State, typename Action, bool UseInterval>
123ItemProbabilityPair<StateID>
124TopologicalValueIteration<State, Action, UseInterval>::ExplorationInfo::
125 get_current_successor()
126{
127 return *successor;
128}
129
130template <typename State, typename Action, bool UseInterval>
131TopologicalValueIteration<State, Action, UseInterval>::QValueInfo::QValueInfo(
132 Action action,
133 value_t action_cost)
134 : action(action)
135 , conv_part(action_cost)
136{
137}
138
139template <typename State, typename Action, bool UseInterval>
140bool TopologicalValueIteration<State, Action, UseInterval>::QValueInfo::
141 finalize_transition(value_t self_loop_prob)
142{
143 if (self_loop_prob != 0_vt) {
144 // Apply self-loop normalization
145 const value_t normalization = 1_vt / (1_vt - self_loop_prob);
146
147 conv_part *= normalization;
148
149 for (auto& pair : nconv_successors) {
150 pair.probability *= normalization;
151 }
152 }
153
154 return nconv_successors.empty();
155}
156
157template <typename State, typename Action, bool UseInterval>
158auto TopologicalValueIteration<State, Action, UseInterval>::QValueInfo::
159 compute_q_value() const -> AlgorithmValueType
160{
161 AlgorithmValueType res = conv_part;
162
163 for (auto& [value, prob] : nconv_successors) {
164 res += prob * (*value);
165 }
166
167 return res;
168}
169
170template <typename State, typename Action, bool UseInterval>
171TopologicalValueIteration<State, Action, UseInterval>::StackInfo::StackInfo(
172 StateID state_id,
173 AlgorithmValueType& value_ref)
174 : state_id(state_id)
175 , value(&value_ref)
176{
177}
178
179template <typename State, typename Action, bool UseInterval>
180bool TopologicalValueIteration<State, Action, UseInterval>::StackInfo::
181 update_value()
182{
183 AlgorithmValueType v = conv_part;
184 best_action = best_converged;
185
186 for (const QValueInfo& info : nconv_qs) {
187 if (set_min(v, info.compute_q_value())) {
188 best_action = info.action;
189 }
190 }
191
192 if constexpr (UseInterval) {
193 update(*value, v);
194 return !value->bounds_approximately_equal();
195 } else {
196 return update(*value, v);
197 }
199
200template <typename State, typename Action, bool UseInterval>
202 TopologicalValueIteration(bool expand_goals)
203 : expand_goals_(expand_goals)
204{
205}
206
207template <typename State, typename Action, bool UseInterval>
209 MDPType& mdp,
210 EvaluatorType& heuristic,
211 param_type<State> state,
213 double max_time) -> std::unique_ptr<PolicyType>
214{
215 storage::PerStateStorage<AlgorithmValueType> value_store;
216 std::unique_ptr<MapPolicy> policy(new MapPolicy(&mdp));
217 this->solve(
218 mdp,
219 heuristic,
220 mdp.get_state_id(state),
221 value_store,
222 max_time,
223 policy.get());
224 return policy;
225}
226
227template <typename State, typename Action, bool UseInterval>
229 MDPType& mdp,
230 EvaluatorType& heuristic,
231 param_type<State> state,
233 double max_time)
234{
235 storage::PerStateStorage<AlgorithmValueType> value_store;
236 return this
237 ->solve(mdp, heuristic, mdp.get_state_id(state), value_store, max_time);
238}
239
240template <typename State, typename Action, bool UseInterval>
242 std::ostream& out) const
243{
244 statistics_.print(out);
245}
246
247template <typename State, typename Action, bool UseInterval>
253
254template <typename State, typename Action, bool UseInterval>
255template <typename ValueStore>
257 MDPType& mdp,
258 EvaluatorType& heuristic,
259 StateID init_state_id,
260 ValueStore& value_store,
261 double max_time,
262 MapPolicy* policy)
263{
264 utils::CountdownTimer timer(max_time);
265
266 StateInfo& iinfo = state_information_[init_state_id];
267 AlgorithmValueType& init_value = value_store[init_state_id];
268
269 push_state(init_state_id, iinfo, init_value);
270
271 for (;;) {
272 ExplorationInfo* explore;
273
274 do {
275 explore = &exploration_stack_.back();
276 } while (initialize_state(mdp, heuristic, *explore, value_store) &&
277 successor_loop(mdp, *explore, value_store, timer));
278
279 do {
280 // Check if an SCC was found.
281 const unsigned stack_id = explore->stackidx;
282 const unsigned lowlink = explore->lowlink;
283 const bool backtrack_from_scc = stack_id == lowlink;
284
285 if (backtrack_from_scc) {
286 scc_found(stack_ | std::views::drop(stack_id), policy, timer);
287 }
288
289 exploration_stack_.pop_back();
290
291 if (exploration_stack_.empty()) {
292 if constexpr (UseInterval) {
293 return init_value;
294 } else {
295 return Interval(init_value, INFINITE_VALUE);
296 }
297 }
298
299 timer.throw_if_expired();
300
301 explore = &exploration_stack_.back();
302
303 const auto [succ_id, prob] = explore->get_current_successor();
304 AlgorithmValueType& s_value = value_store[succ_id];
305 QValueInfo& tinfo = explore->stack_info.nconv_qs.back();
306
307 if (backtrack_from_scc) {
308 tinfo.conv_part += prob * s_value;
309 } else {
310 explore->update_lowlink(lowlink);
311 tinfo.nconv_successors.emplace_back(&s_value, prob);
312 }
313 } while (
314 (!explore->next_successor() && !explore->next_transition(mdp)) ||
315 !successor_loop(mdp, *explore, value_store, timer));
316 }
317}
318
319template <typename State, typename Action, bool UseInterval>
321 StateID state_id,
322 StateInfo& state_info,
323 AlgorithmValueType& state_value)
324{
325 const std::size_t stack_size = stack_.size();
326 exploration_stack_.emplace_back(
327 state_id,
328 stack_.emplace_back(state_id, state_value),
329 stack_size);
330 state_info.stack_id = stack_size;
331 state_info.status = StateInfo::ONSTACK;
332}
333
334template <typename State, typename Action, bool UseInterval>
335bool TopologicalValueIteration<State, Action, UseInterval>::initialize_state(
336 MDPType& mdp,
337 EvaluatorType& heuristic,
338 ExplorationInfo& exp_info,
339 auto& value_store)
340{
341 assert(state_information_[exp_info.state_id].status == StateInfo::NEW);
342
343 const State state = mdp.get_state(exp_info.state_id);
344
345 const TerminationInfo state_eval = mdp.get_termination_info(state);
346 const value_t t_cost = state_eval.get_cost();
347 const value_t estimate = heuristic.evaluate(state);
348
349 exp_info.stack_info.conv_part = AlgorithmValueType(t_cost);
350
351 AlgorithmValueType& state_value = value_store[exp_info.state_id];
352
353 if constexpr (UseInterval) {
354 state_value.lower = estimate;
355 state_value.upper = t_cost;
356 } else {
357 state_value = estimate;
358 }
359
360 if (state_eval.is_goal_state()) {
361 ++statistics_.goal_states;
362
363 if (!expand_goals_) {
364 ++statistics_.pruned;
365 return false;
366 }
367 } else if (estimate == t_cost) {
368 ++statistics_.pruned;
369 return false;
370 }
371
372 mdp.generate_applicable_actions(state, exp_info.aops);
373
374 const size_t num_aops = exp_info.aops.size();
375
376 exp_info.stack_info.nconv_qs.reserve(num_aops);
377
378 ++statistics_.expanded_states;
379
380 if (exp_info.aops.empty()) {
381 ++statistics_.terminal_states;
382 } else if (exp_info.forward_non_loop_transition(mdp, state)) {
383 return true;
384 }
385
386 return false;
387}
388
389template <typename State, typename Action, bool UseInterval>
390template <typename ValueStore>
391bool TopologicalValueIteration<State, Action, UseInterval>::successor_loop(
392 MDPType& mdp,
393 ExplorationInfo& explore,
394 ValueStore& value_store,
395 utils::CountdownTimer& timer)
396{
397 do {
398 assert(!explore.stack_info.nconv_qs.empty());
399 QValueInfo& tinfo = explore.stack_info.nconv_qs.back();
400
401 do {
402 timer.throw_if_expired();
403
404 const auto [succ_id, prob] = explore.get_current_successor();
405 assert(succ_id != explore.state_id);
406 StateInfo& succ_info = state_information_[succ_id];
407 AlgorithmValueType& s_value = value_store[succ_id];
408
409 switch (succ_info.status) {
410 default: abort();
411 case StateInfo::NEW: {
412 push_state(succ_id, succ_info, s_value);
413 return true; // recursion on new state
414 }
415
416 case StateInfo::CLOSED: tinfo.conv_part += prob * s_value; break;
417
418 case StateInfo::ONSTACK:
419 explore.update_lowlink(succ_info.stack_id);
420 tinfo.nconv_successors.emplace_back(&s_value, prob);
421 }
422 } while (explore.next_successor());
423 } while (explore.next_transition(mdp));
424
425 return false;
426}
427
428template <typename State, typename Action, bool UseInterval>
429void TopologicalValueIteration<State, Action, UseInterval>::scc_found(
430 auto scc,
431 MapPolicy* policy,
432 utils::CountdownTimer& timer)
433{
434 assert(!scc.empty());
435
436 ++statistics_.sccs;
437
438 if (scc.size() == 1) {
439 // Singleton SCCs can only transition to a child SCC. The state
440 // value has already converged due to topological ordering.
441 ++statistics_.singleton_sccs;
442 StackInfo& single = scc.front();
443 StateInfo& state_info = state_information_[single.state_id];
444 update(*single.value, single.conv_part);
445 assert(state_info.status == StateInfo::ONSTACK);
446 state_info.status = StateInfo::CLOSED;
447 } else {
448 // Mark all states as closed
449 for (StackInfo& stk_info : scc) {
450 StateInfo& state_info = state_information_[stk_info.state_id];
451 assert(state_info.status == StateInfo::ONSTACK);
452 assert(!stk_info.nconv_qs.empty());
453 state_info.status = StateInfo::CLOSED;
454 }
455
456 // Now run VI on the SCC until convergence
457 bool converged;
458
459 do {
460 timer.throw_if_expired();
461
462 converged = true;
463 auto it = scc.begin();
464
465 do {
466 if (it->update_value()) converged = false;
467 ++statistics_.bellman_backups;
468 } while (++it != scc.end());
469 } while (!converged);
470
471 // Extract a policy from this SCC
472 if (policy) {
473 for (StackInfo& stk_info : scc) {
474 if constexpr (UseInterval) {
475 policy->emplace_decision(
476 stk_info.state_id,
477 *stk_info.best_action,
478 *stk_info.value);
479 } else {
480 policy->emplace_decision(
481 stk_info.state_id,
482 *stk_info.best_action,
483 Interval(*stk_info.value, INFINITE_VALUE));
484 }
485 }
486 }
487 }
488
489 stack_.erase(scc.begin(), scc.end());
490}
491
492} // namespace probfd::algorithms::topological_vi
A registry for print functions related to search progress.
Definition progress_report.h:33
Specifies the termination cost and goal status of a state.
Definition cost_function.h:13
bool is_goal_state() const
Check if this state is a goal.
Definition cost_function.h:34
value_t get_cost() const
Obtains the cost paid upon termination in the state.
Definition cost_function.h:41
Implements Topological Value Iteration dai:etal:jair-11.
Definition topological_value_iteration.h:68
Statistics get_statistics() const
Retreive the algorithm statistics.
Definition topological_value_iteration_impl.h:249
void print_statistics(std::ostream &out) const override
Prints algorithm statistics to the specified output stream.
Definition topological_value_iteration_impl.h:241
Namespace dedicated to Topological Value Iteration (TVI).
Definition topological_value_iteration.h:27
bool update(Interval &lhs, Interval rhs, value_t epsilon=g_epsilon)
Intersects two intervals and assigns the result to the left operand.
bool set_min(Interval &lhs, Interval rhs)
Computes the assignments lhs.lower <- min(lhs.lower, rhs.lower) and lower <- min(lhs....
double value_t
Typedef for the state value type.
Definition aliases.h:7
typename std::conditional_t< is_cheap_to_copy_v< T >, T, const T & > param_type
Alias template defining the best way to pass a parameter of a given type.
Definition type_traits.h:25
Represents a closed interval over the extended reals as a pair of lower and upper bound.
Definition interval.h:12
A StateID represents a state within a StateIDMap. Just like Fast Downward's StateID type,...
Definition types.h:22
Topological value iteration statistics.
Definition topological_value_iteration.h:32