AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
trap_aware_dfhs_impl.h
1#ifndef GUARD_INCLUDE_PROBFD_ALGORITHMS_TRAP_AWARE_DFHS_H
2#error "This file should only be included from trap_aware_dfhs.h"
3#endif
4
5#include "probfd/algorithms/open_list.h"
6
7#include "probfd/policies/map_policy.h"
8
9#include "probfd/quotients/quotient_max_heuristic.h"
10
11#include "probfd/utils/guards.h"
12
13#include "downward/utils/countdown_timer.h"
14
15#include <cassert>
16#include <iterator>
17#include <ranges>
18
20
21namespace internal {
22
23inline void Statistics::print(std::ostream& out) const
24{
25 out << " Iterations: " << iterations << std::endl;
26 out << " Traps: " << traps << std::endl;
27 out << " Bellman backups (forward): " << fw_updates << std::endl;
28 out << " Bellman backups (backward): " << bw_updates << std::endl;
29 out << " State re-expansions: " << reexpansions << std::endl;
30 out << " Trap removal time: " << trap_timer << std::endl;
31}
32
33inline void Statistics::register_report(ProgressReport& report) const
34{
35 report.register_print([this](std::ostream& out) {
36 out << "iteration=" << iterations << ", traps=" << traps;
37 });
38}
39
40} // namespace internal
41
42template <typename State, typename Action, bool UseInterval>
43bool TADFHSImpl<State, Action, UseInterval>::ExplorationInformation::
44 next_successor()
45{
46 successors.pop_back();
47 return !successors.empty();
48}
49
50template <typename State, typename Action, bool UseInterval>
51StateID
52TADFHSImpl<State, Action, UseInterval>::ExplorationInformation::get_successor()
53 const
54{
55 return successors.back();
56}
57
58template <typename State, typename Action, bool UseInterval>
59void TADFHSImpl<State, Action, UseInterval>::ExplorationInformation::update(
60 const ExplorationInformation& other)
61{
62 if (!other.value_converged) value_converged = false;
63 if (!other.all_solved) all_solved = false;
64 if (!other.is_trap) is_trap = false;
65}
66
67template <typename State, typename Action, bool UseInterval>
68void TADFHSImpl<State, Action, UseInterval>::ExplorationInformation::clear()
69{
70 value_converged = true;
71 all_solved = true;
72 is_trap = true;
73}
74
75template <typename State, typename Action, bool UseInterval>
76TADFHSImpl<State, Action, UseInterval>::TADFHSImpl(
77 std::shared_ptr<QuotientPolicyPicker> policy_chooser,
78 bool forward_updates,
79 BacktrackingUpdateType backtrack_update_type,
80 bool cutoff_tip,
81 bool cutoff_inconsistent,
82 bool terminate_exploration_on_cutoff,
83 bool label_solved,
84 bool reexpand_traps)
85 : Base(policy_chooser)
86 , forward_updates_(forward_updates)
87 , backtrack_update_type_(backtrack_update_type)
88 , cutoff_tip_(cutoff_tip)
89 , cutoff_inconsistent_(cutoff_inconsistent)
90 , terminate_exploration_on_cutoff_(terminate_exploration_on_cutoff)
91 , label_solved_(label_solved)
92 , reexpand_traps_(reexpand_traps)
93 , stack_index_(NEW)
94{
95}
96
97template <typename State, typename Action, bool UseInterval>
98Interval TADFHSImpl<State, Action, UseInterval>::solve_quotient(
99 QuotientSystem& quotient,
100 QEvaluator& heuristic,
101 param_type<QState> qstate,
102 ProgressReport& progress,
103 double max_time)
104{
105 utils::CountdownTimer timer(max_time);
106
107 Base::initialize_initial_state(quotient, heuristic, qstate);
108
109 const StateID state_id = quotient.get_state_id(qstate);
110 const StateInfo& state_info = this->state_infos_[state_id];
111
112 progress.register_bound("v", [&state_info]() {
113 return as_interval(state_info.value);
114 });
115
116 statistics_.register_report(progress);
117
118 if (!label_solved_) {
119 dfhs_vi_driver(quotient, heuristic, state_id, progress, timer);
120 } else {
121 dfhs_label_driver(quotient, heuristic, state_id, progress, timer);
122 }
123
124 return state_info.get_bounds();
125}
126
127template <typename State, typename Action, bool UseInterval>
128void TADFHSImpl<State, Action, UseInterval>::print_statistics(
129 std::ostream& out) const
130{
131 Base::print_statistics(out);
132 statistics_.print(out);
133}
134
135template <typename State, typename Action, bool UseInterval>
136void TADFHSImpl<State, Action, UseInterval>::dfhs_vi_driver(
137 QuotientSystem& quotient,
138 QEvaluator& heuristic,
139 const StateID state,
140 ProgressReport& progress,
141 utils::CountdownTimer& timer)
142{
143 UpdateResult vi_res{true, true};
144 do {
145 const bool solved =
146 policy_exploration(quotient, heuristic, state, timer);
147 if (solved) {
148 vi_res = value_iteration(quotient, visited_states_, timer);
149 }
150 visited_states_.clear();
151 ++statistics_.iterations;
152 progress.print();
153 } while (vi_res.value_changed || vi_res.policy_changed);
154}
155
156template <typename State, typename Action, bool UseInterval>
157void TADFHSImpl<State, Action, UseInterval>::dfhs_label_driver(
158 QuotientSystem& quotient,
159 QEvaluator& heuristic,
160 const StateID state,
161 ProgressReport& progress,
162 utils::CountdownTimer& timer)
163{
164 bool solved;
165 do {
166 solved = policy_exploration(quotient, heuristic, state, timer) &&
167 this->state_infos_[state].is_solved();
168 visited_states_.clear();
169 ++statistics_.iterations;
170 progress.print();
171 } while (!solved);
172}
173
174template <typename State, typename Action, bool UseInterval>
175void TADFHSImpl<State, Action, UseInterval>::enqueue(
176 QuotientSystem& quotient,
177 ExplorationInformation& einfo,
178 StateID state,
179 QAction action,
180 const Distribution<StateID>& successor_dist)
181{
182 stack_.back().action = action;
183
184 einfo.successors.reserve(successor_dist.size());
185
186 for (const StateID item : successor_dist.support()) {
187 if (item == state) continue;
188 einfo.successors.push_back(item);
189 }
190
191 assert(!einfo.successors.empty());
192 einfo.is_trap = quotient.get_action_cost(action) == 0;
193}
194
195template <typename State, typename Action, bool UseInterval>
196bool TADFHSImpl<State, Action, UseInterval>::advance(
197 QuotientSystem& quotient,
198 ExplorationInformation& einfo)
199{
200 using enum BacktrackingUpdateType;
201
202 if (terminated_) {
203 einfo.value_converged = false;
204 einfo.all_solved = false;
205 } else if (einfo.next_successor()) {
206 return true;
207 }
208
209 if (backtrack_update_type_ == SINGLE ||
210 (backtrack_update_type_ == ON_DEMAND && !einfo.value_converged)) {
211 ++statistics_.bw_updates;
212
213 const QState state = quotient.get_state(einfo.state);
214 const value_t termination_cost =
215 quotient.get_termination_info(state).get_cost();
216
217 ClearGuard _(transitions_, qvalues_);
218 this->generate_non_tip_transitions(quotient, state, transitions_);
219
220 StateInfo& state_info = this->state_infos_[einfo.state];
221 auto value = this->compute_bellman_and_greedy(
222 quotient,
223 einfo.state,
224 transitions_,
225 termination_cost,
226 qvalues_);
227
228 bool value_changed = this->update_value(state_info, value);
229 bool policy_changed = this->update_policy(
230 state_info,
231 this->select_greedy_transition(
232 quotient,
233 state_info.get_policy(),
234 transitions_));
235 einfo.value_converged = einfo.value_converged && !value_changed;
236 einfo.all_solved =
237 einfo.all_solved && !value_changed && !policy_changed;
238 terminated_ = terminated_ || (terminate_exploration_on_cutoff_ &&
239 cutoff_inconsistent_ && value_changed);
240 }
241
242 return false;
243}
244
245template <typename State, typename Action, bool UseInterval>
246bool TADFHSImpl<State, Action, UseInterval>::push_successor(
247 QuotientSystem& quotient,
248 ExplorationInformation& einfo,
249 utils::CountdownTimer& timer)
250{
251 do {
252 timer.throw_if_expired();
253
254 const StateID succ = quotient.translate_state_id(einfo.get_successor());
255
256 const int succ_status = stack_index_[succ];
257
258 if (succ_status == NEW) {
259 push(succ);
260 return true;
261 } else if (succ_status == CLOSED) {
262 einfo.is_trap = false;
263 if (label_solved_) {
264 einfo.all_solved =
265 einfo.all_solved && this->state_infos_[succ].is_solved();
266 }
267 } else {
268 // is on stack
269 assert(succ_status >= 0);
270 einfo.lowlink = std::min(einfo.lowlink, succ_status);
271 }
272 } while (advance(quotient, einfo));
273
274 return false;
275}
276
277template <typename State, typename Action, bool UseInterval>
278void TADFHSImpl<State, Action, UseInterval>::push(StateID state_id)
279{
280 queue_.emplace_back(state_id, stack_.size());
281 stack_index_[state_id] = stack_.size();
282 stack_.emplace_back(state_id);
283}
284
285template <typename State, typename Action, bool UseInterval>
286bool TADFHSImpl<State, Action, UseInterval>::initialize(
287 QuotientSystem& quotient,
288 QEvaluator& heuristic,
289 ExplorationInformation& einfo)
290{
291 assert(!terminated_);
292
293 const StateID state_id = einfo.state;
294
295 StateInfo& state_info = this->state_infos_[state_id];
296 if (state_info.is_solved()) {
297 assert(label_solved_ || state_info.is_goal_or_terminal());
298 einfo.is_trap = false;
299 return false;
300 }
301
302 const bool tip = state_info.is_on_fringe();
303
304 if (tip || forward_updates_) {
305 ClearGuard _(transitions_, qvalues_);
306
307 const QState state = quotient.get_state(einfo.state);
308 const value_t termination_cost =
309 quotient.get_termination_info(state).get_cost();
310
311 if (tip) {
312 this->expand_and_initialize(
313 quotient,
314 heuristic,
315 state,
316 state_info,
317 transitions_);
318 } else {
319 this->generate_non_tip_transitions(quotient, state, transitions_);
320 }
321
322 ++statistics_.fw_updates;
323
324 auto value = this->compute_bellman_and_greedy(
325 quotient,
326 einfo.state,
327 transitions_,
328 termination_cost,
329 qvalues_);
330
331 auto transition = this->select_greedy_transition(
332 quotient,
333 state_info.get_policy(),
334 transitions_);
335
336 bool value_changed = this->update_value(state_info, value);
337 this->update_policy(state_info, transition);
338
339 einfo.value_converged = einfo.value_converged && !value_changed;
340 einfo.all_solved = einfo.all_solved && !value_changed;
341 const bool cutoff =
342 (cutoff_tip_ && tip) || (cutoff_inconsistent_ && value_changed);
343 terminated_ = terminate_exploration_on_cutoff_ && cutoff;
344
345 if (!transition) {
346 einfo.is_trap = false;
347 return false;
348 }
349
350 if (cutoff) {
351 einfo.is_trap = false;
352 einfo.value_converged = false;
353 einfo.all_solved = false;
354 return false;
355 }
356
357 enqueue(
358 quotient,
359 einfo,
360 state_id,
361 transition->action,
362 transition->successor_dist);
363 } else {
364 auto action = state_info.get_policy();
365 if (!action.has_value()) return false;
366
367 const QState state = quotient.get_state(state_id);
368 quotient.generate_action_transitions(state, *action, transition_);
369 enqueue(quotient, einfo, state_id, *action, transition_);
370 transition_.clear();
371 }
372
373 return true;
374}
375
376template <typename State, typename Action, bool UseInterval>
377bool TADFHSImpl<State, Action, UseInterval>::policy_exploration(
378 QuotientSystem& quotient,
379 QEvaluator& heuristic,
380 StateID start_state,
381 utils::CountdownTimer& timer)
382{
383 assert(visited_states_.empty());
384 terminated_ = false;
385
386 push(start_state);
387
388 ExplorationInformation* einfo;
389
390 for (;;) {
391 do {
392 einfo = &queue_.back();
393 } while (initialize(quotient, heuristic, *einfo) &&
394 push_successor(quotient, *einfo, timer));
395
396 do {
397 const int last_lowlink = einfo->lowlink;
398
399 // Is SCC root?
400 if (einfo->lowlink == stack_index_[einfo->state]) {
401 auto scc = stack_ | std::views::drop(last_lowlink);
402
403 if (scc.size() > 1 && einfo->is_trap) {
404 ++this->statistics_.traps;
405
406 const StateID state_id = einfo->state;
407
408 // Collapse trap and reset quotient state data
409 TimerScope scope(statistics_.trap_timer);
410
411 quotient.build_quotient(scc, *scc.begin());
412 StateInfo& state_info = this->state_infos_[state_id];
413 state_info.update_policy(std::nullopt);
414 state_info.set_on_fringe();
415
416 // re-push trap if enabled
417 if (reexpand_traps_) {
418 stack_.erase(scc.begin(), scc.end());
419 queue_.pop_back();
420 push(state_id);
421 break;
422 }
423
424 stack_index_[state_id] = CLOSED;
425
426 einfo->value_converged = false;
427 einfo->all_solved = false;
428 } else {
429 for (const auto state_id :
430 scc | std::views::transform(&StackInfo::state_id)) {
431 stack_index_[state_id] = CLOSED;
432
433 if (!einfo->all_solved) continue;
434
435 StateInfo& mem_info = this->state_infos_[state_id];
436 if (mem_info.is_solved()) continue;
437
438 if (label_solved_) {
439 mem_info.set_solved();
440 } else {
441 visited_states_.push_back(state_id);
442 }
443 }
444 }
445
446 einfo->is_trap = false;
447 stack_.erase(scc.begin(), scc.end());
448 }
449
450 ExplorationInformation bt_einfo = std::move(*einfo);
451 queue_.pop_back();
452
453 if (queue_.empty()) {
454 assert(stack_.empty());
455 stack_index_.clear();
456 return einfo->all_solved;
457 }
458
459 timer.throw_if_expired();
460
461 einfo = &queue_.back();
462
463 einfo->lowlink = std::min(last_lowlink, einfo->lowlink);
464 einfo->update(bt_einfo);
465 } while (!advance(quotient, *einfo));
466 }
467}
468
469template <typename State, typename Action, bool UseInterval>
470auto TADFHSImpl<State, Action, UseInterval>::value_iteration(
471 QuotientSystem& quotient,
472 const std::ranges::input_range auto& range,
473 utils::CountdownTimer& timer) -> UpdateResult
474{
475 UpdateResult updated_all(false, false);
476 bool value_changed_for_any;
477 bool policy_changed_for_any;
478
479 do {
480 value_changed_for_any = false;
481 policy_changed_for_any = false;
482
483 for (const StateID id : range) {
484 timer.throw_if_expired();
485
486 const QState state = quotient.get_state(id);
487 const value_t termination_cost =
488 quotient.get_termination_info(state).get_cost();
489
490 ClearGuard _(transitions_, qvalues_);
491 this->generate_non_tip_transitions(quotient, state, transitions_);
492
493 StateInfo& state_info = this->state_infos_[id];
494 const auto value = this->compute_bellman_and_greedy(
495 quotient,
496 id,
497 transitions_,
498 termination_cost,
499 qvalues_);
500
501 bool value_changed = this->update_value(state_info, value);
502 bool policy_changed = this->update_policy(
503 state_info,
504 this->select_greedy_transition(
505 quotient,
506 state_info.get_policy(),
507 transitions_));
508 value_changed_for_any = value_changed_for_any || value_changed;
509 policy_changed_for_any = policy_changed_for_any || policy_changed;
510 }
511
512 updated_all.value_changed =
513 updated_all.value_changed || value_changed_for_any;
514 updated_all.policy_changed =
515 updated_all.policy_changed || policy_changed_for_any;
516 } while (value_changed_for_any && !policy_changed_for_any);
517
518 return updated_all;
519}
520
521template <typename State, typename Action, bool UseInterval>
522TADepthFirstHeuristicSearch<State, Action, UseInterval>::
523 TADepthFirstHeuristicSearch(
524 std::shared_ptr<QuotientPolicyPicker> policy_chooser,
525 bool forward_updates,
526 BacktrackingUpdateType backtrack_update_type,
527 bool cutoff_tip,
528 bool cutoff_inconsistent,
529 bool stop_exploration_inconsistent,
530 bool label_solved,
531 bool reexpand_removed_traps)
532 : algorithm_(
533 std::move(policy_chooser),
534 forward_updates,
535 backtrack_update_type,
536 cutoff_tip,
537 cutoff_inconsistent,
538 stop_exploration_inconsistent,
539 label_solved,
540 reexpand_removed_traps)
541{
542}
543
544template <typename State, typename Action, bool UseInterval>
545Interval TADepthFirstHeuristicSearch<State, Action, UseInterval>::solve(
546 MDPType& mdp,
547 EvaluatorType& heuristic,
548 param_type<State> state,
549 ProgressReport progress,
550 double max_time)
551{
552 QuotientSystem quotient(mdp);
553 quotients::QuotientMaxHeuristic<State, Action> qheuristic(heuristic);
554 return algorithm_.solve_quotient(
555 quotient,
556 qheuristic,
557 quotient.translate_state(state),
558 progress,
559 max_time);
560}
561
562template <typename State, typename Action, bool UseInterval>
563auto TADepthFirstHeuristicSearch<State, Action, UseInterval>::compute_policy(
564 MDPType& mdp,
565 EvaluatorType& heuristic,
566 param_type<State> state,
567 ProgressReport progress,
568 double max_time) -> std::unique_ptr<PolicyType>
569{
570 QuotientSystem quotient(mdp);
571 quotients::QuotientMaxHeuristic<State, Action> qheuristic(heuristic);
572
573 QState qinit = quotient.translate_state(state);
574 algorithm_.solve_quotient(quotient, qheuristic, qinit, progress, max_time);
575
576 /*
577 * The quotient policy only specifies the optimal actions between
578 * traps. We need to supplement the optimal actions within the
579 * traps, i.e. the actions which point every other member state of
580 * the trap towards that trap member state that owns the optimal
581 * quotient action.
582 *
583 * We fully explore the quotient policy starting from the initial
584 * state and compute the optimal 'inner' actions for each trap. To
585 * this end, we first generate the sub-MDP of the trap. Afterwards,
586 * we expand the trap graph backwards from the state that has the
587 * optimal quotient action. For each encountered state, we select
588 * the action with which it is encountered first as the policy
589 * action.
590 */
591 using MapPolicy = policies::MapPolicy<State, Action>;
592 std::unique_ptr<MapPolicy> policy(new MapPolicy(&mdp));
593
594 const StateID initial_state_id = quotient.get_state_id(qinit);
595
596 std::deque<StateID> queue({initial_state_id});
597 std::set<StateID> visited({initial_state_id});
598
599 do {
600 const StateID quotient_id = queue.front();
601 const QState quotient_state = quotient.get_state(quotient_id);
602 queue.pop_front();
603
604 const auto& state_info = algorithm_.state_infos_[quotient_id];
605
606 std::optional quotient_action = state_info.get_policy();
607
608 // Terminal states have no policy decision.
609 if (!quotient_action) {
610 continue;
611 }
612
613 const Interval quotient_bound = as_interval(state_info.value);
614
615 const StateID exiting_id = quotient_action->state_id;
616
617 policy->emplace_decision(
618 exiting_id,
619 quotient_action->action,
620 quotient_bound);
621
622 // Nothing else needs to be done if the trap has only one state.
623 if (quotient_state.num_members() != 1) {
624 std::unordered_map<StateID, std::set<QAction>> parents;
625
626 // Build the inverse graph
627 std::vector<QAction> inner_actions;
628 quotient_state.get_collapsed_actions(inner_actions);
629
630 for (const QAction& qaction : inner_actions) {
631 StateID source_id = qaction.state_id;
632 Action action = qaction.action;
633
634 const State source = mdp.get_state(source_id);
635
636 Distribution<StateID> successors;
637 mdp.generate_action_transitions(source, action, successors);
638
639 for (const StateID succ_id : successors.support()) {
640 parents[succ_id].insert(qaction);
641 }
642 }
643
644 // Now traverse the inverse graph starting from the exiting
645 // state
646 std::deque<StateID> inverse_queue({exiting_id});
647 std::set<StateID> inverse_visited({exiting_id});
648
649 do {
650 const StateID next_id = inverse_queue.front();
651 inverse_queue.pop_front();
652
653 for (const auto& [pred_id, act] : parents[next_id]) {
654 if (inverse_visited.insert(pred_id).second) {
655 policy->emplace_decision(pred_id, act, quotient_bound);
656 inverse_queue.push_back(pred_id);
657 }
658 }
659 } while (!inverse_queue.empty());
660 }
661
662 // Push the successor traps.
663 Distribution<StateID> successors;
664 quotient.generate_action_transitions(
665 quotient_state,
666 *quotient_action,
667 successors);
668
669 for (const StateID succ_id : successors.support()) {
670 if (visited.insert(succ_id).second) {
671 queue.push_back(succ_id);
672 }
673 }
674 } while (!queue.empty());
675
676 return policy;
677}
678
679template <typename State, typename Action, bool UseInterval>
680void TADepthFirstHeuristicSearch<State, Action, UseInterval>::print_statistics(
681 std::ostream& out) const
682{
683 return algorithm_.print_statistics(out);
684}
685
686template <typename State, typename Action, bool UseInterval>
687Interval TADepthFirstHeuristicSearch<State, Action, UseInterval>::lookup_bounds(
688 StateID state_id) const
689{
690 return algorithm_.lookup_bounds(state_id);
691}
692
693template <typename State, typename Action, bool UseInterval>
694bool TADepthFirstHeuristicSearch<State, Action, UseInterval>::was_visited(
695 StateID state_id) const
696{
697 return algorithm_.was_visited(state_id);
698}
699
700} // namespace probfd::algorithms::trap_aware_dfhs
Namespace dedicated to the depth-first heuristic search (DFHS) family with native trap handling suppo...
Definition trap_aware_dfhs.h:26
Interval as_interval(value_t lower_bound)
Returns the interval with the given lower bound and infinte upper bound.
double value_t
Typedef for the state value type.
Definition aliases.h:7
STL namespace.