AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
trap_aware_dfhs.h
1#ifndef PROBFD_ALGORITHMS_TRAP_AWARE_DFHS_H
2#define PROBFD_ALGORITHMS_TRAP_AWARE_DFHS_H
3
4#include "probfd/algorithms/heuristic_search_base.h"
5#include "probfd/quotients/quotient_system.h"
6
7#include "probfd/distribution.h"
8
9#include "downward/utils/timer.h"
10
11#include <type_traits>
12#include <vector>
13
14// Forward Declarations
15namespace utils {
16class CountdownTimer;
17}
18
19namespace probfd::algorithms {
20template <typename>
22}
23
27
28enum class BacktrackingUpdateType { DISABLED, ON_DEMAND, SINGLE };
29
30namespace internal {
31
32struct Statistics {
33 utils::Timer trap_timer = utils::Timer(true);
34 unsigned long long iterations = 0;
35 unsigned long long traps = 0;
36 unsigned long long reexpansions = 0;
37 unsigned long long fw_updates = 0;
38 unsigned long long bw_updates = 0;
39
40 void print(std::ostream& out) const;
41 void register_report(ProgressReport& report) const;
42};
43
44template <typename Action, bool UseInterval>
45struct PerStateInformation
46 : public heuristic_search::
47 PerStateBaseInformation<Action, true, UseInterval> {
48private:
49 using Base = heuristic_search::PerStateBaseInformation<Action, true, UseInterval>;
50
51public:
52 static constexpr uint8_t SOLVED = 1 << Base::BITS;
53 static constexpr uint8_t BITS = Base::BITS + 1;
54 static constexpr uint8_t MASK = 1 << Base::BITS;
55
56 [[nodiscard]]
57 bool is_solved() const
58 {
59 return this->info & SOLVED || this->is_goal_or_terminal();
60 }
61
62 void set_solved() { this->info = (this->info & ~MASK) | SOLVED; }
63};
64
65} // namespace internal
66
67template <typename, typename, bool>
68class TADepthFirstHeuristicSearch;
69
70template <typename State, typename Action, bool UseInterval>
71class TADFHSImpl
73 quotients::QuotientState<State, Action>,
74 quotients::QuotientAction<Action>,
75 internal::PerStateInformation<
76 quotients::QuotientAction<Action>,
77 UseInterval>> {
78 using Base = typename TADFHSImpl::HeuristicSearchBase;
79
80 using AlgorithmValueType = Base::AlgorithmValueType;
81
82 using QuotientSystem = quotients::QuotientSystem<State, Action>;
83 using QState = quotients::QuotientState<State, Action>;
84 using QAction = quotients::QuotientAction<Action>;
85
86 using QEvaluator = typename Base::EvaluatorType;
87 using QuotientPolicyPicker = typename Base::PolicyPickerType;
88 using StateInfo = typename Base::StateInfo;
89
90 using QuotientOpenList = OpenList<QAction>;
91
92 template <typename, typename, bool>
93 friend class TADepthFirstHeuristicSearch;
94
95 struct UpdateResult {
96 bool value_changed;
97 bool policy_changed;
98 };
99
100 struct ExplorationInformation {
101 StateID state;
102 int lowlink;
103 std::vector<StateID> successors;
104
106 bool value_converged : 1 = true;
108 bool all_solved : 1 = true;
110 bool is_trap : 1 = true;
111
112 explicit ExplorationInformation(StateID state, int stack_index)
113 : state(state)
114 , lowlink(stack_index)
115 {
116 }
117
118 bool next_successor();
119 StateID get_successor() const;
120
121 void update(const ExplorationInformation& other);
122
123 void clear();
124 };
125
126 struct StackInfo {
127 StateID state_id;
128 std::optional<QAction> action;
129
130 explicit StackInfo(StateID state_id)
131 : state_id(state_id)
132 {
133 }
134
135 template <size_t i>
136 friend auto get(StackInfo& info)
137 {
138 if constexpr (i == 0) return info.state_id;
139 if constexpr (i == 1) return std::views::single(*info.action);
140 }
141
142 template <size_t i>
143 friend auto get(const StackInfo& info)
144 {
145 if constexpr (i == 0) return info.state_id;
146 if constexpr (i == 1) return std::views::single(*info.action);
147 }
148 };
149
150 static constexpr int NEW = -1;
151 static constexpr int CLOSED = -2;
152
153 // Algorithm parameters
154 const bool forward_updates_;
155 const BacktrackingUpdateType backtrack_update_type_;
156 const bool cutoff_tip_;
157 const bool cutoff_inconsistent_;
158 const bool terminate_exploration_on_cutoff_;
159 const bool label_solved_;
160 const bool reexpand_traps_;
161
162 // Algorithm state
163 std::deque<ExplorationInformation> queue_;
164 std::vector<StackInfo> stack_;
165 std::vector<StateID> visited_states_;
166 storage::StateHashMap<int> stack_index_;
167
168 bool terminated_ = false;
169
170 // Re-used buffer
171 std::vector<Transition<QAction>> transitions_;
172 std::vector<AlgorithmValueType> qvalues_;
173 Distribution<StateID> transition_;
174
175 internal::Statistics statistics_;
176
177public:
181 TADFHSImpl(
182 std::shared_ptr<QuotientPolicyPicker> policy_chooser,
183 bool forward_updates,
184 BacktrackingUpdateType backtrack_update_type,
185 bool cutoff_tip,
186 bool cutoff_inconsistent,
187 bool terminate_exploration_on_cutoff,
188 bool label_solved,
189 bool reexpand_traps);
190
191 Interval solve_quotient(
192 QuotientSystem& quotient,
193 QEvaluator& heuristic,
194 param_type<QState> qstate,
195 ProgressReport& progress,
196 double max_time);
197
198 void print_statistics(std::ostream& out) const;
199
200private:
201 void dfhs_vi_driver(
202 QuotientSystem& quotient,
203 QEvaluator& heuristic,
204 StateID state,
205 ProgressReport& progress,
206 utils::CountdownTimer& timer);
207
208 void dfhs_label_driver(
209 QuotientSystem& quotient,
210 QEvaluator& heuristic,
211 StateID state,
212 ProgressReport& progress,
213 utils::CountdownTimer& timer);
214
215 void enqueue(
216 QuotientSystem& quotient,
217 ExplorationInformation& einfo,
218 StateID state,
219 QAction action,
220 const Distribution<StateID>& successor_dist);
221
222 bool advance(QuotientSystem& quotient, ExplorationInformation& einfo);
223
224 bool push_successor(
225 QuotientSystem& quotient,
226 ExplorationInformation& einfo,
227 utils::CountdownTimer& timer);
228
229 bool initialize(
230 QuotientSystem& quotient,
231 QEvaluator& heuristic,
232 ExplorationInformation& einfo);
233
234 void push(StateID state_id);
235
236 bool policy_exploration(
237 QuotientSystem& quotient,
238 QEvaluator& heuristic,
239 StateID start_state,
240 utils::CountdownTimer& timer);
241
242 UpdateResult value_iteration(
243 QuotientSystem& quotient,
244 const std::ranges::input_range auto& range,
245 utils::CountdownTimer& timer);
246};
247
248template <typename State, typename Action, bool UseInterval>
249class TADepthFirstHeuristicSearch : public MDPAlgorithm<State, Action> {
250 using Base = typename TADepthFirstHeuristicSearch::MDPAlgorithm;
251
252 using PolicyType = typename Base::PolicyType;
253 using MDPType = typename Base::MDPType;
254 using EvaluatorType = typename Base::EvaluatorType;
255
256 using QuotientSystem = quotients::QuotientSystem<State, Action>;
257 using QState = quotients::QuotientState<State, Action>;
258 using QAction = quotients::QuotientAction<Action>;
259
260 using QuotientPolicyPicker = PolicyPicker<QState, QAction>;
261
262 TADFHSImpl<State, Action, UseInterval> algorithm_;
263
264public:
268 TADepthFirstHeuristicSearch(
269 std::shared_ptr<QuotientPolicyPicker> policy_chooser,
270 bool forward_updates,
271 BacktrackingUpdateType backtrack_update_type,
272 bool cutoff_tip,
273 bool cutoff_inconsistent,
274 bool stop_exploration_inconsistent,
275 bool label_solved,
276 bool reexpand_removed_traps);
277
278 Interval solve(
279 MDPType& mdp,
280 EvaluatorType& heuristic,
281 param_type<State> state,
282 ProgressReport progress,
283 double max_time) override;
284
285 std::unique_ptr<PolicyType> compute_policy(
286 MDPType& mdp,
287 EvaluatorType& heuristic,
288 param_type<State> state,
289 ProgressReport progress,
290 double max_time) override;
291
292 void print_statistics(std::ostream& out) const override;
293
294 [[nodiscard]]
295 Interval lookup_bounds(StateID state_id) const;
296
297 [[nodiscard]]
298 bool was_visited(StateID state_id) const;
299};
300
301} // namespace probfd::algorithms::trap_aware_dfhs
302
303#define GUARD_INCLUDE_PROBFD_ALGORITHMS_TRAP_AWARE_DFHS_H
304#include "probfd/algorithms/trap_aware_dfhs_impl.h"
305#undef GUARD_INCLUDE_PROBFD_ALGORITHMS_TRAP_AWARE_DFHS_H
306
307#endif
A convenience class that represents a finite probability distribution.
Definition task_state_space.h:27
Interface for MDP algorithm implementations.
Definition mdp_algorithm.h:29
A registry for print functions related to search progress.
Definition progress_report.h:33
An interface for open lists used during search algorithms.
Definition trap_aware_dfhs.h:21
An strategy interface used to choose break ties between multiple greedy actions for a state.
Definition policy_picker.h:57
The common base class for MDP h search algorithms.
Definition heuristic_search_base.h:113
Namespace dedicated to the depth-first heuristic search (DFHS) family with native trap handling suppo...
Definition trap_aware_dfhs.h:26
This namespace contains implementations of SSP search algorithms.
Definition acyclic_value_iteration.h:22
typename std::conditional_t< is_cheap_to_copy_v< T >, T, const T & > param_type
Alias template defining the best way to pass a parameter of a given type.
Definition type_traits.h:25
Represents a closed interval over the extended reals as a pair of lower and upper bound.
Definition interval.h:12
A StateID represents a state within a StateIDMap. Just like Fast Downward's StateID type,...
Definition types.h:22