AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
depth_first_heuristic_search_impl.h
1#ifndef GUARD_INCLUDE_PROBFD_ALGORITHMS_HEURISTIC_DEPTH_FIRST_SEARCH_H
2#error "This file should only be included from heuristic_depth_first_search.h"
3#endif
4
5#include "downward/utils/countdown_timer.h"
6
7#include <cassert>
8
10
11namespace internal {
12
13inline void Statistics::print(std::ostream& out) const
14{
15 out << " Iterations: " << iterations << std::endl;
16 out << " Value iterations: " << convergence_value_iterations << std::endl;
17 out << " Bellman backups (forward): " << forward_updates << std::endl;
18 out << " Bellman backups (backtracking): " << backtracking_updates
19 << std::endl;
20 out << " Bellman backups (convergence): " << convergence_updates
21 << std::endl;
22}
23
24bool ExpansionInfo::next_successor()
25{
26 successors.pop_back();
27 return !successors.empty();
28}
29
30StateID ExpansionInfo::get_current_successor() const
31{
32 return successors.back();
33}
34
35} // namespace internal
36
37template <typename State, typename Action, bool UseInterval>
38HeuristicDepthFirstSearch<State, Action, UseInterval>::
39 HeuristicDepthFirstSearch(
40 std::shared_ptr<PolicyPicker> policy_chooser,
41 bool forward_updates,
42 BacktrackingUpdateType backtrack_update_type,
43 bool cutoff_tip,
44 bool cutoff_inconsistent,
45 bool terminate_exploration_on_cutoff,
46 bool label_solved)
47 : Base(std::move(policy_chooser))
48 , forward_updates_(forward_updates)
49 , backtrack_update_type_(backtrack_update_type)
50 , cutoff_tip_(cutoff_tip)
51 , cutoff_inconsistent_(cutoff_inconsistent)
52 , terminate_exploration_on_cutoff_(terminate_exploration_on_cutoff)
53 , label_solved_(label_solved)
54{
55}
56
57template <typename State, typename Action, bool UseInterval>
62
63template <typename State, typename Action, bool UseInterval>
65 MDP& mdp,
66 Evaluator& heuristic,
68 ProgressReport& progress,
69 double max_time)
70{
71 utils::CountdownTimer timer(max_time);
72
73 const StateID stateid = mdp.get_state_id(state);
74 const StateInfo& state_info = this->state_infos_[stateid];
75
76 progress.register_bound("v", [&state_info]() {
77 return as_interval(state_info.value);
78 });
79
80 if (!label_solved_) {
81 solve_with_vi_termination(mdp, heuristic, stateid, progress, timer);
82 } else {
83 solve_without_vi_termination(mdp, heuristic, stateid, progress, timer);
84 }
85
86 return state_info.get_bounds();
87}
88
89template <typename State, typename Action, bool UseInterval>
91 print_additional_statistics(std::ostream& out) const
92{
93 statistics_.print(out);
94}
95
96template <typename State, typename Action, bool UseInterval>
99 MDP& mdp,
100 Evaluator& heuristic,
101 StateID stateid,
102 ProgressReport& progress,
103 utils::CountdownTimer& timer)
104{
105 bool terminate;
106 do {
107 terminate = policy_exploration<true>(mdp, heuristic, stateid, timer) &&
108 value_iteration(mdp, visited_, timer);
109
110 visited_.clear();
111 ++statistics_.iterations;
112 progress.print();
113 } while (!terminate);
114}
115
116template <typename State, typename Action, bool UseInterval>
117void HeuristicDepthFirstSearch<State, Action, UseInterval>::
118 solve_without_vi_termination(
119 MDP& mdp,
120 Evaluator& heuristic,
121 StateID stateid,
122 ProgressReport& progress,
123 utils::CountdownTimer& timer)
124{
125 bool terminate;
126 do {
127 terminate = policy_exploration<false>(mdp, heuristic, stateid, timer);
128 ++statistics_.iterations;
129 progress.print();
130 assert(visited_.empty());
131 } while (!terminate);
132}
133
134template <typename State, typename Action, bool UseInterval>
135template <bool GetVisited>
136bool HeuristicDepthFirstSearch<State, Action, UseInterval>::policy_exploration(
137 MDP& mdp,
138 Evaluator& heuristic,
139 StateID state,
140 utils::CountdownTimer& timer)
141{
142 using namespace internal;
143
144 ClearGuard _(local_state_infos_);
145
146 bool keep_expanding = true;
147
148 ExpansionInfo* einfo;
149 StateInfo* sinfo;
150 LocalStateInfo* lsinfo;
151
152 push(state);
153
154 for (;;) {
155 // DFS recursion
156 do {
157 einfo = &expansion_queue_.back();
158 sinfo = &this->state_infos_[einfo->stateid];
159 lsinfo = &local_state_infos_[einfo->stateid];
160 } while (initialize(mdp, heuristic, *einfo, *sinfo) &&
161 push_successor(mdp, *einfo, *sinfo, *lsinfo, timer));
162
163 // Iterative backtracking
164 do {
165 unsigned last_lowlink = lsinfo->lowlink;
166 bool last_solved = einfo->solved;
167 bool last_value_converged = einfo->value_converged;
168
169 if (lsinfo->index == lsinfo->lowlink) {
170 auto scc = stack_ | std::views::drop(lsinfo->index);
171
172 for (const StateID state_id : scc) {
173 local_state_infos_[state_id].status =
174 LocalStateInfo::CLOSED;
175
176 if (!einfo->solved) continue;
177
178 StateInfo& mem_info = this->state_infos_[state_id];
179 if (mem_info.is_solved()) continue;
180
181 if (label_solved_) {
182 mem_info.set_solved();
183 }
184
185 if constexpr (GetVisited) {
186 visited_.push_back(state_id);
187 }
188 }
189
190 stack_.erase(scc.begin(), scc.end());
191 }
192
193 expansion_queue_.pop_back();
194
195 if (expansion_queue_.empty()) return last_solved;
196
197 einfo = &expansion_queue_.back();
198 sinfo = &this->state_infos_[einfo->stateid];
199 lsinfo = &local_state_infos_[einfo->stateid];
200
201 lsinfo->lowlink = std::min(lsinfo->lowlink, last_lowlink);
202 einfo->solved =
203 einfo->solved && last_solved && last_value_converged;
204 einfo->value_converged =
205 einfo->value_converged && last_value_converged;
206
207 if (terminate_exploration_on_cutoff_ && !einfo->solved) {
208 keep_expanding = false;
209 }
210 } while (!keep_expanding || !advance(mdp, *einfo, *sinfo));
211 }
212}
213
214template <typename State, typename Action, bool UseInterval>
215bool HeuristicDepthFirstSearch<State, Action, UseInterval>::advance(
216 MDP& mdp,
217 ExpansionInfo& einfo,
218 StateInfo& state_info)
219{
220 using enum BacktrackingUpdateType;
221
222 if (einfo.next_successor()) {
223 return true;
224 }
225
226 if (backtrack_update_type_ == SINGLE ||
227 (backtrack_update_type_ == ON_DEMAND && !einfo.value_converged)) {
228 assert(!state_info.is_on_fringe());
229
230 const State state = mdp.get_state(einfo.stateid);
231 const value_t termination_cost =
232 mdp.get_termination_info(state).get_cost();
233
234 ClearGuard _(transitions_, qvalues_);
235 this->generate_non_tip_transitions(mdp, state, transitions_);
236
237 statistics_.backtracking_updates++;
238
239 auto value = this->compute_bellman_and_greedy(
240 mdp,
241 einfo.stateid,
242 transitions_,
243 termination_cost,
244 qvalues_);
245
246 auto transition = this->select_greedy_transition(
247 mdp,
248 state_info.get_policy(),
249 transitions_);
250
251 bool value_changed = this->update_value(state_info, value);
252 bool policy_changed = this->update_policy(state_info, transition);
253
254 // Note: it is only necessary to check whether eps-consistency
255 // was reached on backward update when both directions are
256 // enabled
257 einfo.value_converged = !value_changed;
258 einfo.solved = einfo.solved && !value_changed && !policy_changed;
259 }
260
261 return false;
262}
263
264template <typename State, typename Action, bool UseInterval>
265bool HeuristicDepthFirstSearch<State, Action, UseInterval>::push_successor(
266 MDP& mdp,
267 ExpansionInfo& einfo,
268 StateInfo& sinfo,
269 LocalStateInfo& lsinfo,
270 utils::CountdownTimer& timer)
271{
272 using namespace internal;
273
274 do {
275 timer.throw_if_expired();
276
277 const StateID succid = einfo.get_current_successor();
278 const LocalStateInfo& succ_info = local_state_infos_[succid];
279
280 const int succ_status = succ_info.status;
281
282 if (succ_status == LocalStateInfo::NEW) {
283 push(succid);
284 return true;
285 } else if (succ_status == LocalStateInfo::CLOSED) {
286 if (label_solved_) {
287 einfo.solved =
288 einfo.solved && this->state_infos_[succid].is_solved();
289 }
290 } else {
291 assert(succ_status == LocalStateInfo::ONSTACK);
292 lsinfo.lowlink = std::min(lsinfo.lowlink, succ_info.index);
293 }
294 } while (advance(mdp, einfo, sinfo));
295
296 return false;
297}
298
299template <typename State, typename Action, bool UseInterval>
300void HeuristicDepthFirstSearch<State, Action, UseInterval>::push(
301 StateID stateid)
302{
303 LocalStateInfo& info = local_state_infos_[stateid];
304 info.status = LocalStateInfo::ONSTACK;
305 info.open(stack_.size());
306 stack_.push_back(stateid);
307 expansion_queue_.emplace_back(stateid);
308}
309
310template <typename State, typename Action, bool UseInterval>
311bool HeuristicDepthFirstSearch<State, Action, UseInterval>::initialize(
312 MDP& mdp,
313 Evaluator& heuristic,
314 ExpansionInfo& einfo,
315 StateInfo& sinfo)
316{
317 using namespace internal;
318
319 // Ignore labels if labelling option is turned off
320 if (sinfo.is_solved()) {
321 assert(label_solved_ || sinfo.is_goal_or_terminal());
322 return false;
323 }
324
325 const StateID stateid = einfo.stateid;
326
327 const bool is_tip_state = sinfo.is_on_fringe();
328
329 if (forward_updates_ || is_tip_state) {
330 const State state = mdp.get_state(einfo.stateid);
331 const value_t termination_cost =
332 mdp.get_termination_info(state).get_cost();
333
334 ClearGuard _(transitions_, qvalues_);
335
336 if (is_tip_state) {
337 this->expand_and_initialize(
338 mdp,
339 heuristic,
340 state,
341 sinfo,
342 transitions_);
343 } else {
344 this->generate_non_tip_transitions(mdp, state, transitions_);
345 }
346
347 statistics_.forward_updates++;
348
349 auto value = this->compute_bellman_and_greedy(
350 mdp,
351 einfo.stateid,
352 transitions_,
353 termination_cost,
354 qvalues_);
355
356 auto transition = this->select_greedy_transition(
357 mdp,
358 sinfo.get_policy(),
359 transitions_);
360
361 einfo.value_converged = !this->update_value(sinfo, value);
362 this->update_policy(sinfo, transition);
363
364 if constexpr (UseInterval) {
365 einfo.value_converged = einfo.value_converged &&
366 sinfo.value.bounds_approximately_equal();
367 }
368
369 if (!transition) {
370 return false;
371 }
372
373 const bool cutoff = (cutoff_tip_ && is_tip_state) ||
374 (cutoff_inconsistent_ && !einfo.value_converged);
375
376 if (cutoff) {
377 einfo.solved = false;
378 return false;
379 }
380
381 einfo.successors =
382 std::ranges::to<std::vector>(transition->successor_dist.support());
383 } else {
384 const auto action = sinfo.get_policy();
385 if (!action.has_value()) return false;
386
387 const State state = mdp.get_state(stateid);
388 Distribution<StateID> successor_dist;
389 mdp.generate_action_transitions(state, *action, successor_dist);
390 einfo.successors =
391 std::ranges::to<std::vector>(successor_dist.support());
392 }
393
394 return true;
395}
396
397template <typename State, typename Action, bool UseInterval>
398bool HeuristicDepthFirstSearch<State, Action, UseInterval>::value_iteration(
399 MDP& mdp,
400 const std::ranges::input_range auto& range,
401 utils::CountdownTimer& timer)
402{
403 ++statistics_.convergence_value_iterations;
404
405 for (;;) {
406 auto [value_changed, policy_changed] =
407 vi_step(mdp, range, timer, statistics_.convergence_updates);
408
409 if (policy_changed) return false;
410 if (!value_changed) break;
411 }
412
413 return true;
414}
415
416template <typename State, typename Action, bool UseInterval>
417std::pair<bool, bool>
418HeuristicDepthFirstSearch<State, Action, UseInterval>::vi_step(
419 MDP& mdp,
420 const std::ranges::input_range auto& range,
421 utils::CountdownTimer& timer,
422 unsigned long long& stat_counter)
423{
424 bool values_not_conv = false;
425 bool policy_not_conv = false;
426
427 for (const StateID id : range) {
428 timer.throw_if_expired();
429
430 StateInfo& state_info = this->state_infos_[id];
431
432 const State state = mdp.get_state(id);
433 const value_t termination_cost =
434 mdp.get_termination_info(state).get_cost();
435
436 ClearGuard _(transitions_, qvalues_);
437
438 this->generate_non_tip_transitions(mdp, state, transitions_);
439
440 const auto value = this->compute_bellman_and_greedy(
441 mdp,
442 id,
443 transitions_,
444 termination_cost,
445 qvalues_);
446
447 ++stat_counter;
448
449 auto transition = this->select_greedy_transition(
450 mdp,
451 state_info.get_policy(),
452 transitions_);
453
454 bool value_changed = this->update_value(state_info, value);
455 bool policy_changed = this->update_policy(state_info, transition);
456 values_not_conv = values_not_conv || value_changed;
457 policy_not_conv = policy_not_conv || policy_changed;
458
459 if constexpr (UseInterval) {
460 values_not_conv = values_not_conv ||
461 !state_info.value.bounds_approximately_equal();
462 }
463 }
464
465 return std::make_pair(values_not_conv, policy_not_conv);
466}
467
468} // namespace probfd::algorithms::heuristic_depth_first_search
The interface representing heuristic functions.
Definition mdp_algorithm.h:16
Basic interface for MDPs.
Definition mdp_algorithm.h:14
A registry for print functions related to search progress.
Definition progress_report.h:33
void print()
Prints the output to the internal output stream, if enabled.
void register_bound(const std::string &property_name, BoundProperty property)
Appends a new bound property with a given name to the list of bound properties to be printed when the...
virtual StateID get_state_id(param_type< State > state)=0
Get the state ID for a given state.
Implementation of the depth-first heuristic search algorithm family steinmetz:etal:icaps-16.
Definition depth_first_heuristic_search.h:104
Namespace dedicated to Depth-First Heuristic Search.
Definition depth_first_heuristic_search.h:17
Interval as_interval(value_t lower_bound)
Returns the interval with the given lower bound and infinte upper bound.
double value_t
Typedef for the state value type.
Definition aliases.h:7
typename std::conditional_t< is_cheap_to_copy_v< T >, T, const T & > param_type
Alias template defining the best way to pass a parameter of a given type.
Definition type_traits.h:25
STL namespace.
Represents a closed interval over the extended reals as a pair of lower and upper bound.
Definition interval.h:12
A StateID represents a state within a StateIDMap. Just like Fast Downward's StateID type,...
Definition types.h:22