AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
ta_topological_value_iteration_impl.h
1#ifndef GUARD_INCLUDE_PROBFD_ALGORITHMS_TA_TOPOLOGICAL_VALUE_ITERATION_H
2#error "This file should only be included from ta_topological_value_iteration.h"
3#endif
4
5#include "probfd/algorithms/utils.h"
6
7#include "probfd/utils/guards.h"
8#include "probfd/utils/iterators.h"
9#include "probfd/utils/not_implemented.h"
10
11#include "probfd/cost_function.h"
12#include "probfd/evaluator.h"
13#include "probfd/progress_report.h"
14
15#include "downward/utils/countdown_timer.h"
16
17#include <type_traits>
18
20
21inline void Statistics::print(std::ostream& out) const
22{
23 out << " Expanded state(s): " << expanded_states << std::endl;
24 out << " Terminal state(s): " << terminal_states << std::endl;
25 out << " Goal state(s): " << goal_states << std::endl;
26 out << " Pruned state(s): " << pruned << std::endl;
27 out << " Maximal SCCs: " << sccs << " (" << singleton_sccs
28 << " are singleton)" << std::endl;
29 out << " Bellman backups: " << bellman_backups << std::endl;
30
31 out << " Time spent initializing state data: " << initialize_state_timer
32 << std::endl;
33 out << " Time spent expanding successors: " << successor_handling_timer
34 << std::endl;
35 out << " Time spent handling SCCs: " << scc_handling_timer << std::endl;
36 out << " Time spent backtracking: " << backtracking_timer << std::endl;
37 out << " Time spent running VI on SCCs: " << vi_timer << std::endl;
38 out << " Time spent decomposing ECs: " << decomposition_timer << std::endl;
39 out << " Time spent in solvability analysis: " << solvability_timer
40 << std::endl;
41}
42
43template <typename State, typename Action, bool UseInterval>
44auto TATopologicalValueIteration<State, Action, UseInterval>::StateInfo::
45 get_status() const
46{
47 return explored ? (stack_id < UNDEF ? ONSTACK : CLOSED) : NEW;
48}
49
50template <typename State, typename Action, bool UseInterval>
51auto TATopologicalValueIteration<State, Action, UseInterval>::StateInfo::
52 get_ecd_status() const
53{
54 return explored ? (ecd_stack_id < UNDEF_ECD ? ONSTACK : CLOSED) : NEW;
55}
56
57template <typename State, typename Action, bool UseInterval>
58TATopologicalValueIteration<State, Action, UseInterval>::ExplorationInfo::
59 ExplorationInfo(
60 StateID state_id,
61 StackInfo& stack_info,
62 unsigned int stackidx)
63 : state_id(state_id)
64 , stack_info(stack_info)
65 , stackidx(stackidx)
66 , lowlink(stackidx)
67{
68}
69
70template <typename State, typename Action, bool UseInterval>
71bool TATopologicalValueIteration<State, Action, UseInterval>::ExplorationInfo::
72 next_transition(MDPType& mdp)
73{
74 aops.pop_back();
75 transition.clear();
76
77 assert(q_value.scc_successors.empty());
78
79 return !aops.empty() &&
80 forward_non_loop_transition(mdp, mdp.get_state(state_id));
81}
82
83template <typename State, typename Action, bool UseInterval>
84bool TATopologicalValueIteration<State, Action, UseInterval>::ExplorationInfo::
85 forward_non_loop_transition(MDPType& mdp, const State& state)
86{
87 do {
88 mdp.generate_action_transitions(state, aops.back(), transition);
89 const value_t self_loop_prob = transition.remove_if_normalize(state_id);
90
91 if (!transition.empty()) {
92 successor = transition.begin();
93
94 const value_t normalization = 1.0_vt / (1.0_vt - self_loop_prob);
95 const auto cost = normalization * mdp.get_action_cost(aops.back());
96 if (cost != 0.0_vt) has_all_zero = false;
97 q_value.conv_part = AlgorithmValueType(cost);
98 return true;
99 }
100
101 aops.pop_back();
102 transition.clear();
103 } while (!aops.empty());
104
105 return false;
106}
107
108template <typename State, typename Action, bool UseInterval>
109bool TATopologicalValueIteration<State, Action, UseInterval>::ExplorationInfo::
110 next_successor()
111{
112 if (++successor != transition.end()) {
113 return true;
114 }
115
116 const bool exits_only_solvable =
117 q_value.conv_part != AlgorithmValueType(INFINITE_VALUE);
118
119 const size_t num_scc_succs = q_value.scc_successors.size();
120
121 if (num_scc_succs == 0) {
122 // Universally exiting -> Not part of scc
123 // Update converged portion of q value and ignore this
124 // transition
125 set_min(stack_info.conv_part, q_value.conv_part);
126
127 if (exits_only_solvable) {
128 ++stack_info.active_exit_transitions;
129 ++stack_info.active_transitions;
130 }
131 } else {
132 const bool leaves_scc = num_scc_succs != transition.size();
133
134 if (leaves_scc || q_value.conv_part != AlgorithmValueType(0_vt)) {
135 // Only some exiting or cost is non-zero ->
136 // Not part of an end component
137 // Add the transition to the set of non-EC transitions
138 stack_info.add_non_ec_transition(std::move(q_value));
139 } else {
140 // Otherwise add to EC transitions
141 stack_info.ec_transitions.emplace_back(std::move(q_value));
142 }
143
144 if (exits_only_solvable) {
145 if (leaves_scc) {
146 ++stack_info.active_exit_transitions;
147 }
148 ++stack_info.active_transitions;
149 }
150 stack_info.transition_flags.emplace_back(
151 exits_only_solvable && leaves_scc,
152 exits_only_solvable);
153 }
154
155 assert(q_value.scc_successors.empty());
156
157 return false;
158}
159
160template <typename State, typename Action, bool UseInterval>
161ItemProbabilityPair<StateID>
162TATopologicalValueIteration<State, Action, UseInterval>::ExplorationInfo::
163 get_current_successor()
164{
165 return *successor;
166}
167
168template <typename State, typename Action, bool UseInterval>
169template <typename ValueStore>
170auto TATopologicalValueIteration<State, Action, UseInterval>::QValueInfo::
171 compute_q_value(ValueStore& value_store) const -> AlgorithmValueType
172{
173 AlgorithmValueType res = conv_part;
174
175 for (auto& [state_id, prob] : scc_successors) {
176 res += prob * value_store[state_id];
177 }
178
179 return res;
180}
181
182template <typename State, typename Action, bool UseInterval>
183TATopologicalValueIteration<State, Action, UseInterval>::StackInfo::StackInfo(
184 StateID state_id,
185 AlgorithmValueType& value_ref)
186 : state_id(state_id)
187 , value(&value_ref)
188{
189}
190
191template <typename State, typename Action, bool UseInterval>
192void TATopologicalValueIteration<State, Action, UseInterval>::StackInfo::
193 add_non_ec_transition(QValueInfo&& info)
194{
195 auto [it, inserted] = non_ec_transitions.insert(std::move(info));
196 if (!inserted) {
197 // Not moved if insertion didn't take place
198 set_min(it->conv_part, info.conv_part);
199 info.scc_successors.clear();
200 }
201}
202
203template <typename State, typename Action, bool UseInterval>
204TATopologicalValueIteration<State, Action, UseInterval>::ECDExplorationInfo::
205 ECDExplorationInfo(StackInfo& stack_info, unsigned stackidx)
206 : stack_info(stack_info)
207 , stackidx(stackidx)
208 , lowlink(stackidx)
209{
210}
211
212template <typename State, typename Action, bool UseInterval>
213bool TATopologicalValueIteration<State, Action, UseInterval>::
214 ECDExplorationInfo::next_transition()
215{
216 assert(!action->scc_successors.empty());
217 assert(action < end);
218
219 if (!leaves_scc) {
220 if (++action != end) {
221 successor = action->scc_successors.begin();
222 assert(!action->scc_successors.empty());
223
224 remains_scc = false;
225
226 return true;
227 }
228 } else {
229 if (remains_scc) {
230 recurse = true;
231 }
232
233 stack_info.add_non_ec_transition(std::move(*action));
234
235 if (action != --end) {
236 *action = std::move(*end);
237 assert(!action->scc_successors.empty());
238 successor = action->scc_successors.begin();
239
240 leaves_scc = false;
241 remains_scc = false;
242
243 return true;
244 }
245 }
246
247 auto& ect = stack_info.ec_transitions;
248 ect.erase(end, ect.end());
249
250 return false;
251}
252
253template <typename State, typename Action, bool UseInterval>
254bool TATopologicalValueIteration<State, Action, UseInterval>::
255 ECDExplorationInfo::next_successor()
256{
257 assert(!action->scc_successors.empty());
258 return ++successor != action->scc_successors.end();
259}
260
261template <typename State, typename Action, bool UseInterval>
262ItemProbabilityPair<StateID>
263TATopologicalValueIteration<State, Action, UseInterval>::ECDExplorationInfo::
264 get_current_successor()
265{
266 return *successor;
267}
268
269template <typename State, typename Action, bool UseInterval>
270Interval TATopologicalValueIteration<State, Action, UseInterval>::solve(
271 MDPType& mdp,
272 EvaluatorType& heuristic,
273 param_type<State> state,
274 ProgressReport,
275 double max_time)
276{
277 storage::PerStateStorage<AlgorithmValueType> value_store;
278 return this
279 ->solve(mdp, heuristic, mdp.get_state_id(state), value_store, max_time);
280}
281
282template <typename State, typename Action, bool UseInterval>
284 std::ostream& out) const
285{
286 statistics_.print(out);
287}
288
289template <typename State, typename Action, bool UseInterval>
295
296template <typename State, typename Action, bool UseInterval>
298 MDPType& mdp,
299 const EvaluatorType& heuristic,
300 StateID init_state_id,
301 auto& value_store,
302 double max_time)
303{
304 statistics_ = Statistics();
305
306 // scope_exit _([this] { statistics_.print(std::cout); });
307
308 utils::CountdownTimer timer(max_time);
309
310 push_state(
311 init_state_id,
312 state_information_[init_state_id],
313 value_store[init_state_id]);
314
315 for (;;) {
316 ExplorationInfo* explore;
317
318 do {
319 explore = &exploration_stack_.back();
320 } while (initialize_state(mdp, heuristic, *explore, value_store) &&
321 successor_loop(mdp, *explore, value_store, timer));
322
323 // Iterative backtracking
324 do {
325 const unsigned stack_id = explore->stackidx;
326 const unsigned lowlink = explore->lowlink;
327
328 assert(stack_id >= lowlink);
329
330 const bool backtrack_from_scc = stack_id == lowlink;
331
332 // Check if an SCC was found.
333 if (backtrack_from_scc) {
334 scc_found(value_store, *explore, stack_id, timer);
335
336 if (stack_id == 0) {
337 assert(stack_.empty());
338 assert(exploration_stack_.size() == 1);
339 exploration_stack_.pop_back();
340
341 if constexpr (UseInterval) {
342 return value_store[init_state_id];
343 } else {
344 return Interval(
345 value_store[init_state_id],
346 INFINITE_VALUE);
347 }
348 }
349 }
350
351 assert(exploration_stack_.size() > 1);
352
353 timer.throw_if_expired();
354
355 TimerScope _(statistics_.backtracking_timer);
356
357 const ExplorationInfo& successor = *explore--;
358
359 const auto [succ_id, prob] = explore->get_current_successor();
360
361 if (backtrack_from_scc) {
362 const AlgorithmValueType value = value_store[succ_id];
363 explore->q_value.conv_part += prob * value;
364 explore->exit_interval.lower =
365 std::min(explore->exit_interval.lower, value);
366 explore->exit_interval.upper =
367 std::max(explore->exit_interval.upper, value);
368 } else {
369 explore->lowlink = std::min(explore->lowlink, lowlink);
370 explore->exit_interval.lower = std::min(
371 explore->exit_interval.lower,
372 successor.exit_interval.lower);
373 explore->exit_interval.upper = std::max(
374 explore->exit_interval.upper,
375 successor.exit_interval.upper);
376 explore->has_all_zero =
377 explore->has_all_zero && successor.has_all_zero;
378
379 explore->q_value.scc_successors.emplace_back(succ_id, prob);
380 successor.stack_info.parents.emplace_back(
381 explore->stackidx,
382 explore->stack_info.transition_flags.size());
383 }
384
385 exploration_stack_.pop_back();
386 } while (
387 (!explore->next_successor() && !explore->next_transition(mdp)) ||
388 !successor_loop(mdp, *explore, value_store, timer));
389 }
390}
391
392template <typename State, typename Action, bool UseInterval>
394 MDPType&,
395 EvaluatorType&,
398 double) -> std::unique_ptr<PolicyType>
399{
400 not_implemented();
401}
402
403template <typename State, typename Action, bool UseInterval>
404void TATopologicalValueIteration<State, Action, UseInterval>::push_state(
405 StateID state_id,
406 StateInfo& state_info,
407 AlgorithmValueType& value)
408{
409 const std::size_t stack_size = stack_.size();
410 exploration_stack_.emplace_back(
411 state_id,
412 stack_.emplace_back(state_id, value),
413 stack_size);
414 state_info.explored = 1;
415 state_info.stack_id = stack_size;
416}
417
418template <typename State, typename Action, bool UseInterval>
419bool TATopologicalValueIteration<State, Action, UseInterval>::successor_loop(
420 MDPType& mdp,
421 ExplorationInfo& explore,
422 auto& value_store,
423 utils::CountdownTimer& timer)
424{
425 TimerScope _(statistics_.successor_handling_timer);
426
427 do {
428 timer.throw_if_expired();
429
430 const auto [succ_id, prob] = explore.get_current_successor();
431 assert(succ_id != explore.state_id);
432
433 StateInfo& succ_info = state_information_[succ_id];
434
435 switch (succ_info.get_status()) {
436 default: abort();
437 case StateInfo::NEW: {
438 push_state(succ_id, succ_info, value_store[succ_id]);
439 return true; // recursion on new state
440 }
441
442 case StateInfo::CLOSED: {
443 const AlgorithmValueType value = value_store[succ_id];
444 explore.q_value.conv_part += prob * value;
445 explore.exit_interval.lower =
446 std::min(explore.exit_interval.lower, value);
447 explore.exit_interval.upper =
448 std::max(explore.exit_interval.upper, value);
449 break;
450 }
451
452 case StateInfo::ONSTACK:
453 unsigned succ_stack_id = succ_info.stack_id;
454 explore.lowlink = std::min(explore.lowlink, succ_stack_id);
455 explore.q_value.scc_successors.emplace_back(succ_id, prob);
456
457 auto& parents = stack_[succ_stack_id].parents;
458 parents.emplace_back(
459 explore.stackidx,
460 explore.stack_info.transition_flags.size());
461 }
462 } while (explore.next_successor() || explore.next_transition(mdp));
463
464 return false;
465}
466
467template <typename State, typename Action, bool UseInterval>
468bool TATopologicalValueIteration<State, Action, UseInterval>::initialize_state(
469 MDPType& mdp,
470 const EvaluatorType& heuristic,
471 ExplorationInfo& exp_info,
472 auto& value_store)
473{
474 assert(
475 state_information_[exp_info.state_id].get_status() ==
476 StateInfo::ONSTACK);
477
478 TimerScope _(statistics_.initialize_state_timer);
479
480 const State state = mdp.get_state(exp_info.state_id);
481
482 const TerminationInfo state_term = mdp.get_termination_info(state);
483 const value_t t_cost = state_term.get_cost();
484 const value_t estimate = heuristic.evaluate(state);
485
486 exp_info.stack_info.conv_part = AlgorithmValueType(t_cost);
487 exp_info.exit_interval = Interval(t_cost);
488
489 AlgorithmValueType& state_value = value_store[exp_info.state_id];
490
491 if constexpr (UseInterval) {
492 state_value.lower = estimate;
493 state_value.upper = t_cost;
494 } else {
495 state_value = estimate;
496 }
497
498 if (t_cost != INFINITE_VALUE) {
499 ++exp_info.stack_info.active_exit_transitions;
500 ++exp_info.stack_info.active_transitions;
501 }
502
503 if (state_term.is_goal_state()) {
504 ++statistics_.goal_states;
505 } else if (estimate == t_cost) {
506 ++statistics_.pruned;
507 return false;
508 }
509
510 mdp.generate_applicable_actions(state, exp_info.aops);
511
512 const size_t num_aops = exp_info.aops.size();
513
514 exp_info.stack_info.ec_transitions.reserve(num_aops);
515
516 ++statistics_.expanded_states;
517
518 if (exp_info.aops.empty()) {
519 ++statistics_.terminal_states;
520 } else if (exp_info.forward_non_loop_transition(mdp, state)) {
521 return true;
522 }
523
524 return false;
525}
526
527template <typename State, typename Action, bool UseInterval>
528void TATopologicalValueIteration<State, Action, UseInterval>::scc_found(
529 auto& value_store,
530 ExplorationInfo& exp_info,
531 unsigned int stack_idx,
532 utils::CountdownTimer& timer)
533{
534 using namespace std::views;
535
536 TimerScope _(statistics_.scc_handling_timer);
537
538 auto scc = stack_ | drop(stack_idx);
539
540 assert(!scc.empty());
541
542 ++statistics_.sccs;
543
544 if (exp_info.exit_interval.lower == INFINITE_VALUE ||
545 (exp_info.exit_interval.lower == exp_info.exit_interval.upper &&
546 exp_info.has_all_zero)) {
547 for (StackInfo& stk_info : scc) {
548 StateInfo& state_info = state_information_[stk_info.state_id];
549 assert(state_info.get_status() == StateInfo::ONSTACK);
550 update(*stk_info.value, exp_info.exit_interval.lower);
551 state_info.stack_id = StateInfo::UNDEF;
552 }
553
554 stack_.erase(scc.begin(), scc.end());
555 return;
556 }
557
558 if (scc.size() == 1) {
559 // For singleton SCCs, we only have transitions which are
560 // self-loops or go to a state that is topologically greater.
561 // The state value is therefore the base value.
562 StackInfo& single = scc.front();
563 StateInfo& state_info = state_information_[single.state_id];
564 assert(state_info.get_status() == StateInfo::ONSTACK);
565 update(*single.value, single.conv_part);
566 state_info.stack_id = StateInfo::UNDEF;
567 ++statistics_.singleton_sccs;
568 stack_.pop_back();
569 return;
570 }
571
572 if (std::ranges::any_of(scc, [](const StackInfo& stk_info) {
573 return !stk_info.non_ec_transitions.empty();
574 })) {
575 // Run recursive EC Decomposition
576 TimerScope _(statistics_.decomposition_timer);
577
578 scc_.reserve(scc.size());
579
580 for (const auto state_id : scc | transform(&StackInfo::state_id)) {
581 scc_.push_back(state_id);
582 }
583
584 // Run decomposition
585 find_and_decompose_sccs(timer);
586
587 for (StackInfo& stack_info : scc) {
588 StateInfo& state_info = state_information_[stack_info.state_id];
589 state_info.stack_id = StateInfo::UNDEF;
590 assert(state_info.get_status() == StateInfo::CLOSED);
591 }
592
593 assert(exploration_stack_ecd_.empty());
594 } else {
595 // We found an end component, patch it
596 StackInfo& repr_stk = scc.front();
597 const StateID scc_repr_id = repr_stk.state_id;
598 state_information_[scc_repr_id].stack_id = StateInfo::UNDEF;
599
600 // Spider construction
601 for (StackInfo& succ_stk : scc | drop(1)) {
602 state_information_[succ_stk.state_id].stack_id = StateInfo::UNDEF;
603
604 // Move all non-EC transitions to representative state
605 auto& tr = succ_stk.non_ec_transitions;
606 for (auto it = tr.begin(); it != tr.end();) {
607 repr_stk.add_non_ec_transition(
608 std::move(tr.extract(it++).value()));
609 }
610
611 // Free memory
612 std::decay_t<decltype(tr)>().swap(tr);
613
614 set_min(repr_stk.conv_part, succ_stk.conv_part);
615
616 succ_stk.conv_part = AlgorithmValueType(INFINITE_VALUE);
617
618 // Connect to representative state with zero cost action
619 succ_stk.non_ec_transitions.emplace(
620 0.0_vt,
621 std::vector{ItemProbabilityPair<StateID>(scc_repr_id, 1.0_vt)});
622 }
623 }
624
625 class Partition {
626 std::vector<std::vector<int>::iterator> scc_index_to_local;
627 std::vector<int> partition;
628 std::vector<int>::iterator solvable_beg;
629 std::vector<int>::iterator solvable_exits_beg;
630
631 public:
632 explicit Partition(std::size_t size)
633 : scc_index_to_local(size)
634 , partition(size, 0)
635 {
636 for (unsigned int i = 0; i != size; ++i) {
637 scc_index_to_local[i] = partition.begin() + i;
638 partition[i] = static_cast<int>(i);
639 }
640
641 solvable_beg = partition.begin();
642 solvable_exits_beg = partition.begin();
643 }
644
645 auto solvable_begin() { return solvable_beg; }
646 auto solvable_end() { return partition.end(); }
647
648 auto solvable()
649 {
650 return std::ranges::subrange(solvable_begin(), solvable_end());
651 }
652
653 [[nodiscard]]
654 bool has_solvable() const
655 {
656 return solvable_beg != partition.end();
657 }
658
659 void demote_unsolvable(int s)
660 {
661 assert(scc_index_to_local[s] >= solvable_beg);
662 assert(scc_index_to_local[s] < solvable_exits_beg);
663
664 auto local = scc_index_to_local[s];
665 std::swap(scc_index_to_local[*solvable_beg], scc_index_to_local[s]);
666 std::swap(*solvable_beg, *local);
667
668 ++solvable_beg;
669
670 assert(scc_index_to_local[s] < solvable_beg);
671 }
672
673 void demote_exit_unsolvable(int s)
674 {
675 demote_exit_solvable(s);
676 demote_unsolvable(s);
677 }
678
679 void demote_exit_solvable(int s)
680 {
681 assert(scc_index_to_local[s] >= solvable_exits_beg);
682
683 auto local = scc_index_to_local[s];
684 std::swap(
685 scc_index_to_local[*solvable_exits_beg],
686 scc_index_to_local[s]);
687 std::swap(*solvable_exits_beg, *local);
688
689 ++solvable_exits_beg;
690
691 assert(scc_index_to_local[s] >= solvable_beg);
692 assert(scc_index_to_local[s] < solvable_exits_beg);
693 }
694
695 bool promote_solvable(int s)
696 {
697 if (!is_unsolvable(s)) {
698 return false;
699 }
700
701 --solvable_beg;
702
703 auto local = scc_index_to_local[s];
704 std::swap(scc_index_to_local[*solvable_beg], scc_index_to_local[s]);
705 std::swap(*solvable_beg, *local);
706
707 assert(scc_index_to_local[s] >= solvable_beg);
708 assert(scc_index_to_local[s] < solvable_exits_beg);
709
710 return true;
711 }
712
713 void mark_non_exit_states_unsolvable()
714 {
715 solvable_beg = solvable_exits_beg;
716 }
717
718 bool is_unsolvable(int s)
719 {
720 return scc_index_to_local[s] < solvable_beg;
721 }
722 };
723
724 {
725 TimerScope _(statistics_.solvability_timer);
726
727 // Set the value of unsolvable states of this SCC to -inf.
728 // Start by partitioning states into inactive states, active exits and
729 // active non-exists.
730 // The partition is initialized optimistically, i.e., all states start
731 // out as active exits.
732 Partition partition(scc.size());
733
734 for (std::size_t i = 0; i != scc.size(); ++i) {
735 StackInfo& info = scc[i];
736
737 assert(
738 info.active_transitions != 0 ||
739 info.active_exit_transitions == 0);
740
741 // Transform to local indices
742 for (auto& parent_info : info.parents) {
743 parent_info.parent_idx -= stack_idx;
744 }
745
746 if (info.active_exit_transitions == 0) {
747 if (info.active_transitions > 0) {
748 partition.demote_exit_solvable(i);
749 } else {
750 value_store[info.state_id] =
751 AlgorithmValueType(INFINITE_VALUE);
752 partition.demote_exit_unsolvable(i);
753 }
754 }
755 }
756
757 if (partition.has_solvable()) {
758 // Compute the set of solvable states of this SCC.
759 for (;;) {
760 timer.throw_if_expired();
761
762 // Collect states that can currently reach an exit and mark
763 // other states unsolvable.
764 auto unsolv_it = partition.solvable_begin();
765
766 partition.mark_non_exit_states_unsolvable();
767
768 for (auto it = partition.solvable_end();
769 it != partition.solvable_begin();) {
770 for (const auto& [parent_idx, tr_idx] :
771 scc[*--it].parents) {
772 StackInfo& pinfo = scc[parent_idx];
773
774 if (pinfo.transition_flags[tr_idx].is_active) {
775 partition.promote_solvable(parent_idx);
776 }
777 }
778 }
779
780 // No new unsolvable states -> stop.
781 if (unsolv_it == partition.solvable_begin()) break;
782
783 // Run fixpoint iteration starting with the new unsolvable
784 // states that could not reach an exit anymore.
785 do {
786 timer.throw_if_expired();
787
788 StackInfo& scc_elem = scc[*unsolv_it];
789
790 // The state was marked unsolvable.
791 assert(partition.is_unsolvable(*unsolv_it));
792
793 value_store[scc_elem.state_id] =
794 AlgorithmValueType(INFINITE_VALUE);
795
796 for (const auto& [parent_idx, tr_idx] : scc_elem.parents) {
797 StackInfo& pinfo = scc[parent_idx];
798 auto& transition_flags = pinfo.transition_flags[tr_idx];
799
800 assert(
801 !transition_flags.is_active_exiting ||
802 transition_flags.is_active);
803
804 if (partition.is_unsolvable(parent_idx)) continue;
805
806 if (transition_flags.is_active_exiting) {
807 transition_flags.is_active_exiting = false;
808 transition_flags.is_active = false;
809
810 --pinfo.active_transitions;
811 --pinfo.active_exit_transitions;
812
813 if (pinfo.active_transitions == 0) {
814 partition.demote_exit_unsolvable(parent_idx);
815 } else if (pinfo.active_exit_transitions == 0) {
816 partition.demote_exit_solvable(parent_idx);
817 }
818 } else if (transition_flags.is_active) {
819 transition_flags.is_active = false;
820
821 --pinfo.active_transitions;
822
823 if (pinfo.active_transitions == 0) {
824 partition.demote_unsolvable(parent_idx);
825 }
826 }
827 }
828 } while (++unsolv_it != partition.solvable_begin());
829 }
830 }
831 }
832
833 // Now run VI on the SCC until convergence
834 {
835 TimerScope _(statistics_.vi_timer);
836
837 struct VIInfo {
838 AlgorithmValueType* value;
839 AlgorithmValueType conv_part;
840 std::vector<QValueInfo> transitions;
841 };
842
843 std::vector<VIInfo> table;
844 table.reserve(scc.size());
845
846 for (auto it = scc.begin(); it != scc.end(); ++it) {
847 auto& t = table.emplace_back(it->value, it->conv_part);
848 auto& tr = it->non_ec_transitions;
849 t.transitions.reserve(tr.size());
850 for (auto it2 = tr.begin(); it2 != tr.end();) {
851 t.transitions.push_back(std::move(tr.extract(it2++).value()));
852 }
853 }
854
855 bool converged;
856
857 do {
858 timer.throw_if_expired();
859
860 converged = true;
861 auto it = table.begin();
862
863 do {
864 AlgorithmValueType v = it->conv_part;
865
866 for (const QValueInfo& info : it->transitions) {
867 set_min(v, info.compute_q_value(value_store));
868 }
869
870 if constexpr (UseInterval) {
871 update(*it->value, v);
872 if (!it->value->bounds_equal()) converged = false;
873 } else {
874 if (update(*it->value, v)) converged = false;
875 }
876
877 ++statistics_.bellman_backups;
878 } while (++it != table.end());
879 } while (!converged);
880 }
881
882 stack_.erase(scc.begin(), scc.end());
883}
884
885template <typename State, typename Action, bool UseInterval>
886void TATopologicalValueIteration<State, Action, UseInterval>::
887 find_and_decompose_sccs(utils::CountdownTimer& timer)
888{
889 do {
890 for (const StateID state_id : scc_) {
891 state_information_[state_id].explored = 0;
892 }
893
894 for (const StateID state_id : scc_) {
895 StateInfo& state_info = state_information_[state_id];
896 if (state_info.get_ecd_status() != StateInfo::NEW) continue;
897
898 state_info.explored = 1;
899 state_info.ecd_stack_id = 0;
900 exploration_stack_ecd_.emplace_back(stack_[state_info.stack_id], 0);
901 stack_ecd_.emplace_back(state_id);
902
903 for (;;) {
904 ECDExplorationInfo* e;
905
906 // DFS recursion
907 do {
908 e = &exploration_stack_ecd_.back();
909 } while (initialize_ecd(*e) && push_successor_ecd(*e, timer));
910
911 // Iterative backtracking
912 do {
913 const unsigned int stck = e->stackidx;
914 const unsigned int lowlink = e->lowlink;
915
916 assert(stck >= lowlink);
917
918 const bool backtracked_from_scc = stck == lowlink;
919
920 if (backtracked_from_scc) {
921 scc_found_ecd(*e);
922
923 if (stck == 0) {
924 assert(stack_ecd_.empty());
925 assert(exploration_stack_ecd_.size() == 1);
926 exploration_stack_ecd_.pop_back();
927 goto break_outer;
928 }
929 }
930
931 assert(exploration_stack_ecd_.size() > 1);
932
933 timer.throw_if_expired();
934
935 const ECDExplorationInfo& successor = *e--;
936
937 if (backtracked_from_scc) {
938 e->leaves_scc = true;
939 } else {
940 e->lowlink = std::min(e->lowlink, lowlink);
941 e->recurse = e->recurse || successor.recurse;
942 e->remains_scc = true;
943 }
944
945 exploration_stack_ecd_.pop_back();
946 } while ((!e->next_successor() && !e->next_transition()) ||
947 !push_successor_ecd(*e, timer));
948 }
949
950 break_outer:;
951 }
952
953 scc_.clear();
954 } while (decomposition_queue_.pop_scc(scc_));
955
956 assert(scc_.empty());
957}
958
959template <typename State, typename Action, bool UseInterval>
960bool TATopologicalValueIteration<State, Action, UseInterval>::
961 push_successor_ecd(ECDExplorationInfo& e, utils::CountdownTimer& timer)
962{
963 do {
964 timer.throw_if_expired();
965
966 const StateID succ_id = e.get_current_successor().item;
967 StateInfo& succ_info = state_information_[succ_id];
968
969 switch (succ_info.get_ecd_status()) {
970 case StateInfo::NEW: {
971 const auto stack_size = stack_ecd_.size();
972 succ_info.explored = 1;
973 succ_info.ecd_stack_id = stack_size;
974 exploration_stack_ecd_.emplace_back(
975 stack_[succ_info.stack_id],
976 stack_size);
977 stack_ecd_.emplace_back(succ_id);
978 return true;
979 }
980
981 case StateInfo::CLOSED: e.leaves_scc = true; break;
982
983 case StateInfo::ONSTACK:
984 e.lowlink = std::min(e.lowlink, succ_info.ecd_stack_id);
985 e.remains_scc = true;
986 }
987 } while (e.next_successor() || e.next_transition());
988
989 return false;
990}
991
992template <typename State, typename Action, bool UseInterval>
993bool TATopologicalValueIteration<State, Action, UseInterval>::initialize_ecd(
994 ECDExplorationInfo& exp_info)
995{
996 StackInfo& stack_info = exp_info.stack_info;
997
998 if (stack_info.ec_transitions.empty()) {
999 return false;
1000 }
1001
1002 exp_info.action = stack_info.ec_transitions.begin();
1003 exp_info.end = stack_info.ec_transitions.end();
1004 exp_info.successor = exp_info.action->scc_successors.begin();
1005
1006 return true;
1007}
1008
1009template <typename State, typename Action, bool UseInterval>
1010void TATopologicalValueIteration<State, Action, UseInterval>::scc_found_ecd(
1011 ECDExplorationInfo& e)
1012{
1013 namespace vws = std::views;
1014
1015 auto scc = stack_ecd_ | std::views::drop(e.stackidx);
1016
1017 if (scc.size() == 1) {
1018 state_information_[scc.front()].ecd_stack_id = StateInfo::UNDEF_ECD;
1019 } else if (e.recurse) {
1020 decomposition_queue_.register_new_scc();
1021 for (const StateID state_id : scc) {
1022 decomposition_queue_.add_scc_state(state_id);
1023 state_information_[state_id].ecd_stack_id = StateInfo::UNDEF_ECD;
1024 }
1025 } else {
1026 // We found an end component, patch it
1027 const StateID scc_repr_id = scc.front();
1028 StateInfo& repr_state_info = state_information_[scc_repr_id];
1029 StackInfo& repr_stk = stack_[repr_state_info.stack_id];
1030
1031 repr_state_info.ecd_stack_id = StateInfo::UNDEF_ECD;
1032
1033 // Spider construction
1034 for (const StateID state_id : scc | std::views::drop(1)) {
1035 StateInfo& state_info = state_information_[state_id];
1036 StackInfo& succ_stk = stack_[state_info.stack_id];
1037
1038 state_info.ecd_stack_id = StateInfo::UNDEF_ECD;
1039
1040 // Move all non-EC transitions to representative state
1041 auto& tr = succ_stk.non_ec_transitions;
1042 for (auto it = tr.begin(); it != tr.end();) {
1043 repr_stk.add_non_ec_transition(
1044 std::move(tr.extract(it++).value()));
1045 }
1046
1047 // Free memory
1048 std::decay_t<decltype(tr)>().swap(tr);
1049
1050 set_min(repr_stk.conv_part, succ_stk.conv_part);
1051 succ_stk.conv_part = AlgorithmValueType(INFINITE_VALUE);
1052
1053 // Connect to representative state with zero cost action
1054 succ_stk.non_ec_transitions.emplace(
1055 0.0_vt,
1056 std::vector{ItemProbabilityPair<StateID>(scc_repr_id, 1.0_vt)});
1057 }
1058 }
1059
1060 stack_ecd_.erase(scc.begin(), scc.end());
1061
1062 assert(stack_ecd_.size() == e.stackidx);
1063}
1064
1065} // namespace probfd::algorithms::ta_topological_vi
A registry for print functions related to search progress.
Definition progress_report.h:33
Implements a trap-aware variant of Topological Value Iteration.
Definition ta_topological_value_iteration.h:63
Namespace dedicated to trap-aware Topological Value Iteration (TATVI).
Definition ta_topological_value_iteration.h:24
bool update(Interval &lhs, Interval rhs, value_t epsilon=g_epsilon)
Intersects two intervals and assigns the result to the left operand.
bool set_min(Interval &lhs, Interval rhs)
Computes the assignments lhs.lower <- min(lhs.lower, rhs.lower) and lower <- min(lhs....
double value_t
Typedef for the state value type.
Definition aliases.h:7
typename std::conditional_t< is_cheap_to_copy_v< T >, T, const T & > param_type
Alias template defining the best way to pass a parameter of a given type.
Definition type_traits.h:25
Represents a closed interval over the extended reals as a pair of lower and upper bound.
Definition interval.h:12
value_t lower
The Lower bound of the interval.
Definition interval.h:13
value_t upper
The upper bound of the interval.
Definition interval.h:14
A StateID represents a state within a StateIDMap. Just like Fast Downward's StateID type,...
Definition types.h:22
Topological value iteration statistics.
Definition ta_topological_value_iteration.h:29