AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
lrtdp_impl.h
1#ifndef GUARD_INCLUDE_PROBFD_ALGORITHMS_LRTDP_H
2#error "This file should only be included from lrtdp.h"
3#endif
4
5#include "probfd/algorithms/successor_sampler.h"
6
7#include "downward/utils/countdown_timer.h"
8
9#include <cassert>
10#include <ranges>
11
13
14namespace internal {
15
16inline void Statistics::print(std::ostream& out) const
17{
18 out << " Trials: " << trials << std::endl;
19 out << " Bellman backups (trials): " << trial_bellman_backups << std::endl;
20 out << " Bellman backups (check&solved): "
21 << check_and_solve_bellman_backups << std::endl;
22}
23
24} // namespace internal
25
26template <typename State, typename Action, bool UseInterval>
28 std::shared_ptr<PolicyPickerType> policy_chooser,
29 TrialTerminationCondition stop_consistent,
30 std::shared_ptr<SuccessorSamplerType> succ_sampler)
31 : Base(policy_chooser)
32 , stop_consistent_(stop_consistent)
33 , sample_(succ_sampler)
34{
35}
36
37template <typename State, typename Action, bool UseInterval>
39{
40 this->state_infos_.reset();
41}
42
43template <typename State, typename Action, bool UseInterval>
45 MDPType& mdp,
46 EvaluatorType& heuristic,
48 ProgressReport& progress,
49 double max_time)
50{
51 utils::CountdownTimer timer(max_time);
52
53 const StateID state_id = mdp.get_state_id(state);
54 const StateInfo& state_info = this->state_infos_[state_id];
55
56 progress.register_bound("v", [&state_info]() {
57 return as_interval(state_info.value);
58 });
59
60 progress.register_print(
61 [&](std::ostream& out) { out << "trials=" << statistics_.trials; });
62
63 while (!state_info.is_solved()) {
64 trial(mdp, heuristic, state_id, timer);
65 this->statistics_.trials++;
66 progress.print();
67 }
68
69 return state_info.get_bounds();
70}
71
72template <typename State, typename Action, bool UseInterval>
74 std::ostream& out) const
75{
76 statistics_.print(out);
77}
78
79template <typename State, typename Action, bool UseInterval>
81 MDPType& mdp,
82 EvaluatorType& heuristic,
83 StateID initial_state,
84 utils::CountdownTimer& timer)
85{
86 assert(!this->state_infos_[initial_state].is_solved());
87
89
90 ClearGuard guard(current_trial_);
91
92 current_trial_.push_back(initial_state);
93
94 for (;;) {
95 timer.throw_if_expired();
96
97 const StateID state_id = current_trial_.back();
98
99 auto& state_info = this->state_infos_[state_id];
100
101 if (state_info.is_solved()) {
102 current_trial_.pop_back();
103 break;
104 }
105
106 const State state = mdp.get_state(state_id);
107 const value_t termination_cost =
108 mdp.get_termination_info(state).get_cost();
109
110 ClearGuard _(transitions_, qvalues_);
111
112 if (state_info.is_on_fringe()) {
113 this->expand_and_initialize(
114 mdp,
115 heuristic,
116 state,
117 state_info,
118 transitions_);
119 } else {
120 this->generate_non_tip_transitions(mdp, state, transitions_);
121 }
122
123 this->statistics_.trial_bellman_backups++;
124
125 auto value = this->compute_bellman_and_greedy(
126 mdp,
127 state_id,
128 transitions_,
129 termination_cost,
130 qvalues_);
131
132 auto transition = this->select_greedy_transition(
133 mdp,
134 state_info.get_policy(),
135 transitions_);
136
137 bool value_changed = this->update_value(state_info, value);
138 this->update_policy(state_info, transition);
139
140 if (!transition) {
141 state_info.mark_solved();
142 current_trial_.pop_back();
143 break;
144 }
145
146 assert(!state_info.is_goal_or_terminal());
147
148 if ((stop_consistent_ == CONSISTENT && !value_changed) ||
149 (stop_consistent_ == INCONSISTENT && value_changed) ||
150 (stop_consistent_ == REVISITED && state_info.is_closed())) {
151 break;
152 }
153
154 if (stop_consistent_ == REVISITED) {
155 state_info.mark_closed();
156 }
157
158 auto next = sample_->sample(
159 state_id,
160 transition->action,
161 transition->successor_dist,
162 this->state_infos_);
163
164 current_trial_.push_back(next);
165 }
166
167 if (stop_consistent_ == REVISITED) {
168 for (const StateID state :
169 current_trial_ | std::views::reverse | std::views::drop(1)) {
170 auto& info = this->state_infos_[state];
171 assert(info.is_closed());
172 info.unmark_closed();
173 }
174 }
175
176 do {
177 timer.throw_if_expired();
178
179 if (!check_and_solve(mdp, heuristic, current_trial_.back(), timer)) {
180 break;
181 }
182
183 current_trial_.pop_back();
184 } while (!current_trial_.empty());
185}
186
187template <typename State, typename Action, bool UseInterval>
188bool LRTDP<State, Action, UseInterval>::check_and_solve(
189 MDPType& mdp,
190 EvaluatorType& heuristic,
191 StateID init_state_id,
192 utils::CountdownTimer& timer)
193{
194 assert(!current_trial_.empty() && policy_queue_.empty());
195
196 ClearGuard guard(visited_);
197
198 {
199 StateInfo& state_info = this->state_infos_[init_state_id];
200 if (state_info.is_solved()) return true;
201 policy_queue_.emplace_back(init_state_id);
202 state_info.mark_closed();
203 }
204
205 bool rv = true;
206
207 do {
208 timer.throw_if_expired();
209
210 const auto state_id = policy_queue_.back();
211 policy_queue_.pop_back();
212
213 auto& info = this->state_infos_[state_id];
214 assert(!info.is_solved());
215 assert(info.is_closed());
216
217 visited_.push_front(state_id);
218
219 const State state = mdp.get_state(state_id);
220 const value_t termination_cost =
221 mdp.get_termination_info(state).get_cost();
222
223 ClearGuard _(transitions_, qvalues_);
224
225 if (info.is_on_fringe()) {
226 this->expand_and_initialize(
227 mdp,
228 heuristic,
229 state,
230 info,
231 transitions_);
232 } else {
233 this->generate_non_tip_transitions(mdp, state, transitions_);
234 }
235
236 this->statistics_.check_and_solve_bellman_backups++;
237
238 auto value = this->compute_bellman_and_greedy(
239 mdp,
240 state_id,
241 transitions_,
242 termination_cost,
243 qvalues_);
244
245 auto transition = this->select_greedy_transition(
246 mdp,
247 info.get_policy(),
248 transitions_);
249
250 bool value_changed = this->update_value(info, value);
251 this->update_policy(info, transition);
252
253 if constexpr (UseInterval) {
254 if (!info.bounds_agree()) {
255 rv = false;
256 continue;
257 }
258 } else {
259 if (value_changed) {
260 rv = false;
261 continue;
262 }
263 }
264
265 if (!transition) {
266 info.mark_solved();
267 continue;
268 }
269
270 for (StateID succ_id : transition->successor_dist.support()) {
271 StateInfo& succ_info = this->state_infos_[succ_id];
272 if (!succ_info.is_closed() && !succ_info.is_solved()) {
273 succ_info.mark_closed();
274 policy_queue_.emplace_back(succ_id);
275 }
276 }
277 } while (!policy_queue_.empty());
278
279 for (StateID sid : visited_) {
280 StateInfo& info = this->state_infos_[sid];
281
282 if (info.is_solved()) continue;
283
284 assert(info.is_closed());
285 info.unmark_closed();
286
287 if (rv) {
288 info.mark_solved();
289 } else {
290 assert(!info.is_on_fringe());
291
292 const State state = mdp.get_state(sid);
293 const value_t termination_cost =
294 mdp.get_termination_info(state).get_cost();
295
296 ClearGuard _(transitions_, qvalues_);
297 this->generate_non_tip_transitions(mdp, state, transitions_);
298
299 statistics_.check_and_solve_bellman_backups++;
300
301 auto value = this->compute_bellman_and_greedy(
302 mdp,
303 sid,
304 transitions_,
305 termination_cost,
306 qvalues_);
307
308 auto transition = this->select_greedy_transition(
309 mdp,
310 info.get_policy(),
311 transitions_);
312
313 this->update_value(info, value);
314 this->update_policy(info, transition);
315 }
316 }
317
318 return rv;
319}
320
321} // namespace probfd::algorithms::lrtdp
A registry for print functions related to search progress.
Definition progress_report.h:33
void print()
Prints the output to the internal output stream, if enabled.
void register_bound(const std::string &property_name, BoundProperty property)
Appends a new bound property with a given name to the list of bound properties to be printed when the...
void register_print(Printer f)
Appends a new printer to the list of printers.
Helper RAII class that ensures that containers are cleared when going out of scope.
Definition utils.h:23
Implements the labelled real-time dynamic programming (LRTDP) algorithm bonet:geffner:icaps-03.
Definition lrtdp.h:129
LRTDP(std::shared_ptr< PolicyPickerType > policy_chooser, TrialTerminationCondition stop_consistent, std::shared_ptr< SuccessorSamplerType > succ_sampler)
Constructs an LRTDP solver object.
Definition lrtdp_impl.h:27
void print_additional_statistics(std::ostream &out) const override
Prints additional statistics to the output stream.
Definition lrtdp_impl.h:73
void reset_search_state() override
Resets the h search algorithm object to a clean state.
Definition lrtdp_impl.h:38
Namespace dedicated to labelled real-time dynamic programming (LRTDP).
Definition lrtdp.h:19
TrialTerminationCondition
Enumeration type specifying the termination condition for trials sampled during LRTDP.
Definition lrtdp.h:25
Interval as_interval(value_t lower_bound)
Returns the interval with the given lower bound and infinte upper bound.
double value_t
Typedef for the state value type.
Definition aliases.h:7
typename std::conditional_t< is_cheap_to_copy_v< T >, T, const T & > param_type
Alias template defining the best way to pass a parameter of a given type.
Definition type_traits.h:25
Represents a closed interval over the extended reals as a pair of lower and upper bound.
Definition interval.h:12
A StateID represents a state within a StateIDMap. Just like Fast Downward's StateID type,...
Definition types.h:22