AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
fret.h
1#ifndef PROBFD_ALGORITHMS_FRET_H
2#define PROBFD_ALGORITHMS_FRET_H
3
4#include "probfd/algorithms/heuristic_search_base.h"
5
6#include "probfd/quotients/quotient_system.h"
7
8#include "probfd/progress_report.h"
9
10#if defined(EXPENSIVE_STATISTICS)
11#include "downward/utils/timer.h"
12#endif
13
14#include <limits>
15#include <type_traits>
16
17// Forward Declarations
18namespace utils {
19class CountdownTimer;
20}
21
24
25namespace internal {
26
27struct Statistics {
28 unsigned long long iterations = 0;
29 unsigned long long traps = 0;
30
31#if defined(EXPENSIVE_STATISTICS)
32 utils::Timer heuristic_search = utils::Timer(true);
33 utils::Timer trap_identification = utils::Timer(true);
34 utils::Timer trap_removal = utils::Timer(true);
35#endif
36
37 void print(std::ostream& out) const;
38};
39
40struct TarjanStateInformation {
41 static constexpr unsigned UNDEF = std::numeric_limits<unsigned int>::max();
42
43 unsigned stack_index = UNDEF;
44 unsigned lowlink = UNDEF;
45
46 [[nodiscard]]
47 bool is_explored() const
48 {
49 return lowlink != UNDEF;
50 }
51
52 [[nodiscard]]
53 bool is_on_stack() const
54 {
55 return stack_index != UNDEF;
56 }
57
58 void open(const unsigned x)
59 {
60 stack_index = x;
61 lowlink = x;
62 }
63
64 void close() { stack_index = UNDEF; }
65};
66
67struct ExplorationInfo {
68 ExplorationInfo(StateID state_id, std::vector<StateID> successors)
69 : state_id(state_id)
70 , successors(std::move(successors))
71 {
72 }
73
74 StateID state_id;
75 std::vector<StateID> successors;
76 bool is_leaf = true;
77};
78
79template <typename QAction>
80struct StackInfo {
81 StateID state_id;
82 std::vector<QAction> aops;
83
84 template <size_t i>
85 friend auto& get(StackInfo& info)
86 {
87 if constexpr (i == 0) return info.state_id;
88 if constexpr (i == 1) return info.aops;
89 }
90
91 template <size_t i>
92 friend const auto& get(const StackInfo& info)
93 {
94 if constexpr (i == 0) return info.state_id;
95 if constexpr (i == 1) return info.aops;
96 }
97};
98
99} // namespace internal
100
124template <
125 typename State,
126 typename Action,
127 typename StateInfoT,
128 typename GreedyGraphGenerator>
129class FRET : public MDPAlgorithm<State, Action> {
130 using Base = typename FRET::MDPAlgorithm;
131
132 using PolicyType = typename Base::PolicyType;
133 using MDPType = typename Base::MDPType;
134 using EvaluatorType = typename Base::EvaluatorType;
135
136 using QuotientSystem = quotients::QuotientSystem<State, Action>;
137 using QState = quotients::QuotientState<State, Action>;
138 using QAction = quotients::QuotientAction<Action>;
139 using QHeuristicSearchAlgorithm = heuristic_search::
140 FRETHeuristicSearchAlgorithm<QState, QAction, StateInfoT>;
141 using QEvaluator = probfd::Evaluator<QState>;
142
143 using StackInfo = internal::StackInfo<QAction>;
144
145 // Algorithm parameters
146 const std::shared_ptr<QHeuristicSearchAlgorithm> base_algorithm_;
147
148 internal::Statistics statistics_;
149
150public:
151 explicit FRET(std::shared_ptr<QHeuristicSearchAlgorithm> algorithm);
152
153 std::unique_ptr<PolicyType> compute_policy(
154 MDPType& mdp,
155 EvaluatorType& heuristic,
156 param_type<State> state,
157 ProgressReport progress,
158 double max_time) override;
159
160 Interval solve(
161 MDPType& mdp,
162 EvaluatorType& heuristic,
163 param_type<State> state,
164 ProgressReport progress,
165 double max_time) override;
166
167 void print_statistics(std::ostream& out) const override;
168
169private:
170 Interval solve(
171 QuotientSystem& quotient,
172 QEvaluator& heuristic,
173 param_type<QState> state,
174 ProgressReport& progress,
175 double max_time);
176
177 Interval heuristic_search(
178 QuotientSystem& quotient,
179 QEvaluator& heuristic,
180 param_type<QState> state,
181 ProgressReport& progress,
182 utils::CountdownTimer& timer);
183
184 bool find_and_remove_traps(
185 QuotientSystem& quotient,
186 param_type<QState> state,
187 utils::CountdownTimer& timer);
188
189 bool push(
190 QuotientSystem& quotient,
191 std::deque<internal::ExplorationInfo>& queue,
192 std::deque<StackInfo>& stack,
193 internal::TarjanStateInformation& info,
194 StateID state_id,
195 unsigned int& unexpanded);
196};
197
198template <typename State, typename Action, typename StateInfoT>
199class ValueGraph {
200 using QuotientSystem = quotients::QuotientSystem<State, Action>;
201 using QState = quotients::QuotientState<State, Action>;
202 using QAction = quotients::QuotientAction<Action>;
203
204 using QHeuristicSearchAlgorithm =
206
207 using AlgorithmValueType =
208 typename QHeuristicSearchAlgorithm::AlgorithmValueType;
209
210 using QEvaluator = Evaluator<QState>;
211
212 std::unordered_set<StateID> ids_;
213 std::vector<Transition<QAction>> opt_transitions_;
214 std::vector<AlgorithmValueType> q_values;
215
216public:
217 bool get_successors(
218 QuotientSystem& quotient,
219 QHeuristicSearchAlgorithm& base_algorithm,
220 StateID qstate,
221 std::vector<QAction>& aops,
222 std::vector<StateID>& successors);
223};
224
225template <typename State, typename Action, typename StateInfoT>
226class PolicyGraph {
227 using QuotientSystem = quotients::QuotientSystem<State, Action>;
228 using QState = quotients::QuotientState<State, Action>;
229 using QAction = quotients::QuotientAction<Action>;
230 using QHeuristicSearchAlgorithm =
232
233 using QEvaluator = Evaluator<QState>;
234
236
237public:
238 bool get_successors(
239 QuotientSystem& quotient,
240 QHeuristicSearchAlgorithm& base_algorithm,
241 StateID quotient_state_id,
242 std::vector<QAction>& aops,
243 std::vector<StateID>& successors);
244};
245
254template <typename State, typename Action, typename StateInfoT>
255using FRETV =
257
266template <typename State, typename Action, typename StateInfoT>
267using FRETPi =
269
270} // namespace probfd::algorithms::fret
271
272#define GUARD_INCLUDE_PROBFD_ALGORITHMS_FRET_H
273#include "probfd/algorithms/fret_impl.h"
274#undef GUARD_INCLUDE_PROBFD_ALGORITHMS_FRET_H
275
276#endif // PROBFD_ALGORITHMS_FRET_H
A convenience class that represents a finite probability distribution.
Definition task_state_space.h:27
The interface representing heuristic functions.
Definition mdp_algorithm.h:16
Interface for MDP algorithm implementations.
Definition mdp_algorithm.h:29
A registry for print functions related to search progress.
Definition progress_report.h:33
Implemetation of the Find-Revise-Eliminate-Traps (FRET) framework kolobov:etal:icaps-11 .
Definition heuristic_search_base.h:39
void print_statistics(std::ostream &out) const override
Prints algorithm statistics to the specified output stream.
Definition fret_impl.h:193
Extends HeuristicSearchBase with default implementations for MDPAlgorithm.
Definition heuristic_search_base.h:302
Namespace dedicated to the Find, Revise, Eliminate Traps (FRET) framework.
Definition fret.h:23
typename std::conditional_t< is_cheap_to_copy_v< T >, T, const T & > param_type
Alias template defining the best way to pass a parameter of a given type.
Definition type_traits.h:25
Represents a closed interval over the extended reals as a pair of lower and upper bound.
Definition interval.h:12
A StateID represents a state within a StateIDMap. Just like Fast Downward's StateID type,...
Definition types.h:22