1#ifndef GUARD_INCLUDE_PROBFD_ALGORITHMS_EXHAUSTIVE_DFS_H
2#error "This file should only be included from exhaustive_dfs.h"
5#include "probfd/algorithms/utils.h"
7#include "probfd/algorithms/transition_sorter.h"
9#include "probfd/utils/not_implemented.h"
11#include "probfd/evaluator.h"
28inline bool update_lower_bound(Interval& x,
value_t v)
38void Statistics::print(std::ostream& out)
const
40 out <<
" Expanded " << expanded <<
" state(s)." << std::endl;
41 out <<
" Evaluated " << evaluated <<
" state(s)." << std::endl;
42 out <<
" Evaluations: " << evaluations << std::endl;
43 out <<
" Terminal states: " << terminal << std::endl;
44 out <<
" Pure self-loop states: " << self_loop << std::endl;
45 out <<
" Goal states: " << goal_states <<
" state(s)." << std::endl;
46 out <<
" Dead ends: " << dead_ends <<
" state(s)." << std::endl;
47 out <<
" State value updates: " << value_updates << std::endl;
48 out <<
" Backtracked from " << backtracks <<
" state(s)." << std::endl;
49 out <<
" Found " << sccs <<
" SCC(s)." << std::endl;
50 out <<
" Found " << dead_end_sccs <<
" dead-end SCC(s)." << std::endl;
51 out <<
" Partially pruned " << pruned_dead_end_sccs <<
" dead-end SCC(s)."
53 out <<
" Average dead-end SCC size: "
54 << (
static_cast<double>(summed_dead_end_scc_sizes) /
55 static_cast<int>(dead_end_sccs))
59template <
typename State,
typename Action,
bool UseInterval>
60ExhaustiveDepthFirstSearch<State, Action, UseInterval>::
61 ExhaustiveDepthFirstSearch(
62 std::shared_ptr<TransitionSorterType> transition_sorting,
65 bool only_propagate_when_changed)
66 : transition_sort_(transition_sorting)
67 , cost_bound_(cost_bound)
68 , trivial_bound_([=] {
69 if constexpr (UseInterval) {
72 return cost_bound.upper;
75 , value_propagation_(path_updates)
76 , only_propagate_when_changed_(only_propagate_when_changed)
80template <
typename State,
typename Action,
bool UseInterval>
81Interval ExhaustiveDepthFirstSearch<State, Action, UseInterval>::solve(
83 EvaluatorType& heuristic,
84 param_type<State> state,
85 ProgressReport progress,
88 StateID stateid = mdp.get_state_id(state);
89 SearchNodeInfo& info = search_space_[stateid];
90 if (!initialize_search_node(mdp, heuristic, state, info)) {
91 return search_space_.lookup_bounds(stateid);
94 if (!push_state(mdp, heuristic, stateid, info)) {
95 std::cout <<
"initial state is dead end!" << std::endl;
96 return search_space_.lookup_bounds(stateid);
99 register_value_reports(info, progress);
100 run_exploration(mdp, heuristic, progress);
102 return search_space_.lookup_bounds(stateid);
105template <
typename State,
typename Action,
bool UseInterval>
106auto ExhaustiveDepthFirstSearch<State, Action, UseInterval>::compute_policy(
111 double) -> std::unique_ptr<PolicyType>
116template <
typename State,
typename Action,
bool UseInterval>
118 std::ostream& out)
const
120 statistics_.print(out);
123template <
typename State,
typename Action,
bool UseInterval>
128 if constexpr (UseInterval) {
131 return Interval(info.value, INFINITE_VALUE);
136template <
typename State,
typename Action,
bool UseInterval>
137bool ExhaustiveDepthFirstSearch<State, Action, UseInterval>::
138 initialize_search_node(
140 EvaluatorType& heuristic,
142 SearchNodeInfo& info)
144 return initialize_search_node(
147 mdp.get_state(state_id),
151template <
typename State,
typename Action,
bool UseInterval>
152bool ExhaustiveDepthFirstSearch<State, Action, UseInterval>::
153 initialize_search_node(
155 EvaluatorType& heuristic,
156 param_type<State> state,
157 SearchNodeInfo& info)
159 assert(info.is_new());
160 info.value = trivial_bound_;
162 TerminationInfo term_info = mdp.get_termination_info(state);
163 const value_t term_cost = term_info.get_cost();
164 info.term_cost = term_cost;
166 if (term_info.is_goal_state()) {
168 info.value = AlgorithmValueType(term_cost);
169 ++statistics_.goal_states;
173 const value_t estimate = heuristic.evaluate(state);
174 if (estimate == term_cost) {
175 info.value = AlgorithmValueType(term_cost);
176 info.mark_dead_end();
177 ++statistics_.dead_ends;
181 if constexpr (UseInterval) {
182 info.value.lower = estimate;
190template <
typename State,
typename Action,
bool UseInterval>
191bool ExhaustiveDepthFirstSearch<State, Action, UseInterval>::push_state(
193 EvaluatorType& heuristic,
195 SearchNodeInfo& info)
197 std::vector<Action> aops;
198 std::vector<Distribution<StateID>> successors;
199 const State state = mdp.get_state(state_id);
200 mdp.generate_all_transitions(state, aops, successors);
201 if (successors.empty()) {
202 info.value = AlgorithmValueType(info.term_cost);
204 statistics_.terminal++;
208 statistics_.expanded++;
210 if (transition_sort_ !=
nullptr) {
211 transition_sort_->sort(state, aops, successors, search_space_);
214 expansion_infos_.emplace_back(stack_infos_.size());
215 stack_infos_.emplace_back(state_id);
217 ExpansionInformation& exp = expansion_infos_.back();
218 StackInformation& si = stack_infos_.back();
220 si.successors.resize(aops.size());
222 const auto cost = info.get_value();
224 bool pure_self_loop =
true;
227 for (
unsigned i = 0; i < aops.size(); ++i) {
228 auto& succs = successors[i];
229 auto& t = si.successors[i];
230 bool all_self_loops =
true;
232 succs.remove_if([&,
this, state_id](
auto& elem) {
233 const auto [succ_id, prob] = elem;
236 if (succ_id == state_id) {
241 SearchNodeInfo& succ_info = search_space_[succ_id];
242 if (succ_info.is_new()) {
243 initialize_search_node(mdp, heuristic, succ_id, succ_info);
246 if (succ_info.is_closed()) {
247 t.base += prob * succ_info.get_value();
248 exp.update_successors_dead(succ_info.is_dead_end());
249 exp.all_successors_marked_dead =
250 exp.all_successors_marked_dead &&
251 succ_info.is_marked_dead_end();
253 all_self_loops =
false;
261 const auto& a = aops[i];
263 if (!all_self_loops) {
264 pure_self_loop =
false;
265 t.base += cost + mdp.get_action_cost(a);
266 auto non_loop = 1_vt - t.self_loop;
267 update_lower_bound(info.value, t.base / non_loop);
270 t.base += cost + mdp.get_action_cost(a);
272 if (t.self_loop == 0_vt) {
275 assert(t.self_loop < 1_vt);
276 t.self_loop = 1_vt / (1_vt - t.self_loop);
280 si.successors[j] = std::move(si.successors[i]);
281 successors[j] = std::move(successors[i]);
288 expansion_infos_.pop_back();
289 stack_infos_.pop_back();
291 if (pure_self_loop) {
292 info.value = AlgorithmValueType(info.term_cost);
294 ++statistics_.self_loop;
296 info.value = AlgorithmValueType(info.get_value());
303 successors.erase(successors.begin() + j, successors.end());
304 si.successors.erase(si.successors.begin() + j, si.successors.end());
307 info.set_onstack(stack_infos_.size() - 1);
308 exp.successors = std::move(successors);
309 exp.succ = exp.successors.back().begin();
314template <
typename State,
typename Action,
bool UseInterval>
315void ExhaustiveDepthFirstSearch<State, Action, UseInterval>::run_exploration(
317 EvaluatorType& heuristic,
318 ProgressReport& progress)
322 while (!expansion_infos_.empty()) {
323 ExpansionInformation& expanding = expansion_infos_.back();
324 assert(expanding.stack_index < stack_infos_.size());
325 assert(!expanding.successors.empty());
326 assert(expanding.succ != expanding.successors.back().end());
328 StackInformation& stack_info = stack_infos_[expanding.stack_index];
329 assert(!stack_info.successors.empty());
331 const StateID stateid = stack_info.state_ref;
332 SearchNodeInfo& node_info = search_space_[stateid];
334 expanding.update_successors_dead(last_all_dead_);
335 expanding.all_successors_marked_dead =
336 expanding.all_successors_marked_dead && last_all_marked_dead_;
338 int idx = stack_info.successors.size() - stack_info.i - 1;
339 SCCTransition* inc = &stack_info.successors[idx];
340 bool val_changed =
false;
341 bool completely_explored =
false;
344 for (; expanding.succ != expanding.successors.back().end();
346 const auto [succ_id, prob] = *expanding.succ;
348 assert(succ_id != stateid);
349 SearchNodeInfo& succ_info = search_space_[succ_id];
350 assert(!succ_info.is_new());
352 if (succ_info.is_open()) {
353 if (push_state(mdp, heuristic, succ_id, succ_info)) {
357 expanding.update_successors_dead(succ_info.is_dead_end());
358 expanding.all_successors_marked_dead =
359 expanding.all_successors_are_dead &&
360 succ_info.is_marked_dead_end();
361 inc->base += prob * succ_info.get_value();
362 }
else if (succ_info.is_onstack()) {
364 std::min(node_info.lowlink, succ_info.lowlink);
365 inc->successors.add_probability(succ_id, prob);
367 assert(succ_info.is_closed());
368 expanding.update_successors_dead(succ_info.is_dead_end());
369 expanding.all_successors_marked_dead =
370 expanding.all_successors_are_dead &&
371 succ_info.is_marked_dead_end();
372 inc->base += prob * succ_info.get_value();
376 expanding.successors.pop_back();
377 if (update_lower_bound(
379 inc->base * inc->self_loop)) {
381 if (check_early_convergence(node_info)) {
382 expanding.successors.clear();
386 if (expanding.successors.empty()) {
387 if (inc->successors.empty()) {
388 if (stack_info.i > 0)
389 std::swap(stack_info.successors.back(), *inc);
390 stack_info.successors.pop_back();
396 if (inc->successors.empty()) {
397 if (stack_info.i > 0) {
398 std::swap(stack_info.successors.back(), *inc);
401 stack_info.successors.pop_back();
402 int t = stack_info.successors.size() - stack_info.i - 1;
403 inc = &stack_info.successors[t];
409 expanding.succ = expanding.successors.back().begin();
412 last_all_dead_ = expanding.all_successors_are_dead;
413 last_all_marked_dead_ = expanding.all_successors_marked_dead;
414 statistics_.backtracks++;
416 if (expanding.stack_index == node_info.lowlink) {
419 auto rend = stack_infos_.rbegin();
420 if (expanding.all_successors_are_dead) {
421 unsigned scc_size = 0;
424 auto& info = search_space_[rend->state_ref];
425 info.value = AlgorithmValueType(info.term_cost);
427 }
while ((rend++)->state_ref != stateid);
429 statistics_.dead_end_sccs++;
430 statistics_.summed_dead_end_scc_sizes += scc_size;
432 unsigned scc_size = 0;
434 auto& info = search_space_[rend->state_ref];
437 if constexpr (UseInterval) {
441 AlgorithmValueType(info.value.lower)) ||
446 }
while ((rend++)->state_ref != stateid);
449 unsigned iterations = 0;
453 for (
auto it = stack_infos_.rbegin(); it != rend;
455 StackInformation& s = *it;
456 assert(!s.successors.empty());
457 value_t best = s.successors.back().base;
459 std::views::reverse(s.successors)) {
461 for (
auto [succ_id, prob] : t.successors) {
464 search_space_[succ_id].get_value();
466 t_first = t_first * t.self_loop;
467 best = best > t_first ? best : t_first;
470 SearchNodeInfo& snode_info =
471 search_space_[s.state_ref];
472 if (best > snode_info.get_value()) {
474 snode_info.get_value(),
476 snode_info.value = AlgorithmValueType(best);
482 val_changed = val_changed || iterations > 1;
486 stack_infos_.erase(rend.base(), stack_infos_.end());
489 expansion_infos_.pop_back();
491 completely_explored =
true;
495 if ((val_changed || !only_propagate_when_changed_) &&
496 value_propagation_) {
497 propagate_value_along_trace(
499 node_info.get_value(),
505template <
typename State,
typename Action,
bool UseInterval>
506void ExhaustiveDepthFirstSearch<State, Action, UseInterval>::
507 propagate_value_along_trace(
510 ProgressReport& progress)
512 auto it = expansion_infos_.rbegin();
517 for (; it != expansion_infos_.rend(); ++it) {
518 StackInformation& st = stack_infos_[it->stack_index];
519 SearchNodeInfo& sn = search_space_[st.state_ref];
520 const auto& t = st.successors[st.successors.size() - st.i - 1];
521 const value_t v = t.base + it->succ->probability * val;
522 if (!update_lower_bound(sn.value, v)) {
529 if (it == expansion_infos_.rend()) {
534template <
typename State,
typename Action,
bool UseInterval>
535bool ExhaustiveDepthFirstSearch<State, Action, UseInterval>::
536 check_early_convergence(
const SearchNodeInfo& node)
const
538 if constexpr (UseInterval) {
539 return node.value.upper <= node.value.lower;
541 return node.value <= cost_bound_.lower;
A registry for print functions related to search progress.
Definition progress_report.h:33
void register_bound(const std::string &property_name, BoundProperty property)
Appends a new bound property with a given name to the list of bound properties to be printed when the...
Implementation of an anytime topological value iteration variant.
Definition exhaustive_dfs.h:210
namespace for anytime TVI
Definition exhaustive_dfs.h:24
bool update(Interval &lhs, Interval rhs, value_t epsilon=g_epsilon)
Intersects two intervals and assigns the result to the left operand.
double value_t
Typedef for the state value type.
Definition aliases.h:7
bool is_approx_equal(value_t v1, value_t v2, value_t epsilon=g_epsilon)
Equivalent to .
Represents a closed interval over the extended reals as a pair of lower and upper bound.
Definition interval.h:12