AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
ao_search.h
1#ifndef PROBFD_ALGORITHMS_AO_SEARCH_H
2#define PROBFD_ALGORITHMS_AO_SEARCH_H
3
4#include "probfd/algorithms/heuristic_search_base.h"
5
6#include "probfd/storage/per_state_storage.h"
7
8#include <iosfwd>
9#include <queue>
10#include <type_traits>
11#include <vector>
12
15
16struct Statistics {
17 unsigned long long iterations = 0;
18 void print(std::ostream& out) const;
19};
20
21template <typename Action, bool Interval, bool StorePolicy>
22struct PerStateInformation
23 : public heuristic_search::
24 PerStateBaseInformation<Action, StorePolicy, Interval> {
25private:
26 using Base = heuristic_search::PerStateBaseInformation<Action, StorePolicy, Interval>;
27
28public:
29 static constexpr uint8_t MARK = 1 << Base::BITS;
30 static constexpr uint8_t SOLVED = 2 << Base::BITS;
31 static constexpr uint8_t MASK = 3 << Base::BITS;
32 static constexpr uint8_t BITS = Base::BITS + 2;
33
34 unsigned update_order = 0;
35 std::vector<StateID> parents;
36
37 [[nodiscard]]
38 bool is_marked() const
39 {
40 return this->info & MARK;
41 }
42
43 [[nodiscard]]
44 bool is_solved() const
45 {
46 return this->info & SOLVED || this->is_goal_or_terminal();
47 }
48
49 [[nodiscard]]
50 const std::vector<StateID>& get_parents() const
51 {
52 return parents;
53 }
54
55 [[nodiscard]]
56 std::vector<StateID>& get_parents()
57 {
58 return parents;
59 }
60
61 void mark()
62 {
63 assert(!is_solved());
64 this->info = (this->info & ~MASK) | MARK;
65 }
66
67 void unmark() { this->info = (this->info & ~MARK); }
68
69 void set_solved() { this->info = (this->info & ~MASK) | SOLVED; }
70
71 void add_parent(StateID s) { parents.push_back(s); }
72};
73
82template <typename State, typename Action, typename StateInfo>
83class AOBase
84 : public heuristic_search::
85 HeuristicSearchAlgorithm<State, Action, StateInfo> {
86 using Base = typename AOBase::HeuristicSearchAlgorithm;
87
88protected:
89 using MDPType = typename Base::MDPType;
90 using EvaluatorType = typename Base::EvaluatorType;
91 using PolicyPickerType = typename Base::PolicyPicker;
92
93private:
94 struct PrioritizedStateID {
95 unsigned update_order;
96 StateID state_id;
97
98 friend bool operator<(
99 const PrioritizedStateID& left,
100 const PrioritizedStateID& right)
101 {
102 return left.update_order > right.update_order;
103 }
104 };
105
106 std::priority_queue<PrioritizedStateID> queue_;
107
108protected:
109 Statistics statistics_;
110
111public:
112 // Inherit constructor
113 using Base::Base;
114
115protected:
116 void print_additional_statistics(std::ostream& out) const override;
117
118 void backpropagate_tip_value(
119 this auto& self,
120 MDPType& mdp,
121 std::vector<Transition<Action>>& transitions,
122 StateInfo& state_info,
123 utils::CountdownTimer& timer);
124
125 void backpropagate_update_order(
126 StateID tip,
127 StateInfo& info,
128 unsigned update_order,
129 utils::CountdownTimer& timer);
130
131private:
132 void push_parents_to_queue(StateInfo& info);
133};
134
135} // namespace probfd::algorithms::ao_search
136
137#define GUARD_INCLUDE_PROBFD_ALGORITHMS_AO_SEARCH_H
138#include "probfd/algorithms/ao_search_impl.h"
139#undef GUARD_INCLUDE_PROBFD_ALGORITHMS_AO_SEARCH_H
140
141#endif // PROBFD_ALGORITHMS_AO_SEARCH_H
Base class for the AO* algorithm family.
Definition ao_search.h:85
Namespace dedicated to the AO* family of MDP algorithms.
Definition ao_search.h:14
A StateID represents a state within a StateIDMap. Just like Fast Downward's StateID type,...
Definition types.h:22