AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
trap_aware_lrtdp.h
1#ifndef PROBFD_ALGORITHMS_TRAP_AWARE_LRTDP_H
2#define PROBFD_ALGORITHMS_TRAP_AWARE_LRTDP_H
3
4#include "probfd/algorithms/heuristic_search_base.h"
5#include "probfd/quotients/quotient_system.h"
6#include "probfd/storage/per_state_storage.h"
7
8#include "downward/utils/timer.h"
9
10// Forward Declarations
11namespace utils {
12class CountdownTimer;
13}
14
15namespace probfd::algorithms {
16template <typename>
18}
19
23
24enum class TrialTerminationCondition {
25 TERMINAL,
26 CONSISTENT,
27 INCONSISTENT,
28 REVISITED,
29};
30
31namespace internal {
32
33struct Statistics {
34 unsigned long long trials = 0;
35 unsigned long long trial_bellman_backups = 0;
36 unsigned long long check_and_solve_bellman_backups = 0;
37 unsigned long long traps = 0;
38 unsigned long long trial_length = 0;
39 utils::Timer trap_timer = utils::Timer(true);
40
41 void print(std::ostream& out) const;
42 void register_report(ProgressReport& report) const;
43};
44
45template <typename Action, bool UseInterval>
46struct PerStateInformation
47 : public heuristic_search::
48 PerStateBaseInformation<Action, true, UseInterval> {
49private:
50 using Base = heuristic_search::PerStateBaseInformation<Action, true, UseInterval>;
51
52public:
53 static constexpr uint8_t MARKED_TRIAL = 1 << Base::BITS;
54 static constexpr uint8_t SOLVED = 2 << Base::BITS;
55 static constexpr uint8_t BITS = Base::BITS + 2;
56 static constexpr uint8_t MASK = 3 << Base::BITS;
57
58 bool is_solved() const
59 {
60 return (this->info & MASK) == SOLVED || this->is_goal_or_terminal();
61 }
62
63 bool is_on_trial() const { return (this->info & MARKED_TRIAL); }
64
65 void set_solved() { this->info = (this->info & ~MASK) | SOLVED; }
66 void set_on_trial() { this->info = this->info | MARKED_TRIAL; }
67
68 void clear_trial_flag() { this->info = (this->info & ~MARKED_TRIAL); }
69};
70
71} // namespace internal
72
73template <typename, typename, bool>
74class TALRTDP;
75
76template <typename State, typename Action, bool UseInterval>
77class TALRTDPImpl
79 quotients::QuotientState<State, Action>,
80 quotients::QuotientAction<Action>,
81 internal::PerStateInformation<
82 quotients::QuotientAction<Action>,
83 UseInterval>> {
84 using Base = typename TALRTDPImpl::HeuristicSearchBase;
85
86 using AlgorithmValueType = Base::AlgorithmValueType;
87
88 using QuotientSystem = quotients::QuotientSystem<State, Action>;
89 using QState = quotients::QuotientState<State, Action>;
90 using QAction = quotients::QuotientAction<Action>;
91
92 using QEvaluator = typename Base::EvaluatorType;
93 using QuotientPolicyPicker = typename Base::PolicyPickerType;
94 using StateInfo = typename Base::StateInfo;
95
96 using QuotientSuccessorSampler = SuccessorSampler<QAction>;
97
98 template <typename, typename, bool>
99 friend class TALRTDP;
100
101 struct ExplorationInformation {
102 explicit ExplorationInformation(StateID state_id)
103 : state(state_id)
104 {
105 }
106
107 StateID state;
108 std::vector<StateID> successors;
109 bool is_root : 1 = true;
110 bool is_trap : 1 = true;
111 bool rv : 1 = true;
112
113 bool next_successor();
114 [[nodiscard]]
115 StateID get_successor() const;
116
117 void update(const ExplorationInformation& backtracked)
118 {
119 is_trap = is_trap && backtracked.is_trap;
120 rv = rv && backtracked.rv;
121 }
122
123 void update(const StateInfo& succ_info)
124 {
125 is_trap = false;
126 rv = rv && succ_info.is_solved();
127 }
128
129 void clear()
130 {
131 is_trap = true;
132 rv = true;
133 }
134 };
135
136 struct StackInfo {
137 StateID state_id;
138 std::vector<QAction> aops;
139
140 explicit StackInfo(StateID state_id)
141 : state_id(state_id)
142 {
143 }
144
145 template <size_t i>
146 friend auto& get(StackInfo& info)
147 {
148 if constexpr (i == 0) return info.state_id;
149 if constexpr (i == 1) return info.aops;
150 }
151
152 template <size_t i>
153 friend const auto& get(const StackInfo& info)
154 {
155 if constexpr (i == 0) return info.state_id;
156 if constexpr (i == 1) return info.aops;
157 }
158 };
159
160 static constexpr int STATE_UNSEEN = -1;
161 static constexpr int STATE_CLOSED = -2;
162
163 // Algorithm parameters
164 const TrialTerminationCondition stop_at_consistent_;
165 const bool reexpand_traps_;
166 const std::shared_ptr<QuotientSuccessorSampler> sample_;
167
168 // Algorithm state
169 std::deque<ExplorationInformation> queue_;
170 std::deque<StackInfo> stack_;
171 storage::StateHashMap<int> stack_index_;
172
173 std::deque<StateID> current_trial_;
174
175 internal::Statistics statistics_;
176
177 // Buffer
178 std::vector<Transition<QAction>> transitions_;
179 std::vector<AlgorithmValueType> qvalues_;
180
181public:
185 TALRTDPImpl(
186 std::shared_ptr<QuotientPolicyPicker> policy_chooser,
187 TrialTerminationCondition stop_consistent,
188 bool reexpand_traps,
189 std::shared_ptr<QuotientSuccessorSampler> succ_sampler);
190
191 Interval solve_quotient(
192 QuotientSystem& quotient,
193 QEvaluator& heuristic,
194 param_type<QState> state,
195 ProgressReport& progress,
196 double max_time);
197
198 void print_statistics(std::ostream& out) const;
199
200private:
201 bool trial(
202 QuotientSystem& quotient,
203 QEvaluator& heuristic,
204 StateID start_state,
205 utils::CountdownTimer& timer);
206
207 bool check_and_solve(
208 QuotientSystem& quotient,
209 QEvaluator& heuristic,
210 utils::CountdownTimer& timer);
211
212 bool push_successor(
213 QuotientSystem& quotient,
214 ExplorationInformation& einfo,
215 utils::CountdownTimer& timer);
216
217 void push(StateID state);
218
219 bool initialize(
220 QuotientSystem& quotient,
221 QEvaluator& heuristic,
222 StateID state,
223 StateInfo& state_info,
224 ExplorationInformation& e_info);
225};
226
227template <typename State, typename Action, bool UseInterval>
228class TALRTDP : public MDPAlgorithm<State, Action> {
229 using Base = typename TALRTDP::MDPAlgorithm;
230
231 using QuotientSystem = quotients::QuotientSystem<State, Action>;
232 using QState = quotients::QuotientState<State, Action>;
233 using QAction = quotients::QuotientAction<Action>;
234
235 using MDPType = typename Base::MDPType;
236 using EvaluatorType = typename Base::EvaluatorType;
237 using PolicyType = typename Base::PolicyType;
238
239 using QuotientPolicyPicker = PolicyPicker<QState, QAction>;
240 using QuotientSuccessorSampler = SuccessorSampler<QAction>;
241
242 TALRTDPImpl<State, Action, UseInterval> algorithm_;
243
244public:
248 TALRTDP(
249 std::shared_ptr<QuotientPolicyPicker> policy_chooser,
250 TrialTerminationCondition stop_consistent,
251 bool reexpand_traps,
252 std::shared_ptr<QuotientSuccessorSampler> succ_sampler);
253
254 Interval solve(
255 MDPType& mdp,
256 EvaluatorType& heuristic,
258 ProgressReport progress,
259 double max_time) final;
260
261 std::unique_ptr<PolicyType> compute_policy(
262 MDPType& mdp,
263 EvaluatorType& heuristic,
265 ProgressReport progress,
266 double max_time) final;
267
268 void print_statistics(std::ostream& out) const final;
269};
270
271} // namespace probfd::algorithms::trap_aware_lrtdp
272
273#define GUARD_INCLUDE_PROBFD_ALGORITHMS_TRAP_AWARE_LRTDP_H
274#include "probfd/algorithms/trap_aware_lrtdp_impl.h"
275#undef GUARD_INCLUDE_PROBFD_ALGORITHMS_TRAP_AWARE_LRTDP_H
276
277#endif // PROBFD_ALGORITHMS_TRAP_AWARE_LRTDP_H
Interface for MDP algorithm implementations.
Definition mdp_algorithm.h:29
A registry for print functions related to search progress.
Definition progress_report.h:33
An strategy interface used to choose break ties between multiple greedy actions for a state.
Definition policy_picker.h:57
An interface used to sample a state from a successor distribution.
Definition trap_aware_lrtdp.h:17
The common base class for MDP h search algorithms.
Definition heuristic_search_base.h:113
Namespace dedicated to labelled real-time dynamic programming (LRTDP) with native trap handling suppo...
Definition trap_aware_lrtdp.h:22
This namespace contains implementations of SSP search algorithms.
Definition acyclic_value_iteration.h:22
typename std::conditional_t< is_cheap_to_copy_v< T >, T, const T & > param_type
Alias template defining the best way to pass a parameter of a given type.
Definition type_traits.h:25
Represents a closed interval over the extended reals as a pair of lower and upper bound.
Definition interval.h:12
A StateID represents a state within a StateIDMap. Just like Fast Downward's StateID type,...
Definition types.h:22