1#ifndef GUARD_INCLUDE_PROBFD_ALGORITHMS_HEURISTIC_DEPTH_FIRST_SEARCH_H
2#error "This file should only be included from heuristic_depth_first_search.h"
5#include "downward/utils/countdown_timer.h"
13inline void Statistics::print(std::ostream& out)
const
15 out <<
" Iterations: " << iterations << std::endl;
16 out <<
" Value iterations: " << convergence_value_iterations << std::endl;
17 out <<
" Bellman backups (forward): " << forward_updates << std::endl;
18 out <<
" Bellman backups (backtracking): " << backtracking_updates
20 out <<
" Bellman backups (convergence): " << convergence_updates
24bool ExpansionInfo::next_successor()
26 successors.pop_back();
27 return !successors.empty();
30StateID ExpansionInfo::get_current_successor()
const
32 return successors.back();
37template <
typename State,
typename Action,
bool UseInterval>
38HeuristicDepthFirstSearch<State, Action, UseInterval>::
39 HeuristicDepthFirstSearch(
40 std::shared_ptr<PolicyPicker> policy_chooser,
42 BacktrackingUpdateType backtrack_update_type,
44 bool cutoff_inconsistent,
45 bool terminate_exploration_on_cutoff,
47 : Base(
std::move(policy_chooser))
48 , forward_updates_(forward_updates)
49 , backtrack_update_type_(backtrack_update_type)
50 , cutoff_tip_(cutoff_tip)
51 , cutoff_inconsistent_(cutoff_inconsistent)
52 , terminate_exploration_on_cutoff_(terminate_exploration_on_cutoff)
53 , label_solved_(label_solved)
57template <
typename State,
typename Action,
bool UseInterval>
60 this->state_infos_.reset();
63template <
typename State,
typename Action,
bool UseInterval>
71 utils::CountdownTimer timer(max_time);
74 const StateInfo& state_info = this->state_infos_[stateid];
81 solve_with_vi_termination(mdp, heuristic, stateid, progress, timer);
83 solve_without_vi_termination(mdp, heuristic, stateid, progress, timer);
86 return state_info.get_bounds();
89template <
typename State,
typename Action,
bool UseInterval>
93 statistics_.print(out);
96template <
typename State,
typename Action,
bool UseInterval>
103 utils::CountdownTimer& timer)
107 terminate = policy_exploration<true>(mdp, heuristic, stateid, timer) &&
108 value_iteration(mdp, visited_, timer);
111 ++statistics_.iterations;
113 }
while (!terminate);
116template <
typename State,
typename Action,
bool UseInterval>
117void HeuristicDepthFirstSearch<State, Action, UseInterval>::
118 solve_without_vi_termination(
123 utils::CountdownTimer& timer)
127 terminate = policy_exploration<false>(mdp, heuristic, stateid, timer);
128 ++statistics_.iterations;
130 assert(visited_.empty());
131 }
while (!terminate);
134template <
typename State,
typename Action,
bool UseInterval>
135template <
bool GetVisited>
136bool HeuristicDepthFirstSearch<State, Action, UseInterval>::policy_exploration(
138 Evaluator& heuristic,
140 utils::CountdownTimer& timer)
142 using namespace internal;
144 ClearGuard _(local_state_infos_);
146 bool keep_expanding =
true;
148 ExpansionInfo* einfo;
150 LocalStateInfo* lsinfo;
157 einfo = &expansion_queue_.back();
158 sinfo = &this->state_infos_[einfo->stateid];
159 lsinfo = &local_state_infos_[einfo->stateid];
160 }
while (initialize(mdp, heuristic, *einfo, *sinfo) &&
161 push_successor(mdp, *einfo, *sinfo, *lsinfo, timer));
165 unsigned last_lowlink = lsinfo->lowlink;
166 bool last_solved = einfo->solved;
167 bool last_value_converged = einfo->value_converged;
169 if (lsinfo->index == lsinfo->lowlink) {
170 auto scc = stack_ | std::views::drop(lsinfo->index);
172 for (
const StateID state_id : scc) {
173 local_state_infos_[state_id].status =
174 LocalStateInfo::CLOSED;
176 if (!einfo->solved)
continue;
178 StateInfo& mem_info = this->state_infos_[state_id];
179 if (mem_info.is_solved())
continue;
182 mem_info.set_solved();
185 if constexpr (GetVisited) {
186 visited_.push_back(state_id);
190 stack_.erase(scc.begin(), scc.end());
193 expansion_queue_.pop_back();
195 if (expansion_queue_.empty())
return last_solved;
197 einfo = &expansion_queue_.back();
198 sinfo = &this->state_infos_[einfo->stateid];
199 lsinfo = &local_state_infos_[einfo->stateid];
201 lsinfo->lowlink = std::min(lsinfo->lowlink, last_lowlink);
203 einfo->solved && last_solved && last_value_converged;
204 einfo->value_converged =
205 einfo->value_converged && last_value_converged;
207 if (terminate_exploration_on_cutoff_ && !einfo->solved) {
208 keep_expanding =
false;
210 }
while (!keep_expanding || !advance(mdp, *einfo, *sinfo));
214template <
typename State,
typename Action,
bool UseInterval>
215bool HeuristicDepthFirstSearch<State, Action, UseInterval>::advance(
217 ExpansionInfo& einfo,
218 StateInfo& state_info)
220 using enum BacktrackingUpdateType;
222 if (einfo.next_successor()) {
226 if (backtrack_update_type_ == SINGLE ||
227 (backtrack_update_type_ == ON_DEMAND && !einfo.value_converged)) {
228 assert(!state_info.is_on_fringe());
230 const State state = mdp.get_state(einfo.stateid);
231 const value_t termination_cost =
232 mdp.get_termination_info(state).get_cost();
234 ClearGuard _(transitions_, qvalues_);
235 this->generate_non_tip_transitions(mdp, state, transitions_);
237 statistics_.backtracking_updates++;
239 auto value = this->compute_bellman_and_greedy(
246 auto transition = this->select_greedy_transition(
248 state_info.get_policy(),
251 bool value_changed = this->update_value(state_info, value);
252 bool policy_changed = this->update_policy(state_info, transition);
257 einfo.value_converged = !value_changed;
258 einfo.solved = einfo.solved && !value_changed && !policy_changed;
264template <
typename State,
typename Action,
bool UseInterval>
265bool HeuristicDepthFirstSearch<State, Action, UseInterval>::push_successor(
267 ExpansionInfo& einfo,
269 LocalStateInfo& lsinfo,
270 utils::CountdownTimer& timer)
272 using namespace internal;
275 timer.throw_if_expired();
277 const StateID succid = einfo.get_current_successor();
278 const LocalStateInfo& succ_info = local_state_infos_[succid];
280 const int succ_status = succ_info.status;
282 if (succ_status == LocalStateInfo::NEW) {
285 }
else if (succ_status == LocalStateInfo::CLOSED) {
288 einfo.solved && this->state_infos_[succid].is_solved();
291 assert(succ_status == LocalStateInfo::ONSTACK);
292 lsinfo.lowlink = std::min(lsinfo.lowlink, succ_info.index);
294 }
while (advance(mdp, einfo, sinfo));
299template <
typename State,
typename Action,
bool UseInterval>
300void HeuristicDepthFirstSearch<State, Action, UseInterval>::push(
303 LocalStateInfo& info = local_state_infos_[stateid];
304 info.status = LocalStateInfo::ONSTACK;
305 info.open(stack_.size());
306 stack_.push_back(stateid);
307 expansion_queue_.emplace_back(stateid);
310template <
typename State,
typename Action,
bool UseInterval>
311bool HeuristicDepthFirstSearch<State, Action, UseInterval>::initialize(
313 Evaluator& heuristic,
314 ExpansionInfo& einfo,
317 using namespace internal;
320 if (sinfo.is_solved()) {
321 assert(label_solved_ || sinfo.is_goal_or_terminal());
325 const StateID stateid = einfo.stateid;
327 const bool is_tip_state = sinfo.is_on_fringe();
329 if (forward_updates_ || is_tip_state) {
330 const State state = mdp.get_state(einfo.stateid);
331 const value_t termination_cost =
332 mdp.get_termination_info(state).get_cost();
334 ClearGuard _(transitions_, qvalues_);
337 this->expand_and_initialize(
344 this->generate_non_tip_transitions(mdp, state, transitions_);
347 statistics_.forward_updates++;
349 auto value = this->compute_bellman_and_greedy(
356 auto transition = this->select_greedy_transition(
361 einfo.value_converged = !this->update_value(sinfo, value);
362 this->update_policy(sinfo, transition);
364 if constexpr (UseInterval) {
365 einfo.value_converged = einfo.value_converged &&
366 sinfo.value.bounds_approximately_equal();
373 const bool cutoff = (cutoff_tip_ && is_tip_state) ||
374 (cutoff_inconsistent_ && !einfo.value_converged);
377 einfo.solved =
false;
382 std::ranges::to<std::vector>(transition->successor_dist.support());
384 const auto action = sinfo.get_policy();
385 if (!action.has_value())
return false;
387 const State state = mdp.get_state(stateid);
388 Distribution<StateID> successor_dist;
389 mdp.generate_action_transitions(state, *action, successor_dist);
391 std::ranges::to<std::vector>(successor_dist.support());
397template <
typename State,
typename Action,
bool UseInterval>
398bool HeuristicDepthFirstSearch<State, Action, UseInterval>::value_iteration(
400 const std::ranges::input_range
auto& range,
401 utils::CountdownTimer& timer)
403 ++statistics_.convergence_value_iterations;
406 auto [value_changed, policy_changed] =
407 vi_step(mdp, range, timer, statistics_.convergence_updates);
409 if (policy_changed)
return false;
410 if (!value_changed)
break;
416template <
typename State,
typename Action,
bool UseInterval>
418HeuristicDepthFirstSearch<State, Action, UseInterval>::vi_step(
420 const std::ranges::input_range
auto& range,
421 utils::CountdownTimer& timer,
422 unsigned long long& stat_counter)
424 bool values_not_conv =
false;
425 bool policy_not_conv =
false;
427 for (
const StateID
id : range) {
428 timer.throw_if_expired();
430 StateInfo& state_info = this->state_infos_[id];
432 const State state = mdp.get_state(
id);
433 const value_t termination_cost =
434 mdp.get_termination_info(state).get_cost();
436 ClearGuard _(transitions_, qvalues_);
438 this->generate_non_tip_transitions(mdp, state, transitions_);
440 const auto value = this->compute_bellman_and_greedy(
449 auto transition = this->select_greedy_transition(
451 state_info.get_policy(),
454 bool value_changed = this->update_value(state_info, value);
455 bool policy_changed = this->update_policy(state_info, transition);
456 values_not_conv = values_not_conv || value_changed;
457 policy_not_conv = policy_not_conv || policy_changed;
459 if constexpr (UseInterval) {
460 values_not_conv = values_not_conv ||
461 !state_info.value.bounds_approximately_equal();
465 return std::make_pair(values_not_conv, policy_not_conv);
The interface representing heuristic functions.
Definition mdp_algorithm.h:16
Basic interface for MDPs.
Definition mdp_algorithm.h:14
A registry for print functions related to search progress.
Definition progress_report.h:33
void print()
Prints the output to the internal output stream, if enabled.
void register_bound(const std::string &property_name, BoundProperty property)
Appends a new bound property with a given name to the list of bound properties to be printed when the...
virtual StateID get_state_id(param_type< State > state)=0
Get the state ID for a given state.
Implementation of the depth-first heuristic search algorithm family steinmetz:etal:icaps-16.
Definition depth_first_heuristic_search.h:104
Namespace dedicated to Depth-First Heuristic Search.
Definition depth_first_heuristic_search.h:17
Interval as_interval(value_t lower_bound)
Returns the interval with the given lower bound and infinte upper bound.
double value_t
Typedef for the state value type.
Definition aliases.h:7
typename std::conditional_t< is_cheap_to_copy_v< T >, T, const T & > param_type
Alias template defining the best way to pass a parameter of a given type.
Definition type_traits.h:25
Represents a closed interval over the extended reals as a pair of lower and upper bound.
Definition interval.h:12
A StateID represents a state within a StateIDMap. Just like Fast Downward's StateID type,...
Definition types.h:22