1#ifndef GUARD_INCLUDE_PROBFD_ALGORITHMS_TRAP_AWARE_DFHS_H
2#error "This file should only be included from trap_aware_dfhs.h"
5#include "probfd/algorithms/open_list.h"
7#include "probfd/policies/map_policy.h"
9#include "probfd/quotients/quotient_max_heuristic.h"
11#include "probfd/utils/guards.h"
13#include "downward/utils/countdown_timer.h"
23inline void Statistics::print(std::ostream& out)
const
25 out <<
" Iterations: " << iterations << std::endl;
26 out <<
" Traps: " << traps << std::endl;
27 out <<
" Bellman backups (forward): " << fw_updates << std::endl;
28 out <<
" Bellman backups (backward): " << bw_updates << std::endl;
29 out <<
" State re-expansions: " << reexpansions << std::endl;
30 out <<
" Trap removal time: " << trap_timer << std::endl;
33inline void Statistics::register_report(ProgressReport& report)
const
35 report.register_print([
this](std::ostream& out) {
36 out <<
"iteration=" << iterations <<
", traps=" << traps;
42template <
typename State,
typename Action,
bool UseInterval>
43bool TADFHSImpl<State, Action, UseInterval>::ExplorationInformation::
46 successors.pop_back();
47 return !successors.empty();
50template <
typename State,
typename Action,
bool UseInterval>
52TADFHSImpl<State, Action, UseInterval>::ExplorationInformation::get_successor()
55 return successors.back();
58template <
typename State,
typename Action,
bool UseInterval>
59void TADFHSImpl<State, Action, UseInterval>::ExplorationInformation::update(
60 const ExplorationInformation& other)
62 if (!other.value_converged) value_converged =
false;
63 if (!other.all_solved) all_solved =
false;
64 if (!other.is_trap) is_trap =
false;
67template <
typename State,
typename Action,
bool UseInterval>
68void TADFHSImpl<State, Action, UseInterval>::ExplorationInformation::clear()
70 value_converged =
true;
75template <
typename State,
typename Action,
bool UseInterval>
76TADFHSImpl<State, Action, UseInterval>::TADFHSImpl(
77 std::shared_ptr<QuotientPolicyPicker> policy_chooser,
79 BacktrackingUpdateType backtrack_update_type,
81 bool cutoff_inconsistent,
82 bool terminate_exploration_on_cutoff,
85 : Base(policy_chooser)
86 , forward_updates_(forward_updates)
87 , backtrack_update_type_(backtrack_update_type)
88 , cutoff_tip_(cutoff_tip)
89 , cutoff_inconsistent_(cutoff_inconsistent)
90 , terminate_exploration_on_cutoff_(terminate_exploration_on_cutoff)
91 , label_solved_(label_solved)
92 , reexpand_traps_(reexpand_traps)
97template <
typename State,
typename Action,
bool UseInterval>
98Interval TADFHSImpl<State, Action, UseInterval>::solve_quotient(
99 QuotientSystem& quotient,
100 QEvaluator& heuristic,
101 param_type<QState> qstate,
102 ProgressReport& progress,
105 utils::CountdownTimer timer(max_time);
107 Base::initialize_initial_state(quotient, heuristic, qstate);
109 const StateID state_id = quotient.get_state_id(qstate);
110 const StateInfo& state_info = this->state_infos_[state_id];
112 progress.register_bound(
"v", [&state_info]() {
116 statistics_.register_report(progress);
118 if (!label_solved_) {
119 dfhs_vi_driver(quotient, heuristic, state_id, progress, timer);
121 dfhs_label_driver(quotient, heuristic, state_id, progress, timer);
124 return state_info.get_bounds();
127template <
typename State,
typename Action,
bool UseInterval>
128void TADFHSImpl<State, Action, UseInterval>::print_statistics(
129 std::ostream& out)
const
131 Base::print_statistics(out);
132 statistics_.print(out);
135template <
typename State,
typename Action,
bool UseInterval>
136void TADFHSImpl<State, Action, UseInterval>::dfhs_vi_driver(
137 QuotientSystem& quotient,
138 QEvaluator& heuristic,
140 ProgressReport& progress,
141 utils::CountdownTimer& timer)
143 UpdateResult vi_res{
true,
true};
146 policy_exploration(quotient, heuristic, state, timer);
148 vi_res = value_iteration(quotient, visited_states_, timer);
150 visited_states_.clear();
151 ++statistics_.iterations;
153 }
while (vi_res.value_changed || vi_res.policy_changed);
156template <
typename State,
typename Action,
bool UseInterval>
157void TADFHSImpl<State, Action, UseInterval>::dfhs_label_driver(
158 QuotientSystem& quotient,
159 QEvaluator& heuristic,
161 ProgressReport& progress,
162 utils::CountdownTimer& timer)
166 solved = policy_exploration(quotient, heuristic, state, timer) &&
167 this->state_infos_[state].is_solved();
168 visited_states_.clear();
169 ++statistics_.iterations;
174template <
typename State,
typename Action,
bool UseInterval>
175void TADFHSImpl<State, Action, UseInterval>::enqueue(
176 QuotientSystem& quotient,
177 ExplorationInformation& einfo,
180 const Distribution<StateID>& successor_dist)
182 stack_.back().action = action;
184 einfo.successors.reserve(successor_dist.size());
186 for (
const StateID item : successor_dist.support()) {
187 if (item == state)
continue;
188 einfo.successors.push_back(item);
191 assert(!einfo.successors.empty());
192 einfo.is_trap = quotient.get_action_cost(action) == 0;
195template <
typename State,
typename Action,
bool UseInterval>
196bool TADFHSImpl<State, Action, UseInterval>::advance(
197 QuotientSystem& quotient,
198 ExplorationInformation& einfo)
200 using enum BacktrackingUpdateType;
203 einfo.value_converged =
false;
204 einfo.all_solved =
false;
205 }
else if (einfo.next_successor()) {
209 if (backtrack_update_type_ == SINGLE ||
210 (backtrack_update_type_ == ON_DEMAND && !einfo.value_converged)) {
211 ++statistics_.bw_updates;
213 const QState state = quotient.get_state(einfo.state);
214 const value_t termination_cost =
215 quotient.get_termination_info(state).get_cost();
217 ClearGuard _(transitions_, qvalues_);
218 this->generate_non_tip_transitions(quotient, state, transitions_);
220 StateInfo& state_info = this->state_infos_[einfo.state];
221 auto value = this->compute_bellman_and_greedy(
228 bool value_changed = this->update_value(state_info, value);
229 bool policy_changed = this->update_policy(
231 this->select_greedy_transition(
233 state_info.get_policy(),
235 einfo.value_converged = einfo.value_converged && !value_changed;
237 einfo.all_solved && !value_changed && !policy_changed;
238 terminated_ = terminated_ || (terminate_exploration_on_cutoff_ &&
239 cutoff_inconsistent_ && value_changed);
245template <
typename State,
typename Action,
bool UseInterval>
246bool TADFHSImpl<State, Action, UseInterval>::push_successor(
247 QuotientSystem& quotient,
248 ExplorationInformation& einfo,
249 utils::CountdownTimer& timer)
252 timer.throw_if_expired();
254 const StateID succ = quotient.translate_state_id(einfo.get_successor());
256 const int succ_status = stack_index_[succ];
258 if (succ_status == NEW) {
261 }
else if (succ_status == CLOSED) {
262 einfo.is_trap =
false;
265 einfo.all_solved && this->state_infos_[succ].is_solved();
269 assert(succ_status >= 0);
270 einfo.lowlink = std::min(einfo.lowlink, succ_status);
272 }
while (advance(quotient, einfo));
277template <
typename State,
typename Action,
bool UseInterval>
278void TADFHSImpl<State, Action, UseInterval>::push(StateID state_id)
280 queue_.emplace_back(state_id, stack_.size());
281 stack_index_[state_id] = stack_.size();
282 stack_.emplace_back(state_id);
285template <
typename State,
typename Action,
bool UseInterval>
286bool TADFHSImpl<State, Action, UseInterval>::initialize(
287 QuotientSystem& quotient,
288 QEvaluator& heuristic,
289 ExplorationInformation& einfo)
291 assert(!terminated_);
293 const StateID state_id = einfo.state;
295 StateInfo& state_info = this->state_infos_[state_id];
296 if (state_info.is_solved()) {
297 assert(label_solved_ || state_info.is_goal_or_terminal());
298 einfo.is_trap =
false;
302 const bool tip = state_info.is_on_fringe();
304 if (tip || forward_updates_) {
305 ClearGuard _(transitions_, qvalues_);
307 const QState state = quotient.get_state(einfo.state);
308 const value_t termination_cost =
309 quotient.get_termination_info(state).get_cost();
312 this->expand_and_initialize(
319 this->generate_non_tip_transitions(quotient, state, transitions_);
322 ++statistics_.fw_updates;
324 auto value = this->compute_bellman_and_greedy(
331 auto transition = this->select_greedy_transition(
333 state_info.get_policy(),
336 bool value_changed = this->update_value(state_info, value);
337 this->update_policy(state_info, transition);
339 einfo.value_converged = einfo.value_converged && !value_changed;
340 einfo.all_solved = einfo.all_solved && !value_changed;
342 (cutoff_tip_ && tip) || (cutoff_inconsistent_ && value_changed);
343 terminated_ = terminate_exploration_on_cutoff_ && cutoff;
346 einfo.is_trap =
false;
351 einfo.is_trap =
false;
352 einfo.value_converged =
false;
353 einfo.all_solved =
false;
362 transition->successor_dist);
364 auto action = state_info.get_policy();
365 if (!action.has_value())
return false;
367 const QState state = quotient.get_state(state_id);
368 quotient.generate_action_transitions(state, *action, transition_);
369 enqueue(quotient, einfo, state_id, *action, transition_);
376template <
typename State,
typename Action,
bool UseInterval>
377bool TADFHSImpl<State, Action, UseInterval>::policy_exploration(
378 QuotientSystem& quotient,
379 QEvaluator& heuristic,
381 utils::CountdownTimer& timer)
383 assert(visited_states_.empty());
388 ExplorationInformation* einfo;
392 einfo = &queue_.back();
393 }
while (initialize(quotient, heuristic, *einfo) &&
394 push_successor(quotient, *einfo, timer));
397 const int last_lowlink = einfo->lowlink;
400 if (einfo->lowlink == stack_index_[einfo->state]) {
401 auto scc = stack_ | std::views::drop(last_lowlink);
403 if (scc.size() > 1 && einfo->is_trap) {
404 ++this->statistics_.traps;
406 const StateID state_id = einfo->state;
409 TimerScope scope(statistics_.trap_timer);
411 quotient.build_quotient(scc, *scc.begin());
412 StateInfo& state_info = this->state_infos_[state_id];
413 state_info.update_policy(std::nullopt);
414 state_info.set_on_fringe();
417 if (reexpand_traps_) {
418 stack_.erase(scc.begin(), scc.end());
424 stack_index_[state_id] = CLOSED;
426 einfo->value_converged =
false;
427 einfo->all_solved =
false;
429 for (
const auto state_id :
430 scc | std::views::transform(&StackInfo::state_id)) {
431 stack_index_[state_id] = CLOSED;
433 if (!einfo->all_solved)
continue;
435 StateInfo& mem_info = this->state_infos_[state_id];
436 if (mem_info.is_solved())
continue;
439 mem_info.set_solved();
441 visited_states_.push_back(state_id);
446 einfo->is_trap =
false;
447 stack_.erase(scc.begin(), scc.end());
450 ExplorationInformation bt_einfo = std::move(*einfo);
453 if (queue_.empty()) {
454 assert(stack_.empty());
455 stack_index_.clear();
456 return einfo->all_solved;
459 timer.throw_if_expired();
461 einfo = &queue_.back();
463 einfo->lowlink = std::min(last_lowlink, einfo->lowlink);
464 einfo->update(bt_einfo);
465 }
while (!advance(quotient, *einfo));
469template <
typename State,
typename Action,
bool UseInterval>
470auto TADFHSImpl<State, Action, UseInterval>::value_iteration(
471 QuotientSystem& quotient,
472 const std::ranges::input_range
auto& range,
473 utils::CountdownTimer& timer) -> UpdateResult
475 UpdateResult updated_all(
false,
false);
476 bool value_changed_for_any;
477 bool policy_changed_for_any;
480 value_changed_for_any =
false;
481 policy_changed_for_any =
false;
483 for (
const StateID
id : range) {
484 timer.throw_if_expired();
486 const QState state = quotient.get_state(
id);
487 const value_t termination_cost =
488 quotient.get_termination_info(state).get_cost();
490 ClearGuard _(transitions_, qvalues_);
491 this->generate_non_tip_transitions(quotient, state, transitions_);
493 StateInfo& state_info = this->state_infos_[id];
494 const auto value = this->compute_bellman_and_greedy(
501 bool value_changed = this->update_value(state_info, value);
502 bool policy_changed = this->update_policy(
504 this->select_greedy_transition(
506 state_info.get_policy(),
508 value_changed_for_any = value_changed_for_any || value_changed;
509 policy_changed_for_any = policy_changed_for_any || policy_changed;
512 updated_all.value_changed =
513 updated_all.value_changed || value_changed_for_any;
514 updated_all.policy_changed =
515 updated_all.policy_changed || policy_changed_for_any;
516 }
while (value_changed_for_any && !policy_changed_for_any);
521template <
typename State,
typename Action,
bool UseInterval>
522TADepthFirstHeuristicSearch<State, Action, UseInterval>::
523 TADepthFirstHeuristicSearch(
524 std::shared_ptr<QuotientPolicyPicker> policy_chooser,
525 bool forward_updates,
526 BacktrackingUpdateType backtrack_update_type,
528 bool cutoff_inconsistent,
529 bool stop_exploration_inconsistent,
531 bool reexpand_removed_traps)
533 std::move(policy_chooser),
535 backtrack_update_type,
538 stop_exploration_inconsistent,
540 reexpand_removed_traps)
544template <
typename State,
typename Action,
bool UseInterval>
545Interval TADepthFirstHeuristicSearch<State, Action, UseInterval>::solve(
547 EvaluatorType& heuristic,
548 param_type<State> state,
549 ProgressReport progress,
552 QuotientSystem quotient(mdp);
553 quotients::QuotientMaxHeuristic<State, Action> qheuristic(heuristic);
554 return algorithm_.solve_quotient(
557 quotient.translate_state(state),
562template <
typename State,
typename Action,
bool UseInterval>
563auto TADepthFirstHeuristicSearch<State, Action, UseInterval>::compute_policy(
565 EvaluatorType& heuristic,
566 param_type<State> state,
567 ProgressReport progress,
568 double max_time) -> std::unique_ptr<PolicyType>
570 QuotientSystem quotient(mdp);
571 quotients::QuotientMaxHeuristic<State, Action> qheuristic(heuristic);
573 QState qinit = quotient.translate_state(state);
574 algorithm_.solve_quotient(quotient, qheuristic, qinit, progress, max_time);
591 using MapPolicy = policies::MapPolicy<State, Action>;
592 std::unique_ptr<MapPolicy> policy(
new MapPolicy(&mdp));
594 const StateID initial_state_id = quotient.get_state_id(qinit);
596 std::deque<StateID> queue({initial_state_id});
597 std::set<StateID> visited({initial_state_id});
600 const StateID quotient_id = queue.front();
601 const QState quotient_state = quotient.get_state(quotient_id);
604 const auto& state_info = algorithm_.state_infos_[quotient_id];
606 std::optional quotient_action = state_info.get_policy();
609 if (!quotient_action) {
613 const Interval quotient_bound =
as_interval(state_info.value);
615 const StateID exiting_id = quotient_action->state_id;
617 policy->emplace_decision(
619 quotient_action->action,
623 if (quotient_state.num_members() != 1) {
624 std::unordered_map<StateID, std::set<QAction>> parents;
627 std::vector<QAction> inner_actions;
628 quotient_state.get_collapsed_actions(inner_actions);
630 for (
const QAction& qaction : inner_actions) {
631 StateID source_id = qaction.state_id;
632 Action action = qaction.action;
634 const State source = mdp.get_state(source_id);
636 Distribution<StateID> successors;
637 mdp.generate_action_transitions(source, action, successors);
639 for (
const StateID succ_id : successors.support()) {
640 parents[succ_id].insert(qaction);
646 std::deque<StateID> inverse_queue({exiting_id});
647 std::set<StateID> inverse_visited({exiting_id});
650 const StateID next_id = inverse_queue.front();
651 inverse_queue.pop_front();
653 for (
const auto& [pred_id, act] : parents[next_id]) {
654 if (inverse_visited.insert(pred_id).second) {
655 policy->emplace_decision(pred_id, act, quotient_bound);
656 inverse_queue.push_back(pred_id);
659 }
while (!inverse_queue.empty());
663 Distribution<StateID> successors;
664 quotient.generate_action_transitions(
669 for (
const StateID succ_id : successors.support()) {
670 if (visited.insert(succ_id).second) {
671 queue.push_back(succ_id);
674 }
while (!queue.empty());
679template <
typename State,
typename Action,
bool UseInterval>
680void TADepthFirstHeuristicSearch<State, Action, UseInterval>::print_statistics(
681 std::ostream& out)
const
683 return algorithm_.print_statistics(out);
686template <
typename State,
typename Action,
bool UseInterval>
687Interval TADepthFirstHeuristicSearch<State, Action, UseInterval>::lookup_bounds(
688 StateID state_id)
const
690 return algorithm_.lookup_bounds(state_id);
693template <
typename State,
typename Action,
bool UseInterval>
694bool TADepthFirstHeuristicSearch<State, Action, UseInterval>::was_visited(
695 StateID state_id)
const
697 return algorithm_.was_visited(state_id);
Namespace dedicated to the depth-first heuristic search (DFHS) family with native trap handling suppo...
Definition trap_aware_dfhs.h:26
Interval as_interval(value_t lower_bound)
Returns the interval with the given lower bound and infinte upper bound.
double value_t
Typedef for the state value type.
Definition aliases.h:7