AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
mdp_heuristic_search.h
1#ifndef PROBFD_SOLVERS_MDP_HEURISTIC_SEARCH_H
2#define PROBFD_SOLVERS_MDP_HEURISTIC_SEARCH_H
3
4#include "probfd/solvers/mdp_solver.h"
5
6#include "probfd/algorithms/fret.h"
7#include "probfd/algorithms/policy_picker.h"
8
9#include "probfd/solvers/bisimulation_heuristic_search_algorithm.h"
10
11#include <memory>
12#include <string>
13#include <type_traits>
14#include <utility>
15
16class State;
17class OperatorID;
18
19namespace probfd::quotients {
20template <typename, typename>
21struct QuotientState;
22template <typename>
23struct QuotientAction;
24} // namespace probfd::quotients
25
26namespace probfd::bisimulation {
27enum class QuotientState;
28enum class QuotientAction;
29} // namespace probfd::bisimulation
30
31namespace probfd::solvers {
32
33template <bool Bisimulation, bool Fret>
34using StateType = std::conditional_t<
35 Bisimulation,
36 std::conditional_t<
37 Fret,
38 probfd::quotients::QuotientState<
42 std::conditional_t<
43 Fret,
44 probfd::quotients::QuotientState<State, OperatorID>,
45 State>>;
46
47template <bool Bisimulation, bool Fret>
48using ActionType = std::conditional_t<
49 Bisimulation,
50 std::conditional_t<
51 Fret,
52 probfd::quotients::QuotientAction<probfd::bisimulation::QuotientAction>,
54 std::conditional_t<
55 Fret,
56 probfd::quotients::QuotientAction<OperatorID>,
57 OperatorID>>;
58
59template <bool Bisimulation, bool Fret>
60class MDPHeuristicSearchBase : public MDPSolver {
61protected:
62 using PolicyPicker = algorithms::PolicyPicker<
63 StateType<Bisimulation, Fret>,
64 ActionType<Bisimulation, Fret>>;
65
66 const bool dual_bounds_;
67 const std::shared_ptr<PolicyPicker> tiebreaker_;
68
69public:
70 MDPHeuristicSearchBase(
71 bool dual_bounds,
72 std::shared_ptr<PolicyPicker> policy,
73 utils::Verbosity verbosity,
74 std::vector<std::shared_ptr<::Evaluator>> path_dependent_evaluators,
75 bool cache,
76 const std::shared_ptr<TaskEvaluatorFactory>& eval,
77 std::optional<value_t> report_epsilon,
78 bool report_enabled,
79 double max_time,
80 std::string policy_filename,
81 bool print_fact_names);
82
83 void print_additional_statistics() const override;
84
85 virtual std::string get_heuristic_search_name() const = 0;
86};
87
88template <bool Bisimulation, bool Fret>
89class MDPHeuristicSearch;
90
91template <>
92class MDPHeuristicSearch<false, false>
93 : public MDPHeuristicSearchBase<false, false> {
94public:
95 MDPHeuristicSearch(
96 bool dual_bounds,
97 std::shared_ptr<PolicyPicker> policy,
98 utils::Verbosity verbosity,
99 std::vector<std::shared_ptr<::Evaluator>> path_dependent_evaluators,
100 bool cache,
101 const std::shared_ptr<TaskEvaluatorFactory>& eval,
102 std::optional<value_t> report_epsilon,
103 bool report_enabled,
104 double max_time,
105 std::string policy_filename,
106 bool print_fact_names);
107
108 std::string get_algorithm_name() const override;
109
110 template <template <typename, typename, bool> class HS, typename... Args>
111 std::unique_ptr<FDRMDPAlgorithm>
112 create_heuristic_search_algorithm(Args&&... args)
113 {
114 if (dual_bounds_) {
115 using HeuristicSearchType = HS<State, OperatorID, true>;
116 return std::make_unique<HeuristicSearchType>(
117 tiebreaker_,
118 std::forward<Args>(args)...);
119 } else {
120 using HeuristicSearchType = HS<State, OperatorID, false>;
121 return std::make_unique<HeuristicSearchType>(
122 tiebreaker_,
123 std::forward<Args>(args)...);
124 }
125 }
126};
127
128template <>
129class MDPHeuristicSearch<false, true>
130 : public MDPHeuristicSearchBase<false, true> {
131 using QState = quotients::QuotientState<State, OperatorID>;
132 using QAction = quotients::QuotientAction<OperatorID>;
133
134 const bool fret_on_policy_;
135
136public:
137 MDPHeuristicSearch(
138 bool fret_on_policy,
139 bool dual_bounds,
140 std::shared_ptr<PolicyPicker> policy,
141 utils::Verbosity verbosity,
142 std::vector<std::shared_ptr<::Evaluator>> path_dependent_evaluators,
143 bool cache,
144 const std::shared_ptr<TaskEvaluatorFactory>& eval,
145 std::optional<value_t> report_epsilon,
146 bool report_enabled,
147 double max_time,
148 std::string policy_filename,
149 bool print_fact_names);
150
151 std::string get_algorithm_name() const override;
152
153 template <template <typename, typename, bool> class HS, typename... Args>
154 std::unique_ptr<FDRMDPAlgorithm>
155 create_heuristic_search_algorithm(Args&&... args)
156 {
157 if (this->dual_bounds_) {
158 if (this->fret_on_policy_) {
159 return this->template create_heuristic_search_algorithm_wrapper<
161 HS,
162 true>(std::forward<Args>(args)...);
163 } else {
164 return this->template create_heuristic_search_algorithm_wrapper<
166 HS,
167 true>(std::forward<Args>(args)...);
168 }
169 } else {
170 if (this->fret_on_policy_) {
171 return this->template create_heuristic_search_algorithm_wrapper<
173 HS,
174 false>(std::forward<Args>(args)...);
175 } else {
176 return this->template create_heuristic_search_algorithm_wrapper<
178 HS,
179 false>(std::forward<Args>(args)...);
180 }
181 }
182 }
183
184 template <template <typename, typename, bool> class HS, typename... Args>
185 std::unique_ptr<FDRMDPAlgorithm>
186 create_quotient_heuristic_search_algorithm(Args&&... args)
187 {
188 if (dual_bounds_) {
189 return std::make_unique<HS<State, OperatorID, true>>(
190 tiebreaker_,
191 std::forward<Args>(args)...);
192 } else {
193 return std::make_unique<HS<State, OperatorID, false>>(
194 tiebreaker_,
195 std::forward<Args>(args)...);
196 }
197 }
198
199private:
200 template <
201 template <typename, typename, typename>
202 class Fret,
203 template <typename, typename, bool>
204 class HS,
205 bool Interval,
206 typename... Args>
207 std::unique_ptr<FDRMDPAlgorithm>
208 create_heuristic_search_algorithm_wrapper(Args&&... args)
209 {
210 using StateInfoT = typename HS<QState, QAction, Interval>::StateInfo;
211 return std::make_unique<Fret<State, OperatorID, StateInfoT>>(
212 std::make_shared<HS<QState, QAction, Interval>>(
213 tiebreaker_,
214 std::forward<Args>(args)...));
215 }
216};
217
218template <>
219class MDPHeuristicSearch<true, false>
220 : public MDPHeuristicSearchBase<true, false> {
221public:
222 MDPHeuristicSearch(
223 bool dual_bounds,
224 std::shared_ptr<PolicyPicker> policy,
225 utils::Verbosity verbosity,
226 std::vector<std::shared_ptr<::Evaluator>> path_dependent_evaluators,
227 bool cache,
228 const std::shared_ptr<TaskEvaluatorFactory>& eval,
229 std::optional<value_t> report_epsilon,
230 bool report_enabled,
231 double max_time,
232 std::string policy_filename,
233 bool print_fact_names);
234
235 std::string get_algorithm_name() const override;
236
237 template <template <typename, typename, bool> class HS, typename... Args>
238 std::unique_ptr<FDRMDPAlgorithm>
239 create_heuristic_search_algorithm(Args&&... args)
240 {
241 if (dual_bounds_) {
242 return BisimulationBasedHeuristicSearchAlgorithm::create<HS, true>(
243 this->task_,
244 this->task_cost_function_,
245 this->get_heuristic_search_name(),
246 this->tiebreaker_,
247 std::forward<Args>(args)...);
248 } else {
249 return BisimulationBasedHeuristicSearchAlgorithm::create<HS, false>(
250 this->task_,
251 this->task_cost_function_,
252 this->get_heuristic_search_name(),
253 this->tiebreaker_,
254 std::forward<Args>(args)...);
255 }
256 }
257};
258
259template <>
260class MDPHeuristicSearch<true, true>
261 : public MDPHeuristicSearchBase<true, true> {
262 const bool fret_on_policy_;
263
264public:
265 MDPHeuristicSearch(
266 bool fret_on_policy,
267 bool dual_bounds,
268 std::shared_ptr<PolicyPicker> policy,
269 utils::Verbosity verbosity,
270 std::vector<std::shared_ptr<::Evaluator>> path_dependent_evaluators,
271 bool cache,
272 const std::shared_ptr<TaskEvaluatorFactory>& eval,
273 std::optional<value_t> report_epsilon,
274 bool report_enabled,
275 double max_time,
276 std::string policy_filename,
277 bool print_fact_names);
278
279 std::string get_algorithm_name() const override;
280
281 template <template <typename, typename, bool> class HS, typename... Args>
282 std::unique_ptr<FDRMDPAlgorithm>
283 create_heuristic_search_algorithm(Args&&... args)
284 {
285 if (this->dual_bounds_) {
286 if (this->fret_on_policy_) {
287 return this
288 ->template heuristic_search_algorithm_factory_wrapper<
290 true,
291 HS>(std::forward<Args>(args)...);
292 } else {
293 return this
294 ->template heuristic_search_algorithm_factory_wrapper<
296 true,
297 HS>(std::forward<Args>(args)...);
298 }
299 } else {
300 if (this->fret_on_policy_) {
301 return this
302 ->template heuristic_search_algorithm_factory_wrapper<
304 false,
305 HS>(std::forward<Args>(args)...);
306 } else {
307 return this
308 ->template heuristic_search_algorithm_factory_wrapper<
310 false,
311 HS>(std::forward<Args>(args)...);
312 }
313 }
314 }
315
316 template <
317 template <typename, typename, typename>
318 class HS,
319 typename... Args>
320 std::unique_ptr<FDRMDPAlgorithm>
321 create_quotient_heuristic_search_algorithm(Args&&... args)
322 {
323 if (dual_bounds_) {
324 return BisimulationBasedHeuristicSearchAlgorithm::create<HS, true>(
325 this->task_,
326 this->task_cost_function_,
327 this->get_heuristic_search_name(),
328 this->tiebreaker_,
329 std::forward<Args>(args)...);
330 } else {
331 return BisimulationBasedHeuristicSearchAlgorithm::create<HS, false>(
332 this->task_,
333 this->task_cost_function_,
334 this->get_heuristic_search_name(),
335 this->tiebreaker_,
336 std::forward<Args>(args)...);
337 }
338 }
339
340private:
341 template <
342 template <typename, typename, typename>
343 class Fret,
344 bool Interval,
345 template <typename, typename, bool>
346 class HS,
347 typename... Args>
348 std::unique_ptr<FDRMDPAlgorithm>
349 heuristic_search_algorithm_factory_wrapper(Args&&... args)
350 {
351 return BisimulationBasedHeuristicSearchAlgorithm::
352 create<Fret, HS, Interval>(
353 this->task_,
354 this->task_cost_function_,
355 this->get_heuristic_search_name(),
356 this->tiebreaker_,
357 std::forward<Args>(args)...);
358 }
359};
360
361} // namespace probfd::solvers
362
363#endif // PROBFD_SOLVERS_MDP_HEURISTIC_SEARCH_H
FRET< State, Action, StateInfoT, PolicyGraph< State, Action, StateInfoT > > FRETPi
Implementation of FRET with trap elimination in the greedy policy graph of the last returned policy.
Definition fret.h:267
FRET< State, Action, StateInfoT, ValueGraph< State, Action, StateInfoT > > FRETV
Implementation of FRET with trap elimination in the greedy value graph of the MDP.
Definition fret.h:255
This namespace contains the implementation of deterministic bisimulation quotients for SSPs,...
Definition bisimilar_state_space.h:33
QuotientAction
Represents an action in the probabilistic bisimulation quotient.
Definition types.h:12
QuotientState
Represents a state in the probabilistic bisimulation quotient.
Definition types.h:9
This namespace contains the solver interface base class for various search algorithms.
Definition bisimulation_heuristic_search_algorithm.h:17