1#ifndef PROBFD_SOLVERS_MDP_HEURISTIC_SEARCH_H
2#define PROBFD_SOLVERS_MDP_HEURISTIC_SEARCH_H
4#include "probfd/solvers/mdp_solver.h"
6#include "probfd/algorithms/fret.h"
7#include "probfd/algorithms/policy_picker.h"
9#include "probfd/solvers/bisimulation_heuristic_search_algorithm.h"
19namespace probfd::quotients {
20template <
typename,
typename>
33template <
bool Bisimulation,
bool Fret>
34using StateType = std::conditional_t<
38 probfd::quotients::QuotientState<
44 probfd::quotients::QuotientState<State, OperatorID>,
47template <
bool Bisimulation,
bool Fret>
48using ActionType = std::conditional_t<
52 probfd::quotients::QuotientAction<probfd::bisimulation::QuotientAction>,
56 probfd::quotients::QuotientAction<OperatorID>,
59template <
bool Bisimulation,
bool Fret>
60class MDPHeuristicSearchBase :
public MDPSolver {
62 using PolicyPicker = algorithms::PolicyPicker<
63 StateType<Bisimulation, Fret>,
64 ActionType<Bisimulation, Fret>>;
66 const bool dual_bounds_;
67 const std::shared_ptr<PolicyPicker> tiebreaker_;
70 MDPHeuristicSearchBase(
72 std::shared_ptr<PolicyPicker> policy,
73 utils::Verbosity verbosity,
74 std::vector<std::shared_ptr<::Evaluator>> path_dependent_evaluators,
76 const std::shared_ptr<TaskEvaluatorFactory>& eval,
77 std::optional<value_t> report_epsilon,
80 std::string policy_filename,
81 bool print_fact_names);
83 void print_additional_statistics()
const override;
85 virtual std::string get_heuristic_search_name()
const = 0;
88template <
bool Bisimulation,
bool Fret>
89class MDPHeuristicSearch;
92class MDPHeuristicSearch<false, false>
93 :
public MDPHeuristicSearchBase<false, false> {
97 std::shared_ptr<PolicyPicker> policy,
98 utils::Verbosity verbosity,
99 std::vector<std::shared_ptr<::Evaluator>> path_dependent_evaluators,
101 const std::shared_ptr<TaskEvaluatorFactory>& eval,
102 std::optional<value_t> report_epsilon,
105 std::string policy_filename,
106 bool print_fact_names);
108 std::string get_algorithm_name()
const override;
110 template <
template <
typename,
typename,
bool>
class HS,
typename... Args>
111 std::unique_ptr<FDRMDPAlgorithm>
112 create_heuristic_search_algorithm(Args&&... args)
115 using HeuristicSearchType = HS<State, OperatorID, true>;
116 return std::make_unique<HeuristicSearchType>(
118 std::forward<Args>(args)...);
120 using HeuristicSearchType = HS<State, OperatorID, false>;
121 return std::make_unique<HeuristicSearchType>(
123 std::forward<Args>(args)...);
129class MDPHeuristicSearch<false, true>
130 :
public MDPHeuristicSearchBase<false, true> {
131 using QState = quotients::QuotientState<State, OperatorID>;
132 using QAction = quotients::QuotientAction<OperatorID>;
134 const bool fret_on_policy_;
140 std::shared_ptr<PolicyPicker> policy,
141 utils::Verbosity verbosity,
142 std::vector<std::shared_ptr<::Evaluator>> path_dependent_evaluators,
144 const std::shared_ptr<TaskEvaluatorFactory>& eval,
145 std::optional<value_t> report_epsilon,
148 std::string policy_filename,
149 bool print_fact_names);
151 std::string get_algorithm_name()
const override;
153 template <
template <
typename,
typename,
bool>
class HS,
typename... Args>
154 std::unique_ptr<FDRMDPAlgorithm>
155 create_heuristic_search_algorithm(Args&&... args)
157 if (this->dual_bounds_) {
158 if (this->fret_on_policy_) {
159 return this->
template create_heuristic_search_algorithm_wrapper<
162 true>(std::forward<Args>(args)...);
164 return this->
template create_heuristic_search_algorithm_wrapper<
167 true>(std::forward<Args>(args)...);
170 if (this->fret_on_policy_) {
171 return this->
template create_heuristic_search_algorithm_wrapper<
174 false>(std::forward<Args>(args)...);
176 return this->
template create_heuristic_search_algorithm_wrapper<
179 false>(std::forward<Args>(args)...);
184 template <
template <
typename,
typename,
bool>
class HS,
typename... Args>
185 std::unique_ptr<FDRMDPAlgorithm>
186 create_quotient_heuristic_search_algorithm(Args&&... args)
189 return std::make_unique<HS<State, OperatorID, true>>(
191 std::forward<Args>(args)...);
193 return std::make_unique<HS<State, OperatorID, false>>(
195 std::forward<Args>(args)...);
201 template <
typename,
typename,
typename>
203 template <
typename,
typename,
bool>
207 std::unique_ptr<FDRMDPAlgorithm>
208 create_heuristic_search_algorithm_wrapper(Args&&... args)
210 using StateInfoT =
typename HS<QState, QAction, Interval>::StateInfo;
211 return std::make_unique<Fret<State, OperatorID, StateInfoT>>(
212 std::make_shared<HS<QState, QAction, Interval>>(
214 std::forward<Args>(args)...));
219class MDPHeuristicSearch<true, false>
220 :
public MDPHeuristicSearchBase<true, false> {
224 std::shared_ptr<PolicyPicker> policy,
225 utils::Verbosity verbosity,
226 std::vector<std::shared_ptr<::Evaluator>> path_dependent_evaluators,
228 const std::shared_ptr<TaskEvaluatorFactory>& eval,
229 std::optional<value_t> report_epsilon,
232 std::string policy_filename,
233 bool print_fact_names);
235 std::string get_algorithm_name()
const override;
237 template <
template <
typename,
typename,
bool>
class HS,
typename... Args>
238 std::unique_ptr<FDRMDPAlgorithm>
239 create_heuristic_search_algorithm(Args&&... args)
242 return BisimulationBasedHeuristicSearchAlgorithm::create<HS, true>(
244 this->task_cost_function_,
245 this->get_heuristic_search_name(),
247 std::forward<Args>(args)...);
249 return BisimulationBasedHeuristicSearchAlgorithm::create<HS, false>(
251 this->task_cost_function_,
252 this->get_heuristic_search_name(),
254 std::forward<Args>(args)...);
260class MDPHeuristicSearch<true, true>
261 :
public MDPHeuristicSearchBase<true, true> {
262 const bool fret_on_policy_;
268 std::shared_ptr<PolicyPicker> policy,
269 utils::Verbosity verbosity,
270 std::vector<std::shared_ptr<::Evaluator>> path_dependent_evaluators,
272 const std::shared_ptr<TaskEvaluatorFactory>& eval,
273 std::optional<value_t> report_epsilon,
276 std::string policy_filename,
277 bool print_fact_names);
279 std::string get_algorithm_name()
const override;
281 template <
template <
typename,
typename,
bool>
class HS,
typename... Args>
282 std::unique_ptr<FDRMDPAlgorithm>
283 create_heuristic_search_algorithm(Args&&... args)
285 if (this->dual_bounds_) {
286 if (this->fret_on_policy_) {
288 ->template heuristic_search_algorithm_factory_wrapper<
291 HS>(std::forward<Args>(args)...);
294 ->template heuristic_search_algorithm_factory_wrapper<
297 HS>(std::forward<Args>(args)...);
300 if (this->fret_on_policy_) {
302 ->template heuristic_search_algorithm_factory_wrapper<
305 HS>(std::forward<Args>(args)...);
308 ->template heuristic_search_algorithm_factory_wrapper<
311 HS>(std::forward<Args>(args)...);
317 template <
typename,
typename,
typename>
320 std::unique_ptr<FDRMDPAlgorithm>
321 create_quotient_heuristic_search_algorithm(Args&&... args)
324 return BisimulationBasedHeuristicSearchAlgorithm::create<HS, true>(
326 this->task_cost_function_,
327 this->get_heuristic_search_name(),
329 std::forward<Args>(args)...);
331 return BisimulationBasedHeuristicSearchAlgorithm::create<HS, false>(
333 this->task_cost_function_,
334 this->get_heuristic_search_name(),
336 std::forward<Args>(args)...);
342 template <
typename,
typename,
typename>
345 template <
typename,
typename,
bool>
348 std::unique_ptr<FDRMDPAlgorithm>
349 heuristic_search_algorithm_factory_wrapper(Args&&... args)
351 return BisimulationBasedHeuristicSearchAlgorithm::
352 create<Fret, HS, Interval>(
354 this->task_cost_function_,
355 this->get_heuristic_search_name(),
357 std::forward<Args>(args)...);
FRET< State, Action, StateInfoT, PolicyGraph< State, Action, StateInfoT > > FRETPi
Implementation of FRET with trap elimination in the greedy policy graph of the last returned policy.
Definition fret.h:267
FRET< State, Action, StateInfoT, ValueGraph< State, Action, StateInfoT > > FRETV
Implementation of FRET with trap elimination in the greedy value graph of the MDP.
Definition fret.h:255
This namespace contains the implementation of deterministic bisimulation quotients for SSPs,...
Definition bisimilar_state_space.h:33
QuotientAction
Represents an action in the probabilistic bisimulation quotient.
Definition types.h:12
QuotientState
Represents a state in the probabilistic bisimulation quotient.
Definition types.h:9
This namespace contains the solver interface base class for various search algorithms.
Definition bisimulation_heuristic_search_algorithm.h:17