AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
topological_value_iteration.h
1#ifndef PROBFD_ALGORITHMS_TOPOLOGICAL_VALUE_ITERATION_H
2#define PROBFD_ALGORITHMS_TOPOLOGICAL_VALUE_ITERATION_H
3
4#include "probfd/algorithms/types.h"
5
6#include "probfd/storage/per_state_storage.h"
7
8#include "probfd/distribution.h"
9#include "probfd/mdp_algorithm.h"
10
11#include <deque>
12#include <limits>
13#include <ostream>
14#include <vector>
15
16// Forward Declarations
17namespace utils {
18class CountdownTimer;
19}
20
21namespace probfd::policies {
22template <typename, typename>
23class MapPolicy;
24}
25
28
32struct Statistics {
33 unsigned long long expanded_states = 0;
34 unsigned long long terminal_states = 0;
35 unsigned long long goal_states = 0;
36 unsigned long long sccs = 0;
37 unsigned long long singleton_sccs = 0;
38 unsigned long long bellman_backups = 0;
39 unsigned long long pruned = 0;
40
41 void print(std::ostream& out) const;
42};
43
67template <typename State, typename Action, bool UseInterval = false>
68class TopologicalValueIteration : public MDPAlgorithm<State, Action> {
69 using Base = typename TopologicalValueIteration::MDPAlgorithm;
70
71 using PolicyType = typename Base::PolicyType;
72 using MDPType = typename Base::MDPType;
73 using EvaluatorType = typename Base::EvaluatorType;
74
75 using MapPolicy = policies::MapPolicy<State, Action>;
76 using AlgorithmValueType = algorithms::AlgorithmValue<UseInterval>;
77
78 struct StateInfo {
79 // Status Flags
80 enum { NEW, CLOSED, ONSTACK };
81
82 unsigned stack_id = 0;
83 uint8_t status = NEW;
84 };
85
86 struct QValueInfo {
87 // The action id this Q value belongs to.
88 Action action;
89
90 // Precomputed part of the Q-value.
91 // Sum of action cost plus those weighted successor values which
92 // have already converged due to topological ordering.
93 AlgorithmValueType conv_part;
94
95 // Pointers to successor values which have not yet converged,
96 // self-loops excluded.
97 std::vector<ItemProbabilityPair<AlgorithmValueType*>> nconv_successors;
98
99 QValueInfo(Action action, value_t action_cost);
100
101 bool finalize_transition(value_t self_loop_prob);
102
103 AlgorithmValueType compute_q_value() const;
104 };
105
106 struct StackInfo {
107 StateID state_id;
108
109 // Reference to the state value of the state.
110 AlgorithmValueType* value;
111
112 // Precomputed part of the max of the value update.
113 // Minimum over all Q values of actions leaving the SCC.
114 AlgorithmValueType conv_part;
115
116 // Remaining Q values which have not yet converged.
117 std::vector<QValueInfo> nconv_qs;
118
119 // The optimal action
120 std::optional<Action> best_action = std::nullopt;
121
122 // The optimal action among those leaving the SCC.
123 std::optional<Action> best_converged = std::nullopt;
124
125 StackInfo(StateID state_id, AlgorithmValueType& value_ref);
126
127 bool update_value();
128 };
129
130 struct ExplorationInfo {
131 // Exploration State
132 std::vector<Action> aops; // Remaining unexpanded operators
133 Distribution<StateID> transition; // Currently expanded transition
134 Distribution<StateID>::const_iterator successor; // Current successor
135
136 public:
137 // Immutable info
138 StateID state_id; // State this information belongs to
139 StackInfo& stack_info;
140 unsigned stackidx; // Index on the stack of the associated state
141
142 unsigned lowlink;
143
144 value_t self_loop_prob = 0_vt;
145
146 ExplorationInfo(
147 StateID state_id,
148 StackInfo& stack_info,
149 unsigned stackidx);
150
151 void update_lowlink(unsigned upd);
152
153 bool next_transition(MDPType& mdp);
154 bool next_successor();
155
156 bool forward_non_loop_transition(MDPType& mdp, const State& state);
157 bool forward_non_loop_successor();
158
159 Action& get_current_action();
160 ItemProbabilityPair<StateID> get_current_successor();
161 };
162
163 using StackIterator = typename std::vector<StackInfo>::iterator;
164
165 // Algorithm parameters
166 const bool expand_goals_;
167
168 // Algorithm state
169 storage::PerStateStorage<StateInfo> state_information_;
170 std::deque<ExplorationInfo> exploration_stack_;
171 std::vector<StackInfo> stack_;
172
173 Statistics statistics_;
174
175public:
176 explicit TopologicalValueIteration(bool expand_goals);
177
178 std::unique_ptr<PolicyType> compute_policy(
179 MDPType& mdp,
180 EvaluatorType& heuristic,
181 param_type<State> state,
183 double max_time) override;
184
185 Interval solve(
186 MDPType& mdp,
187 EvaluatorType& heuristic,
188 param_type<State> state,
190 double max_time) override;
191
192 void print_statistics(std::ostream& out) const override;
193
197 [[nodiscard]]
199
207 template <typename ValueStore>
208 Interval solve(
209 MDPType& mdp,
210 EvaluatorType& heuristic,
211 StateID init_state_id,
212 ValueStore& value_store,
213 double max_time = std::numeric_limits<double>::infinity(),
214 MapPolicy* policy = nullptr);
215
216private:
222 void push_state(
223 StateID state_id,
224 StateInfo& state_info,
225 AlgorithmValueType& state_value);
226
232 bool initialize_state(
233 MDPType& mdp,
234 EvaluatorType& heuristic,
235 ExplorationInfo& exp_info,
236 auto& value_store);
237
246 template <typename ValueStore>
247 bool successor_loop(
248 MDPType& mdp,
249 ExplorationInfo& explore,
250 ValueStore& value_store,
251 utils::CountdownTimer& timer);
252
256 void scc_found(auto scc, MapPolicy* policy, utils::CountdownTimer& timer);
257};
258
259} // namespace probfd::algorithms::topological_vi
260
261#define GUARD_INCLUDE_PROBFD_ALGORITHMS_TOPOLOGICAL_VALUE_ITERATION_H
262#include "probfd/algorithms/topological_value_iteration_impl.h"
263#undef GUARD_INCLUDE_PROBFD_ALGORITHMS_TOPOLOGICAL_VALUE_ITERATION_H
264
265#endif // PROBFD_ALGORITHMS_TOPOLOGICAL_VALUE_ITERATION_H
A convenience class that represents a finite probability distribution.
Definition task_state_space.h:27
An item-probability pair.
Definition distribution.h:20
Interface for MDP algorithm implementations.
Definition mdp_algorithm.h:29
A registry for print functions related to search progress.
Definition progress_report.h:33
Implements Topological Value Iteration dai:etal:jair-11.
Definition topological_value_iteration.h:68
Statistics get_statistics() const
Retreive the algorithm statistics.
Definition topological_value_iteration_impl.h:249
void print_statistics(std::ostream &out) const override
Prints algorithm statistics to the specified output stream.
Definition topological_value_iteration_impl.h:241
Namespace dedicated to Topological Value Iteration (TVI).
Definition topological_value_iteration.h:27
std::conditional_t< UseInterval, Interval, value_t > AlgorithmValue
Convenience value type alias for algorithms selecting interval iteration behaviour based on a templat...
Definition types.h:14
double value_t
Typedef for the state value type.
Definition aliases.h:7
typename std::conditional_t< is_cheap_to_copy_v< T >, T, const T & > param_type
Alias template defining the best way to pass a parameter of a given type.
Definition type_traits.h:25
Represents a closed interval over the extended reals as a pair of lower and upper bound.
Definition interval.h:12
A StateID represents a state within a StateIDMap. Just like Fast Downward's StateID type,...
Definition types.h:22
Topological value iteration statistics.
Definition topological_value_iteration.h:32