1#ifndef GUARD_INCLUDE_PROBFD_ALGORITHMS_TRAP_AWARE_LRTDP_H
2#error "This file should only be included from trap_aware_lrtdp.h"
5#include "probfd/algorithms/successor_sampler.h"
7#include "probfd/quotients/quotient_max_heuristic.h"
9#include "probfd/utils/guards.h"
11#include "downward/utils/countdown_timer.h"
17inline void Statistics::print(std::ostream& out)
const
19 out <<
" Trials: " << trials << std::endl;
20 out <<
" Average trial length: "
21 << (
static_cast<double>(trial_length) /
static_cast<double>(trials))
23 out <<
" Bellman backups (trials): " << trial_bellman_backups << std::endl;
24 out <<
" Bellman backups (check&solved): "
25 << check_and_solve_bellman_backups << std::endl;
26 out <<
" Trap removals: " << traps << std::endl;
27 out <<
" Trap removal time: " << trap_timer << std::endl;
30inline void Statistics::register_report(ProgressReport& report)
const
32 report.register_print([
this](std::ostream& out) {
33 out <<
"traps=" << traps <<
", trials=" << trials;
39template <
typename State,
typename Action,
bool UseInterval>
40bool TALRTDPImpl<State, Action, UseInterval>::ExplorationInformation::
43 successors.pop_back();
44 return !successors.empty();
47template <
typename State,
typename Action,
bool UseInterval>
49TALRTDPImpl<State, Action, UseInterval>::ExplorationInformation::get_successor()
52 return successors.back();
55template <
typename State,
typename Action,
bool UseInterval>
56TALRTDPImpl<State, Action, UseInterval>::TALRTDPImpl(
57 std::shared_ptr<QuotientPolicyPicker> policy_chooser,
58 TrialTerminationCondition stop_consistent,
60 std::shared_ptr<QuotientSuccessorSampler> succ_sampler)
61 : Base(policy_chooser)
62 , stop_at_consistent_(stop_consistent)
63 , reexpand_traps_(reexpand_traps)
64 , sample_(succ_sampler)
65 , stack_index_(STATE_UNSEEN)
69template <
typename State,
typename Action,
bool UseInterval>
70Interval TALRTDPImpl<State, Action, UseInterval>::solve_quotient(
71 QuotientSystem& quotient,
72 QEvaluator& heuristic,
73 param_type<QState> state,
74 ProgressReport& progress,
77 utils::CountdownTimer timer(max_time);
79 Base::initialize_initial_state(quotient, heuristic, state);
81 const StateID state_id = quotient.get_state_id(state);
82 const StateInfo& state_info = this->state_infos_[state_id];
84 progress.register_bound(
"v", [&state_info]() {
88 this->statistics_.register_report(progress);
92 terminate = trial(quotient, heuristic, state_id, timer);
93 assert(state_id == quotient.translate_state_id(state_id));
98 return state_info.get_bounds();
101template <
typename State,
typename Action,
bool UseInterval>
102void TALRTDPImpl<State, Action, UseInterval>::print_statistics(
103 std::ostream& out)
const
105 this->statistics_.print(out);
108template <
typename State,
typename Action,
bool UseInterval>
109bool TALRTDPImpl<State, Action, UseInterval>::trial(
110 QuotientSystem& quotient,
111 QEvaluator& heuristic,
113 utils::CountdownTimer& timer)
117 assert(current_trial_.empty());
119 ClearGuard guard(current_trial_);
120 current_trial_.push_back(start_state);
122 timer.throw_if_expired();
124 StateID stateid = current_trial_.back();
125 auto& info = this->state_infos_[stateid];
127 if (info.is_solved()) {
128 current_trial_.pop_back();
132 const QState state = quotient.get_state(stateid);
133 const value_t termination_cost =
134 quotient.get_termination_info(state).get_cost();
136 ClearGuard _(transitions_, qvalues_);
137 if (info.is_on_fringe()) {
138 this->expand_and_initialize(
145 this->generate_non_tip_transitions(quotient, state, transitions_);
148 const auto value = this->compute_bellman_and_greedy(
155 statistics_.trial_bellman_backups++;
157 auto transition = this->select_greedy_transition(
162 bool value_changed = this->update_value(info, value);
163 this->update_policy(info, transition);
165 if (!transition.has_value()) {
167 current_trial_.pop_back();
171 if ((stop_at_consistent_ == CONSISTENT && !value_changed) ||
172 (stop_at_consistent_ == INCONSISTENT && value_changed) ||
173 (stop_at_consistent_ == REVISITED && info.is_on_trial())) {
177 if (stop_at_consistent_ == REVISITED) {
181 auto next = sample_->sample(
184 transition->successor_dist,
187 current_trial_.push_back(next);
190 statistics_.trial_length += current_trial_.size();
191 if (stop_at_consistent_ == REVISITED) {
192 for (
const StateID state : current_trial_) {
193 this->state_infos_[state].clear_trial_flag();
198 timer.throw_if_expired();
200 if (!check_and_solve(quotient, heuristic, timer)) {
204 current_trial_.pop_back();
205 }
while (!current_trial_.empty());
210template <
typename State,
typename Action,
bool UseInterval>
211bool TALRTDPImpl<State, Action, UseInterval>::check_and_solve(
212 QuotientSystem& quotient,
213 QEvaluator& heuristic,
214 utils::CountdownTimer& timer)
216 assert(!this->current_trial_.empty());
218 push(quotient.translate_state_id(this->current_trial_.back()));
220 ExplorationInformation* einfo;
225 einfo = &queue_.back();
226 sinfo = &this->state_infos_[einfo->state];
227 }
while (this->initialize(
233 this->push_successor(quotient, *einfo, timer));
236 if (einfo->is_root) {
237 const StateID state_id = einfo->state;
238 const unsigned stack_index = stack_index_[state_id];
239 auto scc = stack_ | std::views::drop(stack_index);
241 if (einfo->is_trap && scc.size() > 1) {
243 for (
const auto& entry : scc) {
244 stack_index_[entry.state_id] = STATE_CLOSED;
247 TimerScope scope(statistics_.trap_timer);
248 quotient.build_quotient(scc, *scc.begin());
249 sinfo->update_policy(std::nullopt);
251 stack_.erase(scc.begin(), scc.end());
253 if (reexpand_traps_) {
259 ++statistics_.check_and_solve_bellman_backups;
261 const QState state = quotient.get_state(state_id);
262 const value_t termination_cost =
263 quotient.get_termination_info(state).get_cost();
266 ClearGuard _(transitions_, qvalues_);
267 this->generate_non_tip_transitions(
272 auto value = this->compute_bellman_and_greedy(
279 auto transition = this->select_greedy_transition(
284 this->update_value(*sinfo, value);
285 this->update_policy(*sinfo, transition);
290 for (
const auto& entry : scc) {
291 const StateID
id = entry.state_id;
292 StateInfo& info = this->state_infos_[id];
293 stack_index_[id] = STATE_CLOSED;
294 if (info.is_solved())
continue;
298 const QState state = quotient.get_state(
id);
299 const value_t termination_cost =
300 quotient.get_termination_info(state).get_cost();
302 ClearGuard _(transitions_, qvalues_);
303 this->generate_non_tip_transitions(
308 ++this->statistics_.check_and_solve_bellman_backups;
310 auto value = this->compute_bellman_and_greedy(
317 auto transition = this->select_greedy_transition(
322 this->update_value(info, value);
323 this->update_policy(info, transition);
326 stack_.erase(scc.begin(), scc.end());
329 einfo->is_trap =
false;
332 ExplorationInformation bt_einfo = std::move(*einfo);
336 if (queue_.empty()) {
337 assert(stack_.empty());
338 stack_index_.clear();
339 return sinfo->is_solved();
342 timer.throw_if_expired();
344 einfo = &queue_.back();
345 sinfo = &this->state_infos_[einfo->state];
347 einfo->update(bt_einfo);
348 }
while (!einfo->next_successor() ||
349 !this->push_successor(quotient, *einfo, timer));
353template <
typename State,
typename Action,
bool UseInterval>
354bool TALRTDPImpl<State, Action, UseInterval>::push_successor(
355 QuotientSystem& quotient,
356 ExplorationInformation& einfo,
357 utils::CountdownTimer& timer)
360 timer.throw_if_expired();
362 const StateID succ = quotient.translate_state_id(einfo.get_successor());
363 StateInfo& succ_info = this->state_infos_[succ];
364 int& sidx = stack_index_[succ];
365 if (sidx == STATE_UNSEEN) {
368 }
else if (sidx >= 0) {
369 int& sidx2 = stack_index_[einfo.state];
372 einfo.is_root =
false;
375 einfo.update(succ_info);
377 }
while (einfo.next_successor());
382template <
typename State,
typename Action,
bool UseInterval>
383void TALRTDPImpl<State, Action, UseInterval>::push(StateID state)
385 queue_.emplace_back(state);
386 stack_index_[state] = stack_.size();
387 stack_.emplace_back(state);
390template <
typename State,
typename Action,
bool UseInterval>
391bool TALRTDPImpl<State, Action, UseInterval>::initialize(
392 QuotientSystem& quotient,
393 QEvaluator& heuristic,
395 StateInfo& state_info,
396 ExplorationInformation& e_info)
398 assert(quotient.translate_state_id(state_id) == state_id);
400 if (state_info.is_solved()) {
401 e_info.is_trap =
false;
405 const QState state = quotient.get_state(state_id);
406 const value_t termination_cost =
407 quotient.get_termination_info(state).get_cost();
409 ClearGuard _(transitions_, qvalues_);
411 if (state_info.is_on_fringe()) {
412 this->expand_and_initialize(
419 this->generate_non_tip_transitions(quotient, state, transitions_);
422 ++this->statistics_.check_and_solve_bellman_backups;
424 const auto value = this->compute_bellman_and_greedy(
431 auto transition = this->select_greedy_transition(
433 state_info.get_policy(),
436 bool value_changed = this->update_value(state_info, value);
437 this->update_policy(state_info, transition);
440 e_info.rv = e_info.rv && !value_changed;
441 e_info.is_trap =
false;
447 e_info.is_trap =
false;
451 for (
const StateID sel : transition->successor_dist.support()) {
452 if (sel != state_id) {
453 e_info.successors.push_back(sel);
457 assert(!e_info.successors.empty());
458 e_info.is_trap = quotient.get_action_cost(transition->action) == 0;
459 stack_.back().aops.emplace_back(transition->action);
463template <
typename State,
typename Action,
bool UseInterval>
464TALRTDP<State, Action, UseInterval>::TALRTDP(
465 std::shared_ptr<QuotientPolicyPicker> policy_chooser,
466 TrialTerminationCondition stop_consistent,
468 std::shared_ptr<QuotientSuccessorSampler> succ_sampler)
469 : algorithm_(policy_chooser, stop_consistent, reexpand_traps, succ_sampler)
473template <
typename State,
typename Action,
bool UseInterval>
474Interval TALRTDP<State, Action, UseInterval>::solve(
476 EvaluatorType& heuristic,
478 ProgressReport progress,
481 QuotientSystem quotient(mdp);
482 quotients::QuotientMaxHeuristic<State, Action> qheuristic(heuristic);
483 return algorithm_.solve_quotient(
486 quotient.translate_state(s),
491template <
typename State,
typename Action,
bool UseInterval>
492auto TALRTDP<State, Action, UseInterval>::compute_policy(
494 EvaluatorType& heuristic,
495 param_type<State> state,
496 ProgressReport progress,
497 double max_time) -> std::unique_ptr<PolicyType>
499 QuotientSystem quotient(mdp);
500 quotients::QuotientMaxHeuristic<State, Action> qheuristic(heuristic);
502 QState qinit = quotient.translate_state(state);
503 algorithm_.solve_quotient(quotient, qheuristic, qinit, progress, max_time);
520 using MapPolicy = policies::MapPolicy<State, Action>;
521 std::unique_ptr<MapPolicy> policy(
new MapPolicy(&mdp));
523 const StateID initial_state_id = quotient.get_state_id(qinit);
525 std::deque<StateID> queue({initial_state_id});
526 std::set<StateID> visited({initial_state_id});
529 const StateID quotient_id = queue.front();
530 const QState quotient_state = quotient.get_state(quotient_id);
533 const auto& state_info = algorithm_.state_infos_[quotient_id];
535 std::optional quotient_action = state_info.get_policy();
538 if (!quotient_action) {
542 const Interval quotient_bound =
as_interval(state_info.value);
544 const StateID exiting_id = quotient_action->state_id;
546 policy->emplace_decision(
548 quotient_action->action,
552 if (quotient_state.num_members() != 1) {
553 std::unordered_map<StateID, std::set<QAction>> parents;
556 std::vector<QAction> inner_actions;
557 quotient_state.get_collapsed_actions(inner_actions);
559 for (
const QAction& qaction : inner_actions) {
560 StateID source_id = qaction.state_id;
561 Action action = qaction.action;
563 const State source = mdp.get_state(source_id);
565 Distribution<StateID> successors;
566 mdp.generate_action_transitions(source, action, successors);
568 for (
const StateID succ_id : successors.support()) {
569 parents[succ_id].insert(qaction);
575 std::deque<StateID> inverse_queue({exiting_id});
576 std::set<StateID> inverse_visited({exiting_id});
579 const StateID next_id = inverse_queue.front();
580 inverse_queue.pop_front();
582 for (
const auto& [pred_id, act] : parents[next_id]) {
583 if (inverse_visited.insert(pred_id).second) {
584 policy->emplace_decision(pred_id, act, quotient_bound);
585 inverse_queue.push_back(pred_id);
588 }
while (!inverse_queue.empty());
592 Distribution<StateID> successors;
593 quotient.generate_action_transitions(
598 for (
const StateID succ_id : successors.support()) {
599 if (visited.insert(succ_id).second) {
600 queue.push_back(succ_id);
603 }
while (!queue.empty());
608template <
typename State,
typename Action,
bool UseInterval>
609void TALRTDP<State, Action, UseInterval>::print_statistics(
610 std::ostream& out)
const
612 return algorithm_.print_statistics(out);
TrialTerminationCondition
Enumeration type specifying the termination condition for trials sampled during LRTDP.
Definition lrtdp.h:25
Namespace dedicated to labelled real-time dynamic programming (LRTDP) with native trap handling suppo...
Definition trap_aware_lrtdp.h:22
Interval as_interval(value_t lower_bound)
Returns the interval with the given lower bound and infinte upper bound.
double value_t
Typedef for the state value type.
Definition aliases.h:7