AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
ta_topological_value_iteration.h
1#ifndef PROBFD_ALGORITHMS_TA_TOPOLOGICAL_VALUE_ITERATION_H
2#define PROBFD_ALGORITHMS_TA_TOPOLOGICAL_VALUE_ITERATION_H
3
4#include "probfd/algorithms/types.h"
5
6#include "probfd/storage/per_state_storage.h"
7
8#include "probfd/distribution.h"
9#include "probfd/mdp_algorithm.h"
10
11#include "downward/utils/timer.h"
12
13#include <deque>
14#include <limits>
15#include <ostream>
16#include <set>
17#include <vector>
18
19namespace utils {
20class CountdownTimer;
21}
22
25
29struct Statistics {
30 unsigned long long expanded_states = 0;
31 unsigned long long terminal_states = 0;
32 unsigned long long goal_states = 0;
33 unsigned long long sccs = 0;
34 unsigned long long singleton_sccs = 0;
35 unsigned long long bellman_backups = 0;
36 unsigned long long pruned = 0;
37
38 utils::Timer initialize_state_timer = utils::Timer(false);
39 utils::Timer successor_handling_timer = utils::Timer(false);
40 utils::Timer scc_handling_timer = utils::Timer(false);
41 utils::Timer backtracking_timer = utils::Timer(false);
42 utils::Timer vi_timer = utils::Timer(false);
43 utils::Timer decomposition_timer = utils::Timer(false);
44 utils::Timer solvability_timer = utils::Timer(false);
45
46 void print(std::ostream& out) const;
47};
48
62template <typename State, typename Action, bool UseInterval = false>
63class TATopologicalValueIteration : public MDPAlgorithm<State, Action> {
64 using Base = typename TATopologicalValueIteration::MDPAlgorithm;
65
66 using MDPType = typename Base::MDPType;
67 using EvaluatorType = typename Base::EvaluatorType;
68 using PolicyType = typename Base::PolicyType;
69
70 using AlgorithmValueType = algorithms::AlgorithmValue<UseInterval>;
71
72 struct StateInfo {
73 // Status Flags
74 enum { NEW, CLOSED, ONSTACK };
75
76 static constexpr uint32_t UNDEF =
77 std::numeric_limits<uint32_t>::max() >> 1;
78 static constexpr uint32_t UNDEF_ECD =
79 std::numeric_limits<uint32_t>::max();
80
81 unsigned explored : 1 = 0;
82 unsigned stack_id : 31 = UNDEF;
83 unsigned ecd_stack_id = UNDEF_ECD;
84
85 [[nodiscard]]
86 auto get_status() const;
87
88 [[nodiscard]]
89 auto get_ecd_status() const;
90 };
91
92 struct StackInfo;
93
94 struct QValueInfo {
95 // Precomputed part of the Q-value.
96 // Sum of action cost plus those weighted successor values which
97 // have already converged due to topological ordering.
98 mutable AlgorithmValueType conv_part;
99
100 // Pointers to successor values which have not yet converged,
101 // self-loops excluded.
102 std::vector<ItemProbabilityPair<StateID>> scc_successors;
103
104 template <typename ValueStore>
105 AlgorithmValueType compute_q_value(ValueStore& value_store) const;
106 };
107
108 struct ExplorationInfo {
109 // Immutable state
110 StateID state_id;
111 StackInfo& stack_info;
112 unsigned stackidx;
113
114 // Tarjans's algorithm state
115 unsigned lowlink;
116
117 // Exploration State -- Remaining operators
118 std::vector<Action> aops;
119
120 // Exploration State -- Currently expanded transition and successor
121 Distribution<StateID> transition;
122 typename Distribution<StateID>::const_iterator successor;
123
124 // Exploration state -- Current Q value info
125 QValueInfo q_value;
126
127 Interval exit_interval;
128
129 // End component decomposition state
130
131 // recursive decomposition flag
132 // Recursively decompose the SCC if there is a zero-cost transition
133 // in it that can leave and remain in the scc, or a non-zero-cost
134 // transition that can remain in the MDP. Both cannot be part of an
135 // end component and removing them affects connectivity of the SCCs,
136 // so recursion is necessary after removal.
137 bool has_all_zero : 1 = true;
138
139 ExplorationInfo(
140 StateID state_id,
141 StackInfo& stack_info,
142 unsigned int stackidx);
143
144 bool next_transition(MDPType& mdp);
145 bool forward_non_loop_transition(MDPType& mdp, const State& state);
146 bool next_successor();
147
148 ItemProbabilityPair<StateID> get_current_successor();
149 };
150
151 struct cmp_qval_info {
152 bool operator()(const QValueInfo& left, const QValueInfo& right) const
153 {
154 return std::ranges::lexicographical_compare(
155 left.scc_successors,
156 right.scc_successors,
157 [](const auto& left, const auto& right) {
158 return left.item < right.item ||
159 (left.item == right.item && is_approx_less(
160 left.probability,
161 right.probability));
162 });
163 }
164 };
165
166 struct StackInfo {
167 StateID state_id;
168
169 // Reference to the state value of the state.
170 AlgorithmValueType* value;
171
172 // Precomputed portion of the Bellman update.
173 // Maximum over all Q values for actions which always
174 // leave the current scc.
175 AlgorithmValueType conv_part;
176
177 // Q value structs for transitions belonging to the scc,
178 // but not to an end component.
179 std::set<QValueInfo, cmp_qval_info> non_ec_transitions;
180
181 // Q value structs for transitions currently assumed to belong
182 // to an end component within the current scc.
183 // Iteratively refined during end component decomposition.
184 std::vector<QValueInfo> ec_transitions;
185
186 struct ParentTransition {
187 unsigned parent_idx;
188 unsigned parent_transition_idx;
189 };
190
191 struct TransitionFlags {
192 bool is_active_exiting : 1; // Is the transition active and an SCC
193 // exit?
194 bool is_active : 1; // Is the transition active?
195 };
196
197 unsigned active_exit_transitions =
198 0; // Number of active exit transitions.
199 unsigned active_transitions = 0; // Number of active transitions.
200
201 std::vector<TransitionFlags> transition_flags;
202 std::vector<ParentTransition> parents;
203
204 StackInfo(StateID state_id, AlgorithmValueType& value_ref);
205
206 void add_non_ec_transition(QValueInfo&& info);
207 };
208
209 struct ECDExplorationInfo {
210 // Immutable info
211 StackInfo& stack_info;
212 unsigned stackidx;
213
214 // Exploration state - Action
215 typename std::vector<QValueInfo>::iterator action;
216 typename std::vector<QValueInfo>::iterator end;
217
218 // Exploration state - Transition successor
219 typename std::vector<ItemProbabilityPair<StateID>>::iterator successor;
220
221 // Tarjan's algorithm state
222 unsigned lowlink;
223
224 // End component decomposition state
225
226 // ECD recursion flag. Recurse if there is a transition that can leave
227 // and remain in the current scc.
228 bool recurse : 1 = false;
229
230 // Whether the current transition remains in or leaves the current scc.
231 bool leaves_scc : 1 = false;
232 bool remains_scc : 1 = false;
233
234 ECDExplorationInfo(StackInfo& stack_info, unsigned stackidx);
235
236 bool next_transition();
237 bool next_successor();
238
239 ItemProbabilityPair<StateID> get_current_successor();
240 };
241
242 struct DecompositionQueue {
243 std::vector<StateID> state_ids;
244 std::vector<std::size_t> scc_spans;
245
246 void reserve(std::size_t num_states)
247 {
248 state_ids.reserve(num_states);
249 scc_spans.reserve(num_states);
250 }
251
252 void register_new_scc() { scc_spans.push_back(state_ids.size()); }
253
254 void add_scc_state(StateID state_id) { state_ids.push_back(state_id); }
255
256 bool pop_scc(std::vector<StateID>& r)
257 {
258 using namespace std::views;
259
260 assert(r.empty());
261
262 if (state_ids.empty()) return false;
263
264 auto scc_view = state_ids | drop(scc_spans.back());
265
266 for (const auto state_id : scc_view) {
267 r.push_back(state_id);
268 }
269
270 state_ids.erase(scc_view.begin(), scc_view.end());
271
272 scc_spans.pop_back();
273
274 return true;
275 }
276 };
277
278 storage::PerStateStorage<StateInfo> state_information_;
279 std::vector<ExplorationInfo> exploration_stack_;
280 std::deque<StackInfo> stack_;
281
282 std::vector<ECDExplorationInfo> exploration_stack_ecd_;
283 std::vector<StateID> stack_ecd_;
284
285 DecompositionQueue decomposition_queue_;
286 std::vector<StateID> scc_;
287
288 Statistics statistics_;
289
290public:
291 TATopologicalValueIteration() = default;
292
293 explicit TATopologicalValueIteration(std::size_t num_states_hint)
294 {
295 exploration_stack_.reserve(num_states_hint);
296 exploration_stack_ecd_.reserve(num_states_hint);
297 stack_ecd_.reserve(num_states_hint);
298 decomposition_queue_.reserve(num_states_hint);
299 scc_.reserve(num_states_hint);
300 }
301
302 Interval solve(
303 MDPType& mdp,
304 EvaluatorType& heuristic,
305 param_type<State> state,
307 double max_time) override;
308
309 std::unique_ptr<PolicyType> compute_policy(
310 MDPType& mdp,
311 EvaluatorType& heuristic,
312 param_type<State> state,
313 ProgressReport progress,
314 double max_time) override;
315
316 void print_statistics(std::ostream& out) const override;
317
321 [[nodiscard]]
323
331 Interval solve(
332 MDPType& mdp,
333 const EvaluatorType& heuristic,
334 StateID init_state_id,
335 auto& value_store,
336 double max_time = std::numeric_limits<double>::infinity());
337
338private:
342 void push_state(
343 StateID state_id,
344 StateInfo& state_info,
345 AlgorithmValueType& value);
346
352 bool initialize_state(
353 MDPType& mdp,
354 const EvaluatorType& heuristic,
355 ExplorationInfo& exp_info,
356 auto& value_store);
357
366 bool successor_loop(
367 MDPType& mdp,
368 ExplorationInfo& explore,
369 auto& value_store,
370 utils::CountdownTimer& timer);
371
375 void scc_found(
376 auto& value_store,
377 ExplorationInfo& exp_info,
378 unsigned int stack_idx,
379 utils::CountdownTimer& timer);
380
381 void find_and_decompose_sccs(utils::CountdownTimer& timer);
382
383 bool initialize_ecd(ECDExplorationInfo& exp_info);
384
385 bool
386 push_successor_ecd(ECDExplorationInfo& e, utils::CountdownTimer& timer);
387
388 void scc_found_ecd(ECDExplorationInfo& e);
389};
390
391} // namespace probfd::algorithms::ta_topological_vi
392
393#define GUARD_INCLUDE_PROBFD_ALGORITHMS_TA_TOPOLOGICAL_VALUE_ITERATION_H
394#include "probfd/algorithms/ta_topological_value_iteration_impl.h"
395#undef GUARD_INCLUDE_PROBFD_ALGORITHMS_TA_TOPOLOGICAL_VALUE_ITERATION_H
396
397#endif // PROBFD_ALGORITHMS_TA_TOPOLOGICAL_VALUE_ITERATION_H
A convenience class that represents a finite probability distribution.
Definition task_state_space.h:27
An item-probability pair.
Definition distribution.h:20
Interface for MDP algorithm implementations.
Definition mdp_algorithm.h:29
A registry for print functions related to search progress.
Definition progress_report.h:33
Implements a trap-aware variant of Topological Value Iteration.
Definition ta_topological_value_iteration.h:63
void print_statistics(std::ostream &out) const override
Prints algorithm statistics to the specified output stream.
Definition ta_topological_value_iteration_impl.h:283
Statistics get_statistics() const
Retreive the algorithm statistics.
Definition ta_topological_value_iteration_impl.h:291
Namespace dedicated to trap-aware Topological Value Iteration (TATVI).
Definition ta_topological_value_iteration.h:24
std::conditional_t< UseInterval, Interval, value_t > AlgorithmValue
Convenience value type alias for algorithms selecting interval iteration behaviour based on a templat...
Definition types.h:14
typename std::conditional_t< is_cheap_to_copy_v< T >, T, const T & > param_type
Alias template defining the best way to pass a parameter of a given type.
Definition type_traits.h:25
Represents a closed interval over the extended reals as a pair of lower and upper bound.
Definition interval.h:12
A StateID represents a state within a StateIDMap. Just like Fast Downward's StateID type,...
Definition types.h:22
Topological value iteration statistics.
Definition ta_topological_value_iteration.h:29