AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
heuristic_search_base.h
1#ifndef PROBFD_ALGORITHMS_HEURISTIC_SEARCH_BASE_H
2#define PROBFD_ALGORITHMS_HEURISTIC_SEARCH_BASE_H
3
4#include "probfd/algorithms/heuristic_search_state_information.h"
5#include "probfd/algorithms/types.h"
6
7#include "probfd/mdp_algorithm.h"
8#include "probfd/progress_report.h"
9
10#if defined(EXPENSIVE_STATISTICS)
11#include "downward/utils/timer.h"
12#endif
13
14#include <algorithm>
15#include <iostream>
16#include <limits>
17#include <type_traits>
18#include <vector>
19
20// Forward Declarations
21namespace probfd {
22template <typename>
23class Distribution;
24template <typename>
25struct Transition;
26template <typename, typename>
27class CostFunction;
28} // namespace probfd
29
30namespace probfd::algorithms {
31template <typename, typename>
32class PolicyPicker;
33template <typename>
34class SuccessorSampler;
35} // namespace probfd::algorithms
36
38template <typename, typename, typename, typename>
39class FRET;
40template <typename, typename, typename>
41class PolicyGraph;
42template <typename, typename, typename>
43class ValueGraph;
44} // namespace probfd::algorithms::fret
45
48
49namespace internal {
50
54struct Statistics {
55 unsigned long long evaluated_states = 0;
56 unsigned long long pruned_states = 0;
57 unsigned long long goal_states = 0;
58
59 unsigned long long expanded_states = 0;
60 unsigned long long terminal_states = 0;
61 unsigned long long self_loop_states = 0;
62
63 unsigned long long value_changes = 0;
64 unsigned long long policy_changes = 0;
65 unsigned long long value_updates = 0;
66 unsigned long long policy_updates = 0;
67
68 value_t initial_state_estimate = 0;
69 bool initial_state_found_terminal = false;
70
71#if defined(EXPENSIVE_STATISTICS)
72 utils::Timer update_time = utils::Timer(true);
73 utils::Timer policy_selection_time = utils::Timer(true);
74#endif
75
79 void print(std::ostream& out) const;
80};
81
82template <typename StateInfo>
83class StateInfos : public StateProperties {
84 storage::PerStateStorage<StateInfo> state_infos_;
85
86public:
87 StateInfo& operator[](StateID sid) { return state_infos_[sid]; }
88 const StateInfo& operator[](StateID sid) const { return state_infos_[sid]; }
89
90 value_t lookup_value(StateID state_id) override
91 {
92 return state_infos_[state_id].get_value();
93 }
94
95 Interval lookup_bounds(StateID state_id) override
96 {
97 return state_infos_[state_id].get_bounds();
98 }
99
100 void reset() { std::ranges::for_each(state_infos_, &StateInfo::clear); }
101};
102
103} // namespace internal
104
112template <typename State, typename Action, typename StateInfoT>
114 template <bool b, typename T>
115 using const_if = std::conditional_t<b, const T, T>;
116
117protected:
121 using TransitionType = Transition<Action>;
122
124
125 // Fret implementation has access to the internals of this base class.
126 template <typename, typename, typename, typename>
127 friend class fret::FRET;
128
129 template <typename, typename, typename>
130 friend class fret::PolicyGraph;
131
132 template <typename, typename, typename>
133 friend class fret::ValueGraph;
134
135public:
136 using StateInfo = StateInfoT;
137
138 static constexpr bool StorePolicy = StateInfo::StorePolicy;
139 static constexpr bool UseInterval = StateInfo::UseInterval;
140
141 using AlgorithmValueType = AlgorithmValue<UseInterval>;
142
143private:
144 // Algorithm parameters
145 const std::shared_ptr<PolicyPickerType> policy_chooser_;
146
147protected:
148 // Algorithm state
149 internal::StateInfos<StateInfo> state_infos_;
150
151 internal::Statistics statistics_;
152
153 struct BellmanResult {
154 AlgorithmValueType best_value;
155 std::optional<TransitionType> transition;
156 };
157
158public:
159 explicit HeuristicSearchBase(
160 std::shared_ptr<PolicyPickerType> policy_chooser);
161
165 [[nodiscard]]
166 Interval lookup_bounds(StateID state_id) const;
167
172 [[nodiscard]]
173 bool was_visited(StateID state_id) const;
174
178 AlgorithmValueType compute_bellman(
179 CostFunctionType& cost_function,
180 StateID state_id,
181 const std::vector<TransitionType>& transitions,
182 value_t termination_cost) const;
183
200 AlgorithmValueType compute_bellman_and_greedy(
201 CostFunctionType& cost_function,
202 StateID state_id,
203 std::vector<TransitionType>& transitions,
204 value_t termination_cost,
205 std::vector<AlgorithmValueType>& qvalues,
206 value_t epsilon = g_epsilon) const;
207
218 std::optional<TransitionType> select_greedy_transition(
219 MDPType& mdp,
220 std::optional<Action> previous_greedy_action,
221 std::vector<TransitionType>& greedy_transitions);
222
229 bool update_value(
230 StateInfo& state_info,
231 AlgorithmValueType other,
232 value_t epsilon = g_epsilon);
233
240 bool update_policy(
241 StateInfo& state_info,
242 const std::optional<TransitionType>& transition)
243 requires(StorePolicy);
244
245protected:
246 void initialize_initial_state(
247 MDPType& mdp,
248 EvaluatorType& h,
249 param_type<State> state);
250
251 void expand_and_initialize(
252 MDPType& mdp,
253 EvaluatorType& h,
254 param_type<State> state,
255 StateInfo& state_info,
256 std::vector<TransitionType>& transitions);
257
258 void generate_non_tip_transitions(
259 MDPType& mdp,
260 param_type<State> state,
261 std::vector<TransitionType>& transitions) const;
262
263 void print_statistics(std::ostream& out) const;
264
265private:
266 void initialize(
267 MDPType& mdp,
268 EvaluatorType& h,
269 param_type<State> state,
270 StateInfo& state_info);
271
272 AlgorithmValueType compute_qvalue(
273 value_t action_cost,
274 StateID state_id,
275 const TransitionType& transition) const;
276
277 AlgorithmValueType compute_q_values(
278 CostFunctionType& cost_function,
279 StateID state_id,
280 std::vector<TransitionType>& transitions,
281 value_t termination_cost,
282 std::vector<AlgorithmValueType>& qvalues) const;
283
284 AlgorithmValueType filter_greedy_transitions(
285 std::vector<TransitionType>& transitions,
286 std::vector<AlgorithmValueType>& qvalues,
287 const AlgorithmValueType& best_value,
288 value_t epsilon = g_epsilon) const;
289};
290
299template <typename State, typename Action, typename StateInfoT>
301 : public MDPAlgorithm<State, Action>
302 , public HeuristicSearchBase<State, Action, StateInfoT> {
303 using AlgorithmBase = typename HeuristicSearchAlgorithm::MDPAlgorithm;
304 using HSBase = typename HeuristicSearchAlgorithm::HeuristicSearchBase;
305
306public:
307 using TransitionType = HSBase::TransitionType;
308 using AlgorithmValueType = HSBase::AlgorithmValueType;
309
310protected:
311 using PolicyType = typename AlgorithmBase::PolicyType;
312
313 using MDPType = typename AlgorithmBase::MDPType;
314 using EvaluatorType = typename AlgorithmBase::EvaluatorType;
315
316 using StateInfo = typename HSBase::StateInfo;
317 using PolicyPicker = typename HSBase::PolicyPickerType;
318
319public:
320 // Inherited constructor
321 using HSBase::HSBase;
322
323 Interval solve(
324 MDPType& mdp,
325 EvaluatorType& h,
326 param_type<State> state,
327 ProgressReport progress,
328 double max_time) final;
329
330 std::unique_ptr<PolicyType> compute_policy(
331 MDPType& mdp,
332 EvaluatorType& h,
333 param_type<State> state,
334 ProgressReport progress,
335 double max_time) final;
336
337 void print_statistics(std::ostream& out) const final;
338
345 MDPType& mdp,
346 EvaluatorType& h,
347 param_type<State> state,
348 ProgressReport& progress,
349 double max_time) = 0;
350
356 virtual void print_additional_statistics(std::ostream& out) const = 0;
357};
358
366template <typename State, typename Action, typename StateInfoT>
368 : public HeuristicSearchAlgorithm<State, Action, StateInfoT> {
369 using AlgorithmBase = typename FRETHeuristicSearchAlgorithm::MDPAlgorithm;
370 using HSBase =
371 typename FRETHeuristicSearchAlgorithm::HeuristicSearchAlgorithm;
372
373protected:
374 using PolicyType = typename AlgorithmBase::PolicyType;
375
376 using MDPType = typename AlgorithmBase::MDPType;
377 using EvaluatorType = typename AlgorithmBase::EvaluatorType;
378
379 using StateInfo = typename HSBase::StateInfo;
380 using PolicyPicker = typename HSBase::PolicyPickerType;
381
382public:
383 // Inherited constructor
384 using HSBase::HSBase;
385
392 virtual void reset_search_state() {}
393};
394
395} // namespace probfd::algorithms::heuristic_search
396
397#define GUARD_INCLUDE_PROBFD_ALGORITHMS_HEURISTIC_SEARCH_BASE_H
398#include "probfd/algorithms/heuristic_search_base_impl.h"
399#undef GUARD_INCLUDE_PROBFD_ALGORITHMS_HEURISTIC_SEARCH_BASE_H
400
401#endif // __HEURISTIC_SEARCH_BASE_H__
Interface for MDP algorithm implementations.
Definition mdp_algorithm.h:29
A registry for print functions related to search progress.
Definition progress_report.h:33
An strategy interface used to choose break ties between multiple greedy actions for a state.
Definition policy_picker.h:57
Interface providing access to various state properties during heuristic search.
Definition state_properties.h:22
Implemetation of the Find-Revise-Eliminate-Traps (FRET) framework kolobov:etal:icaps-11 .
Definition heuristic_search_base.h:39
Heuristics search algorithm that can be used within FRET.
Definition heuristic_search_base.h:368
virtual void reset_search_state()
Resets the h search algorithm object to a clean state.
Definition heuristic_search_base.h:392
Extends HeuristicSearchBase with default implementations for MDPAlgorithm.
Definition heuristic_search_base.h:302
virtual void print_additional_statistics(std::ostream &out) const =0
Prints additional statistics to the output stream.
void print_statistics(std::ostream &out) const final
Prints algorithm statistics to the specified output stream.
Definition heuristic_search_base_impl.h:484
virtual Interval do_solve(MDPType &mdp, EvaluatorType &h, param_type< State > state, ProgressReport &progress, double max_time)=0
Solves for the optimal state value of the input state.
The common base class for MDP h search algorithms.
Definition heuristic_search_base.h:113
bool update_value(StateInfo &state_info, AlgorithmValueType other, value_t epsilon=g_epsilon)
Updates the value of the state associated with the given storage.
Definition heuristic_search_base_impl.h:153
AlgorithmValueType compute_bellman(CostFunctionType &cost_function, StateID state_id, const std::vector< TransitionType > &transitions, value_t termination_cost) const
Computes the Bellman operator value for a state.
Definition heuristic_search_base_impl.h:74
Interval lookup_bounds(StateID state_id) const
Looks up the current value interval of state_id.
Definition heuristic_search_base_impl.h:60
bool update_policy(StateInfo &state_info, const std::optional< TransitionType > &transition)
Updates the current greedy action of the state associated with the given storage.
Definition heuristic_search_base_impl.h:165
std::optional< TransitionType > select_greedy_transition(MDPType &mdp, std::optional< Action > previous_greedy_action, std::vector< TransitionType > &greedy_transitions)
Selects a greedy transition from the given list of greedy transitions through the policy selector pas...
Definition heuristic_search_base_impl.h:130
bool was_visited(StateID state_id) const
Checks if the state represented by state_id has been visited yet.
Definition heuristic_search_base_impl.h:67
AlgorithmValueType compute_bellman_and_greedy(CostFunctionType &cost_function, StateID state_id, std::vector< TransitionType > &transitions, value_t termination_cost, std::vector< AlgorithmValueType > &qvalues, value_t epsilon=g_epsilon) const
Computes the Bellman operator value for a state, as well as all transitions achieving a value epsilon...
Definition heuristic_search_base_impl.h:95
Namespace dedicated to the Find, Revise, Eliminate Traps (FRET) framework.
Definition fret.h:23
Namespace dedicated to the MDP h search base implementation.
Definition heuristic_search_base.h:47
This namespace contains implementations of SSP search algorithms.
Definition acyclic_value_iteration.h:22
std::conditional_t< UseInterval, Interval, value_t > AlgorithmValue
Convenience value type alias for algorithms selecting interval iteration behaviour based on a templat...
Definition types.h:14
The top-level namespace of probabilistic Fast Downward.
Definition command_line.h:8
double value_t
Typedef for the state value type.
Definition aliases.h:7
typename std::conditional_t< is_cheap_to_copy_v< T >, T, const T & > param_type
Alias template defining the best way to pass a parameter of a given type.
Definition type_traits.h:25
value_t g_epsilon
The default tolerance value for approximate comparisons.
Represents a closed interval over the extended reals as a pair of lower and upper bound.
Definition interval.h:12
A StateID represents a state within a StateIDMap. Just like Fast Downward's StateID type,...
Definition types.h:22
Base statistics for MDP h search.
Definition heuristic_search_base.h:54
void print(std::ostream &out) const
Prints the statistics to the specified output stream.
Definition heuristic_search_base_impl.h:26