AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
ao_star_impl.h
1#ifndef GUARD_INCLUDE_PROBFD_ALGORITHMS_AO_STAR_H
2#error "This file should only be included from ao_star.h"
3#endif
4
5#include "probfd/algorithms/utils.h"
6
7#include "probfd/algorithms/successor_sampler.h"
8
9#include "probfd/views/concat.h"
10
11#include "probfd/progress_report.h"
12
13#include "downward/utils/countdown_timer.h"
14
15#include <ranges>
16#include <type_traits>
17
19
20template <typename State, typename Action, bool UseInterval>
21AOStar<State, Action, UseInterval>::AOStar(
22 std::shared_ptr<PolicyPickerType> policy_chooser,
23 std::shared_ptr<SuccessorSamplerType> outcome_selection)
24 : Base(std::move(policy_chooser))
25 , outcome_selection_(std::move(outcome_selection))
26{
27}
28
29template <typename State, typename Action, bool UseInterval>
30Interval AOStar<State, Action, UseInterval>::do_solve(
31 MDPType& mdp,
32 EvaluatorType& heuristic,
33 param_type<State> initial_state,
34 ProgressReport& progress,
35 double max_time)
36{
37 using namespace std::views;
38
39 utils::CountdownTimer timer(max_time);
40
41 const StateID initstateid = mdp.get_state_id(initial_state);
42 auto& iinfo = this->state_infos_[initstateid];
43
44 progress.register_bound("v", [&iinfo]() {
45 return as_interval(iinfo.value);
46 });
47
48 progress.register_print([&](std::ostream& out) {
49 out << "i=" << this->statistics_.iterations;
50 });
51
52 iinfo.update_order = 0;
53
54 for (; !iinfo.is_solved(); progress.print()) {
55 StateID stateid = initstateid;
56
57 for (;;) {
58 timer.throw_if_expired();
59
60 auto& info = this->state_infos_[stateid];
61 assert(!info.is_solved());
62
63 const State state = mdp.get_state(stateid);
64
65 if (info.is_on_fringe()) {
66 ClearGuard _(transitions_);
67 this->expand_and_initialize(
68 mdp,
69 heuristic,
70 state,
71 info,
72 transitions_);
73
74 bool value_changed = this->update_value_check_solved(
75 mdp,
76 state,
77 transitions_,
78 info);
79
80 if (info.is_solved()) {
81 transitions_.clear();
82 this->backpropagate_tip_value(
83 mdp,
84 transitions_,
85 info,
86 timer);
87 break;
88 }
89
90 auto all_successors = transitions_ | transform([](auto& t) {
91 return t.successor_dist.support();
92 }) |
93 std::views::join;
94
95 unsigned min_succ_order = std::numeric_limits<unsigned>::max();
96
97 for (const StateID succ_id : all_successors) {
98 auto& succ_info = this->state_infos_[succ_id];
99
100 if (succ_info.is_marked() || succ_info.is_solved())
101 continue;
102
103 succ_info.mark();
104 succ_info.add_parent(stateid);
105 assert(
106 succ_info.update_order <
107 std::numeric_limits<unsigned>::max());
108 min_succ_order =
109 std::min(min_succ_order, succ_info.update_order);
110 }
111
112 assert(min_succ_order < std::numeric_limits<unsigned>::max());
113
114 for (const StateID succ_id : all_successors) {
115 this->state_infos_[succ_id].unmark();
116 }
117
118 this->backpropagate_update_order(
119 stateid,
120 info,
121 min_succ_order + 1,
122 timer);
123
124 if (value_changed) {
125 transitions_.clear();
126 this->backpropagate_tip_value(
127 mdp,
128 transitions_,
129 info,
130 timer);
131 break;
132 }
133 }
134
135 assert(
136 !info.is_on_fringe() && !info.is_goal_or_terminal() &&
137 !info.is_solved());
138
139 const auto action = info.get_policy();
140
141 assert(action.has_value());
142
143 ClearGuard guard(successor_dist_);
144
145 mdp.generate_action_transitions(state, *action, successor_dist_);
146
147 successor_dist_.remove_if_normalize([this](const auto& target) {
148 return this->state_infos_[target.item].is_solved();
149 });
150
151 stateid = outcome_selection_->sample(
152 stateid,
153 *action,
154 successor_dist_,
155 this->state_infos_);
156 }
157
158 ++this->statistics_.iterations;
159 }
160
161 return iinfo.get_bounds();
162}
163
164template <typename State, typename Action, bool UseInterval>
165bool AOStar<State, Action, UseInterval>::update_value_check_solved(
166 MDPType& mdp,
167 param_type<State> state,
168 std::vector<Transition<Action>> transitions,
169 StateInfo& info)
170{
171 const value_t termination_cost = mdp.get_termination_info(state).get_cost();
172
173 const auto value = this->compute_bellman_and_greedy(
174 mdp,
175 mdp.get_state_id(state),
176 transitions,
177 termination_cost,
178 qvalues_);
179
180 auto greedy_transition =
181 this->select_greedy_transition(mdp, info.get_policy(), transitions_);
182
183 bool value_changed = this->update_value(info, value);
184 this->update_policy(info, greedy_transition);
185
186 bool all_succs_solved = true;
187
188 if (greedy_transition) {
189 for (const auto succ_id : greedy_transition->successor_dist.support()) {
190 const auto& succ_info = this->state_infos_[succ_id];
191 all_succs_solved = all_succs_solved && succ_info.is_solved();
192 }
193 }
194
195 if (all_succs_solved) {
196 info.set_solved();
197 }
198
199 return value_changed;
200}
201
202} // namespace probfd::algorithms::ao_search::ao_star
Namespace dedicated to the AO* algorithm.
Definition ao_star.h:15
Interval as_interval(value_t lower_bound)
Returns the interval with the given lower bound and infinte upper bound.
double value_t
Typedef for the state value type.
Definition aliases.h:7
STL namespace.