AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
exhaustive_dfs_impl.h
1#ifndef GUARD_INCLUDE_PROBFD_ALGORITHMS_EXHAUSTIVE_DFS_H
2#error "This file should only be included from exhaustive_dfs.h"
3#endif
4
5#include "probfd/algorithms/utils.h"
6
7#include "probfd/algorithms/transition_sorter.h"
8
9#include "probfd/utils/not_implemented.h"
10
11#include "probfd/evaluator.h"
12
13#include <cassert>
14#include <ranges>
15
17
18inline bool update_lower_bound(value_t& x, value_t v)
19{
20 if (v > x) {
21 x = v;
22 return true;
23 }
24
25 return false;
26}
27
28inline bool update_lower_bound(Interval& x, value_t v)
29{
30 if (v > x.lower) {
31 x.lower = v;
32 return true;
33 }
34
35 return false;
36}
37
38void Statistics::print(std::ostream& out) const
39{
40 out << " Expanded " << expanded << " state(s)." << std::endl;
41 out << " Evaluated " << evaluated << " state(s)." << std::endl;
42 out << " Evaluations: " << evaluations << std::endl;
43 out << " Terminal states: " << terminal << std::endl;
44 out << " Pure self-loop states: " << self_loop << std::endl;
45 out << " Goal states: " << goal_states << " state(s)." << std::endl;
46 out << " Dead ends: " << dead_ends << " state(s)." << std::endl;
47 out << " State value updates: " << value_updates << std::endl;
48 out << " Backtracked from " << backtracks << " state(s)." << std::endl;
49 out << " Found " << sccs << " SCC(s)." << std::endl;
50 out << " Found " << dead_end_sccs << " dead-end SCC(s)." << std::endl;
51 out << " Partially pruned " << pruned_dead_end_sccs << " dead-end SCC(s)."
52 << std::endl;
53 out << " Average dead-end SCC size: "
54 << (static_cast<double>(summed_dead_end_scc_sizes) /
55 static_cast<int>(dead_end_sccs))
56 << std::endl;
57}
58
59template <typename State, typename Action, bool UseInterval>
60ExhaustiveDepthFirstSearch<State, Action, UseInterval>::
61 ExhaustiveDepthFirstSearch(
62 std::shared_ptr<TransitionSorterType> transition_sorting,
63 Interval cost_bound,
64 bool path_updates,
65 bool only_propagate_when_changed)
66 : transition_sort_(transition_sorting)
67 , cost_bound_(cost_bound)
68 , trivial_bound_([=] {
69 if constexpr (UseInterval) {
70 return cost_bound;
71 } else {
72 return cost_bound.upper;
73 }
74 }())
75 , value_propagation_(path_updates)
76 , only_propagate_when_changed_(only_propagate_when_changed)
77{
78}
79
80template <typename State, typename Action, bool UseInterval>
81Interval ExhaustiveDepthFirstSearch<State, Action, UseInterval>::solve(
82 MDPType& mdp,
83 EvaluatorType& heuristic,
84 param_type<State> state,
85 ProgressReport progress,
86 double)
87{
88 StateID stateid = mdp.get_state_id(state);
89 SearchNodeInfo& info = search_space_[stateid];
90 if (!initialize_search_node(mdp, heuristic, state, info)) {
91 return search_space_.lookup_bounds(stateid);
92 }
93
94 if (!push_state(mdp, heuristic, stateid, info)) {
95 std::cout << "initial state is dead end!" << std::endl;
96 return search_space_.lookup_bounds(stateid);
97 }
98
99 register_value_reports(info, progress);
100 run_exploration(mdp, heuristic, progress);
101
102 return search_space_.lookup_bounds(stateid);
103}
104
105template <typename State, typename Action, bool UseInterval>
106auto ExhaustiveDepthFirstSearch<State, Action, UseInterval>::compute_policy(
107 MDPType&,
108 EvaluatorType&,
109 param_type<State>,
110 ProgressReport,
111 double) -> std::unique_ptr<PolicyType>
112{
113 not_implemented();
114}
115
116template <typename State, typename Action, bool UseInterval>
118 std::ostream& out) const
119{
120 statistics_.print(out);
121}
122
123template <typename State, typename Action, bool UseInterval>
125 register_value_reports(const SearchNodeInfo& info, ProgressReport& progress)
126{
127 progress.register_bound("v", [info]() {
128 if constexpr (UseInterval) {
129 return info.value;
130 } else {
131 return Interval(info.value, INFINITE_VALUE);
132 }
133 });
134}
135
136template <typename State, typename Action, bool UseInterval>
137bool ExhaustiveDepthFirstSearch<State, Action, UseInterval>::
138 initialize_search_node(
139 MDPType& mdp,
140 EvaluatorType& heuristic,
141 StateID state_id,
142 SearchNodeInfo& info)
143{
144 return initialize_search_node(
145 mdp,
146 heuristic,
147 mdp.get_state(state_id),
148 info);
149}
150
151template <typename State, typename Action, bool UseInterval>
152bool ExhaustiveDepthFirstSearch<State, Action, UseInterval>::
153 initialize_search_node(
154 MDPType& mdp,
155 EvaluatorType& heuristic,
156 param_type<State> state,
157 SearchNodeInfo& info)
158{
159 assert(info.is_new());
160 info.value = trivial_bound_;
161
162 TerminationInfo term_info = mdp.get_termination_info(state);
163 const value_t term_cost = term_info.get_cost();
164 info.term_cost = term_cost;
165
166 if (term_info.is_goal_state()) {
167 info.close();
168 info.value = AlgorithmValueType(term_cost);
169 ++statistics_.goal_states;
170 return false;
171 }
172
173 const value_t estimate = heuristic.evaluate(state);
174 if (estimate == term_cost) {
175 info.value = AlgorithmValueType(term_cost);
176 info.mark_dead_end();
177 ++statistics_.dead_ends;
178 return false;
179 }
180
181 if constexpr (UseInterval) {
182 info.value.lower = estimate;
183 }
184
185 info.open();
186
187 return true;
188}
189
190template <typename State, typename Action, bool UseInterval>
191bool ExhaustiveDepthFirstSearch<State, Action, UseInterval>::push_state(
192 MDPType& mdp,
193 EvaluatorType& heuristic,
194 StateID state_id,
195 SearchNodeInfo& info)
196{
197 std::vector<Action> aops;
198 std::vector<Distribution<StateID>> successors;
199 const State state = mdp.get_state(state_id);
200 mdp.generate_all_transitions(state, aops, successors);
201 if (successors.empty()) {
202 info.value = AlgorithmValueType(info.term_cost);
203 info.set_dead_end();
204 statistics_.terminal++;
205 return false;
206 }
207
208 statistics_.expanded++;
209
210 if (transition_sort_ != nullptr) {
211 transition_sort_->sort(state, aops, successors, search_space_);
212 }
213
214 expansion_infos_.emplace_back(stack_infos_.size());
215 stack_infos_.emplace_back(state_id);
216
217 ExpansionInformation& exp = expansion_infos_.back();
218 StackInformation& si = stack_infos_.back();
219
220 si.successors.resize(aops.size());
221
222 const auto cost = info.get_value();
223
224 bool pure_self_loop = true;
225
226 unsigned j = 0;
227 for (unsigned i = 0; i < aops.size(); ++i) {
228 auto& succs = successors[i];
229 auto& t = si.successors[i];
230 bool all_self_loops = true;
231
232 succs.remove_if([&, this, state_id](auto& elem) {
233 const auto [succ_id, prob] = elem;
234
235 // Remove self loops
236 if (succ_id == state_id) {
237 t.self_loop += prob;
238 return true;
239 }
240
241 SearchNodeInfo& succ_info = search_space_[succ_id];
242 if (succ_info.is_new()) {
243 initialize_search_node(mdp, heuristic, succ_id, succ_info);
244 }
245
246 if (succ_info.is_closed()) {
247 t.base += prob * succ_info.get_value();
248 exp.update_successors_dead(succ_info.is_dead_end());
249 exp.all_successors_marked_dead =
250 exp.all_successors_marked_dead &&
251 succ_info.is_marked_dead_end();
252
253 all_self_loops = false;
254
255 return true;
256 }
257
258 return false;
259 });
260
261 const auto& a = aops[i];
262 if (succs.empty()) {
263 if (!all_self_loops) {
264 pure_self_loop = false;
265 t.base += cost + mdp.get_action_cost(a);
266 auto non_loop = 1_vt - t.self_loop;
267 update_lower_bound(info.value, t.base / non_loop);
268 }
269 } else {
270 t.base += cost + mdp.get_action_cost(a);
271
272 if (t.self_loop == 0_vt) {
273 t.self_loop = 1_vt;
274 } else {
275 assert(t.self_loop < 1_vt);
276 t.self_loop = 1_vt / (1_vt - t.self_loop);
277 }
278
279 if (i != j) {
280 si.successors[j] = std::move(si.successors[i]);
281 successors[j] = std::move(successors[i]);
282 }
283 ++j;
284 }
285 }
286
287 if (j == 0) {
288 expansion_infos_.pop_back();
289 stack_infos_.pop_back();
290
291 if (pure_self_loop) {
292 info.value = AlgorithmValueType(info.term_cost);
293 info.set_dead_end();
294 ++statistics_.self_loop;
295 } else {
296 info.value = AlgorithmValueType(info.get_value());
297 info.close();
298 }
299
300 return false;
301 }
302
303 successors.erase(successors.begin() + j, successors.end());
304 si.successors.erase(si.successors.begin() + j, si.successors.end());
305 si.i = 0;
306
307 info.set_onstack(stack_infos_.size() - 1);
308 exp.successors = std::move(successors);
309 exp.succ = exp.successors.back().begin();
310
311 return true;
312}
313
314template <typename State, typename Action, bool UseInterval>
315void ExhaustiveDepthFirstSearch<State, Action, UseInterval>::run_exploration(
316 MDPType& mdp,
317 EvaluatorType& heuristic,
318 ProgressReport& progress)
319{
320 using namespace std;
321
322 while (!expansion_infos_.empty()) {
323 ExpansionInformation& expanding = expansion_infos_.back();
324 assert(expanding.stack_index < stack_infos_.size());
325 assert(!expanding.successors.empty());
326 assert(expanding.succ != expanding.successors.back().end());
327
328 StackInformation& stack_info = stack_infos_[expanding.stack_index];
329 assert(!stack_info.successors.empty());
330
331 const StateID stateid = stack_info.state_ref;
332 SearchNodeInfo& node_info = search_space_[stateid];
333
334 expanding.update_successors_dead(last_all_dead_);
335 expanding.all_successors_marked_dead =
336 expanding.all_successors_marked_dead && last_all_marked_dead_;
337
338 int idx = stack_info.successors.size() - stack_info.i - 1;
339 SCCTransition* inc = &stack_info.successors[idx];
340 bool val_changed = false;
341 bool completely_explored = false;
342
343 for (;;) {
344 for (; expanding.succ != expanding.successors.back().end();
345 ++expanding.succ) {
346 const auto [succ_id, prob] = *expanding.succ;
347
348 assert(succ_id != stateid);
349 SearchNodeInfo& succ_info = search_space_[succ_id];
350 assert(!succ_info.is_new());
351
352 if (succ_info.is_open()) {
353 if (push_state(mdp, heuristic, succ_id, succ_info)) {
354 goto skip;
355 }
356
357 expanding.update_successors_dead(succ_info.is_dead_end());
358 expanding.all_successors_marked_dead =
359 expanding.all_successors_are_dead &&
360 succ_info.is_marked_dead_end();
361 inc->base += prob * succ_info.get_value();
362 } else if (succ_info.is_onstack()) {
363 node_info.lowlink =
364 std::min(node_info.lowlink, succ_info.lowlink);
365 inc->successors.add_probability(succ_id, prob);
366 } else {
367 assert(succ_info.is_closed());
368 expanding.update_successors_dead(succ_info.is_dead_end());
369 expanding.all_successors_marked_dead =
370 expanding.all_successors_are_dead &&
371 succ_info.is_marked_dead_end();
372 inc->base += prob * succ_info.get_value();
373 }
374 }
375
376 expanding.successors.pop_back();
377 if (update_lower_bound(
378 node_info.value,
379 inc->base * inc->self_loop)) {
380 val_changed = true;
381 if (check_early_convergence(node_info)) {
382 expanding.successors.clear();
383 }
384 }
385
386 if (expanding.successors.empty()) {
387 if (inc->successors.empty()) {
388 if (stack_info.i > 0)
389 std::swap(stack_info.successors.back(), *inc);
390 stack_info.successors.pop_back();
391 }
392
393 break;
394 }
395
396 if (inc->successors.empty()) {
397 if (stack_info.i > 0) {
398 std::swap(stack_info.successors.back(), *inc);
399 }
400
401 stack_info.successors.pop_back();
402 int t = stack_info.successors.size() - stack_info.i - 1;
403 inc = &stack_info.successors[t];
404 } else {
405 --inc;
406 ++stack_info.i;
407 }
408
409 expanding.succ = expanding.successors.back().begin();
410 }
411
412 last_all_dead_ = expanding.all_successors_are_dead;
413 last_all_marked_dead_ = expanding.all_successors_marked_dead;
414 statistics_.backtracks++;
415
416 if (expanding.stack_index == node_info.lowlink) {
417 ++statistics_.sccs;
418
419 auto rend = stack_infos_.rbegin();
420 if (expanding.all_successors_are_dead) {
421 unsigned scc_size = 0;
422 do {
423 ++scc_size;
424 auto& info = search_space_[rend->state_ref];
425 info.value = AlgorithmValueType(info.term_cost);
426 info.set_dead_end();
427 } while ((rend++)->state_ref != stateid);
428
429 statistics_.dead_end_sccs++;
430 statistics_.summed_dead_end_scc_sizes += scc_size;
431 } else {
432 unsigned scc_size = 0;
433 do {
434 auto& info = search_space_[rend->state_ref];
435 info.close();
436
437 if constexpr (UseInterval) {
438 val_changed =
439 update(
440 info.value,
441 AlgorithmValueType(info.value.lower)) ||
442 val_changed;
443 }
444
445 ++scc_size;
446 } while ((rend++)->state_ref != stateid);
447
448 if (scc_size > 1) {
449 unsigned iterations = 0;
450 bool changed;
451 do {
452 changed = false;
453 for (auto it = stack_infos_.rbegin(); it != rend;
454 ++it) {
455 StackInformation& s = *it;
456 assert(!s.successors.empty());
457 value_t best = s.successors.back().base;
458 for (const auto& t :
459 std::views::reverse(s.successors)) {
460 value_t t_first = t.base;
461 for (auto [succ_id, prob] : t.successors) {
462 t_first +=
463 prob *
464 search_space_[succ_id].get_value();
465 }
466 t_first = t_first * t.self_loop;
467 best = best > t_first ? best : t_first;
468 }
469
470 SearchNodeInfo& snode_info =
471 search_space_[s.state_ref];
472 if (best > snode_info.get_value()) {
473 changed = changed || !is_approx_equal(
474 snode_info.get_value(),
475 best);
476 snode_info.value = AlgorithmValueType(best);
477 }
478 }
479 ++iterations;
480 } while (changed);
481
482 val_changed = val_changed || iterations > 1;
483 }
484 }
485
486 stack_infos_.erase(rend.base(), stack_infos_.end());
487 }
488
489 expansion_infos_.pop_back();
490
491 completely_explored = true;
492
493 skip:
494
495 if ((val_changed || !only_propagate_when_changed_) &&
496 value_propagation_) {
497 propagate_value_along_trace(
498 completely_explored,
499 node_info.get_value(),
500 progress);
501 }
502 }
503}
504
505template <typename State, typename Action, bool UseInterval>
506void ExhaustiveDepthFirstSearch<State, Action, UseInterval>::
507 propagate_value_along_trace(
508 bool was_poped,
509 value_t val,
510 ProgressReport& progress)
511{
512 auto it = expansion_infos_.rbegin();
513 if (!was_poped) {
514 it += 2;
515 }
516
517 for (; it != expansion_infos_.rend(); ++it) {
518 StackInformation& st = stack_infos_[it->stack_index];
519 SearchNodeInfo& sn = search_space_[st.state_ref];
520 const auto& t = st.successors[st.successors.size() - st.i - 1];
521 const value_t v = t.base + it->succ->probability * val;
522 if (!update_lower_bound(sn.value, v)) {
523 break;
524 }
525
526 val = v;
527 }
528
529 if (it == expansion_infos_.rend()) {
530 progress.print();
531 }
532}
533
534template <typename State, typename Action, bool UseInterval>
535bool ExhaustiveDepthFirstSearch<State, Action, UseInterval>::
536 check_early_convergence(const SearchNodeInfo& node) const
537{
538 if constexpr (UseInterval) {
539 return node.value.upper <= node.value.lower;
540 } else {
541 return node.value <= cost_bound_.lower;
542 }
543}
544
545} // namespace probfd::algorithms::exhaustive_dfs
A registry for print functions related to search progress.
Definition progress_report.h:33
void register_bound(const std::string &property_name, BoundProperty property)
Appends a new bound property with a given name to the list of bound properties to be printed when the...
Implementation of an anytime topological value iteration variant.
Definition exhaustive_dfs.h:210
namespace for anytime TVI
Definition exhaustive_dfs.h:24
bool update(Interval &lhs, Interval rhs, value_t epsilon=g_epsilon)
Intersects two intervals and assigns the result to the left operand.
double value_t
Typedef for the state value type.
Definition aliases.h:7
bool is_approx_equal(value_t v1, value_t v2, value_t epsilon=g_epsilon)
Equivalent to .
STL namespace.
Represents a closed interval over the extended reals as a pair of lower and upper bound.
Definition interval.h:12