1#ifndef GUARD_INCLUDE_PROBFD_ALGORITHMS_AO_STAR_H
2#error "This file should only be included from ao_star.h"
5#include "probfd/algorithms/utils.h"
7#include "probfd/algorithms/successor_sampler.h"
9#include "probfd/views/concat.h"
11#include "probfd/progress_report.h"
13#include "downward/utils/countdown_timer.h"
20template <
typename State,
typename Action,
bool UseInterval>
21AOStar<State, Action, UseInterval>::AOStar(
22 std::shared_ptr<PolicyPickerType> policy_chooser,
23 std::shared_ptr<SuccessorSamplerType> outcome_selection)
24 : Base(
std::move(policy_chooser))
25 , outcome_selection_(
std::move(outcome_selection))
29template <
typename State,
typename Action,
bool UseInterval>
30Interval AOStar<State, Action, UseInterval>::do_solve(
32 EvaluatorType& heuristic,
33 param_type<State> initial_state,
34 ProgressReport& progress,
37 using namespace std::views;
39 utils::CountdownTimer timer(max_time);
41 const StateID initstateid = mdp.get_state_id(initial_state);
42 auto& iinfo = this->state_infos_[initstateid];
44 progress.register_bound(
"v", [&iinfo]() {
48 progress.register_print([&](std::ostream& out) {
49 out <<
"i=" << this->statistics_.iterations;
52 iinfo.update_order = 0;
54 for (; !iinfo.is_solved(); progress.print()) {
55 StateID stateid = initstateid;
58 timer.throw_if_expired();
60 auto& info = this->state_infos_[stateid];
61 assert(!info.is_solved());
63 const State state = mdp.get_state(stateid);
65 if (info.is_on_fringe()) {
66 ClearGuard _(transitions_);
67 this->expand_and_initialize(
74 bool value_changed = this->update_value_check_solved(
80 if (info.is_solved()) {
82 this->backpropagate_tip_value(
90 auto all_successors = transitions_ | transform([](
auto& t) {
91 return t.successor_dist.support();
95 unsigned min_succ_order = std::numeric_limits<unsigned>::max();
97 for (
const StateID succ_id : all_successors) {
98 auto& succ_info = this->state_infos_[succ_id];
100 if (succ_info.is_marked() || succ_info.is_solved())
104 succ_info.add_parent(stateid);
106 succ_info.update_order <
107 std::numeric_limits<unsigned>::max());
109 std::min(min_succ_order, succ_info.update_order);
112 assert(min_succ_order < std::numeric_limits<unsigned>::max());
114 for (
const StateID succ_id : all_successors) {
115 this->state_infos_[succ_id].unmark();
118 this->backpropagate_update_order(
125 transitions_.clear();
126 this->backpropagate_tip_value(
136 !info.is_on_fringe() && !info.is_goal_or_terminal() &&
139 const auto action = info.get_policy();
141 assert(action.has_value());
143 ClearGuard guard(successor_dist_);
145 mdp.generate_action_transitions(state, *action, successor_dist_);
147 successor_dist_.remove_if_normalize([
this](
const auto& target) {
148 return this->state_infos_[target.item].is_solved();
151 stateid = outcome_selection_->sample(
158 ++this->statistics_.iterations;
161 return iinfo.get_bounds();
164template <
typename State,
typename Action,
bool UseInterval>
165bool AOStar<State, Action, UseInterval>::update_value_check_solved(
167 param_type<State> state,
168 std::vector<Transition<Action>> transitions,
171 const value_t termination_cost = mdp.get_termination_info(state).get_cost();
173 const auto value = this->compute_bellman_and_greedy(
175 mdp.get_state_id(state),
180 auto greedy_transition =
181 this->select_greedy_transition(mdp, info.get_policy(), transitions_);
183 bool value_changed = this->update_value(info, value);
184 this->update_policy(info, greedy_transition);
186 bool all_succs_solved =
true;
188 if (greedy_transition) {
189 for (
const auto succ_id : greedy_transition->successor_dist.support()) {
190 const auto& succ_info = this->state_infos_[succ_id];
191 all_succs_solved = all_succs_solved && succ_info.is_solved();
195 if (all_succs_solved) {
199 return value_changed;
Namespace dedicated to the AO* algorithm.
Definition ao_star.h:15
Interval as_interval(value_t lower_bound)
Returns the interval with the given lower bound and infinte upper bound.
double value_t
Typedef for the state value type.
Definition aliases.h:7