1#ifndef GUARD_INCLUDE_PROBFD_ALGORITHMS_FRET_H
2#error "This file should only be included from fret.h"
5#include "probfd/policies/map_policy.h"
7#include "probfd/quotients/quotient_max_heuristic.h"
9#include "downward/utils/countdown_timer.h"
15inline void Statistics::print(std::ostream& out)
const
17 out <<
" FRET iterations: " << iterations << std::endl;
18#if defined(EXPENSIVE_STATISTICS)
19 out <<
" Heuristic search: " << heuristic_search << std::endl;
20 out <<
" Trap identification: " << (trap_identification() - trap_removal())
22 out <<
" Trap removal: " << trap_removal << std::endl;
32 typename GreedyGraphGenerator>
33FRET<State, Action, StateInfoT, GreedyGraphGenerator>::FRET(
34 std::shared_ptr<QHeuristicSearchAlgorithm> algorithm)
35 : base_algorithm_(
std::move(algorithm))
43 typename GreedyGraphGenerator>
44auto FRET<State, Action, StateInfoT, GreedyGraphGenerator>::compute_policy(
46 EvaluatorType& heuristic,
47 param_type<State> state,
48 ProgressReport progress,
49 double max_time) -> std::unique_ptr<PolicyType>
51 QuotientSystem quotient(mdp);
52 quotients::QuotientMaxHeuristic<State, Action> qheuristic(heuristic);
56 quotient.translate_state(state),
74 std::unique_ptr<policies::MapPolicy<State, Action>> policy(
75 new policies::MapPolicy<State, Action>(&mdp));
77 const StateID initial_state_id = mdp.get_state_id(state);
79 std::deque<StateID> queue;
80 std::set<StateID> visited;
81 queue.push_back(initial_state_id);
82 visited.insert(initial_state_id);
85 const StateID quotient_id = queue.front();
86 const QState quotient_state = quotient.get_state(quotient_id);
89 auto& base_info = base_algorithm_->state_infos_[quotient_id];
90 std::optional quotient_action = base_info.get_policy();
93 if (!quotient_action) {
97 const Interval quotient_bound =
98 base_algorithm_->lookup_bounds(quotient_id);
100 const StateID exiting_id = quotient_action->state_id;
102 policy->emplace_decision(
104 quotient_action->action,
108 if (quotient_state.num_members() != 1) {
109 std::unordered_map<StateID, std::set<QAction>> parents;
112 std::vector<QAction> inner_actions;
113 quotient_state.get_collapsed_actions(inner_actions);
115 for (
const QAction& qaction : inner_actions) {
116 StateID source_id = qaction.state_id;
117 Action action = qaction.action;
119 const State source = mdp.get_state(source_id);
121 Distribution<StateID> successors;
122 mdp.generate_action_transitions(source, action, successors);
124 for (
const StateID succ_id : successors.support()) {
125 parents[succ_id].insert(qaction);
131 std::deque<StateID> inverse_queue;
132 std::set<StateID> inverse_visited;
133 inverse_queue.push_back(exiting_id);
134 inverse_visited.insert(exiting_id);
137 const StateID next_id = inverse_queue.front();
138 inverse_queue.pop_front();
140 for (
const auto& [pred_id, act] : parents[next_id]) {
141 if (inverse_visited.insert(pred_id).second) {
142 policy->emplace_decision(pred_id, act, quotient_bound);
143 inverse_queue.push_back(pred_id);
146 }
while (!inverse_queue.empty());
150 Distribution<StateID> successors;
151 quotient.generate_action_transitions(
156 for (
const StateID succ_id : successors.support()) {
157 if (visited.insert(succ_id).second) {
158 queue.push_back(succ_id);
161 }
while (!queue.empty());
170 typename GreedyGraphGenerator>
171Interval FRET<State, Action, StateInfoT, GreedyGraphGenerator>::solve(
173 EvaluatorType& heuristic,
174 param_type<State> state,
175 ProgressReport progress,
178 QuotientSystem quotient(mdp);
179 quotients::QuotientMaxHeuristic<State, Action> qheuristic(heuristic);
183 quotient.translate_state(state),
192 typename GreedyGraphGenerator>
194 std::ostream& out)
const
197 statistics_.print(out);
204 typename GreedyGraphGenerator>
206 QuotientSystem& quotient,
207 QEvaluator& heuristic,
212 utils::CountdownTimer timer(max_time);
215 out <<
"fret=" << statistics_.iterations
216 <<
", traps=" << statistics_.traps;
221 heuristic_search(quotient, heuristic, state, progress, timer);
223 if (find_and_remove_traps(quotient, state, timer)) {
227 base_algorithm_->reset_search_state();
235 typename GreedyGraphGenerator>
237FRET<State, Action, StateInfoT, GreedyGraphGenerator>::heuristic_search(
238 QuotientSystem& quotient,
239 QEvaluator& heuristic,
240 param_type<QState> state,
241 ProgressReport& progress,
242 utils::CountdownTimer& timer)
244#if defined(EXPENSIVE_STATISTICS)
245 TimerScope scoped(statistics_.heuristic_search);
248 return base_algorithm_->solve(
253 timer.get_remaining_time());
260 typename GreedyGraphGenerator>
261bool FRET<State, Action, StateInfoT, GreedyGraphGenerator>::
262 find_and_remove_traps(
263 QuotientSystem& quotient,
264 param_type<QState> state,
265 utils::CountdownTimer& timer)
267 using namespace internal;
269#if defined(EXPENSIVE_STATISTICS)
270 TimerScope scoped(statistics_.trap_identification);
272 unsigned int trap_counter = 0;
273 unsigned int unexpanded = 0;
275 storage::StateHashMap<TarjanStateInformation> state_infos;
276 std::deque<ExplorationInfo> exploration_queue;
277 std::deque<StackInfo> stack;
279 StateID state_id = quotient.get_state_id(state);
280 TarjanStateInformation* sinfo = &state_infos[state_id];
289 return unexpanded == 0;
292 ExplorationInfo* einfo = &exploration_queue.back();
296 timer.throw_if_expired();
298 const StateID succid = einfo->successors.back();
299 TarjanStateInformation& succ_info = state_infos[succid];
301 if (succ_info.is_on_stack()) {
303 std::min(sinfo->lowlink, succ_info.stack_index);
305 !succ_info.is_explored() && push(
312 einfo = &exploration_queue.back();
313 state_id = einfo->state_id;
314 sinfo = &state_infos[state_id];
317 einfo->is_leaf =
false;
320 einfo->successors.pop_back();
321 }
while (!einfo->successors.empty());
324 const unsigned last_lowlink = sinfo->lowlink;
325 const bool scc_found = last_lowlink == sinfo->stack_index;
326 const bool can_reach_child_scc = scc_found || !einfo->is_leaf;
329 auto scc = stack | std::views::drop(sinfo->stack_index);
331 for (
const auto& info : scc) {
332 state_infos[info.state_id].close();
335 if (einfo->is_leaf) {
337 assert(scc.size() > 1);
339#if defined(EXPENSIVE_STATISTICS)
340 TimerScope t(statistics_.trap_removal);
342 quotient.build_quotient(scc, *scc.begin());
345 auto& base_info = base_algorithm_->state_infos_[state_id];
346 base_info.set_on_fringe();
347 base_algorithm_->update_policy(base_info, std::nullopt);
353 stack.erase(scc.begin(), scc.end());
356 exploration_queue.pop_back();
358 if (exploration_queue.empty()) {
359 ++statistics_.iterations;
360 return trap_counter == 0 && unexpanded == 0;
363 timer.throw_if_expired();
365 einfo = &exploration_queue.back();
366 state_id = einfo->state_id;
367 sinfo = &state_infos[state_id];
369 sinfo->lowlink = std::min(sinfo->lowlink, last_lowlink);
370 if (can_reach_child_scc) {
371 einfo->is_leaf =
false;
374 einfo->successors.pop_back();
375 }
while (einfo->successors.empty());
383 typename GreedyGraphGenerator>
384bool FRET<State, Action, StateInfoT, GreedyGraphGenerator>::push(
385 QuotientSystem& quotient,
386 std::deque<internal::ExplorationInfo>& queue,
387 std::deque<StackInfo>& stack,
388 internal::TarjanStateInformation& info,
390 unsigned int& unexpanded)
392 const auto& state_info = base_algorithm_->state_infos_[state_id];
394 if (state_info.is_goal_or_terminal()) {
398 GreedyGraphGenerator greedy_graph;
399 std::vector<QAction> aops;
400 std::vector<StateID> succs;
401 if (greedy_graph.get_successors(
414 info.open(stack.size());
415 stack.emplace_back(state_id, std::move(aops));
416 queue.emplace_back(state_id, std::move(succs));
420template <
typename State,
typename Action,
typename StateInfoT>
421bool ValueGraph<State, Action, StateInfoT>::get_successors(
422 QuotientSystem& quotient,
423 QHeuristicSearchAlgorithm& base_algorithm,
425 std::vector<QAction>& aops,
426 std::vector<StateID>& successors)
428 assert(successors.empty());
430 auto& state_info = base_algorithm.state_infos_[qstate];
432 const QState state = quotient.get_state(qstate);
433 const value_t termination_cost =
434 quotient.get_termination_info(state).get_cost();
436 ClearGuard _(opt_transitions_, ids_, q_values);
437 base_algorithm.generate_non_tip_transitions(
442 auto value = base_algorithm.compute_bellman_and_greedy(
449 bool value_changed = base_algorithm.update_value(state_info, value);
451 for (
const auto& transition : opt_transitions_) {
452 aops.push_back(transition.action);
454 for (
const StateID sid : transition.successor_dist.support()) {
455 if (ids_.insert(sid).second) {
456 successors.push_back(sid);
461 return value_changed;
464template <
typename State,
typename Action,
typename StateInfoT>
465bool PolicyGraph<State, Action, StateInfoT>::get_successors(
466 QuotientSystem& quotient,
467 QHeuristicSearchAlgorithm& base_algorithm,
468 StateID quotient_state_id,
469 std::vector<QAction>& aops,
470 std::vector<StateID>& successors)
472 auto& base_info = base_algorithm.state_infos_[quotient_state_id];
473 auto a = base_info.get_policy();
475 if (!a.has_value())
return false;
479 const QState quotient_state = quotient.get_state(quotient_state_id);
480 quotient.generate_action_transitions(quotient_state, *a, t_);
482 for (StateID sid : t_.support()) {
483 successors.push_back(sid);
A registry for print functions related to search progress.
Definition progress_report.h:33
void register_print(Printer f)
Appends a new printer to the list of printers.
Implemetation of the Find-Revise-Eliminate-Traps (FRET) framework kolobov:etal:icaps-11 .
Definition heuristic_search_base.h:39
void print_statistics(std::ostream &out) const override
Prints algorithm statistics to the specified output stream.
Definition fret_impl.h:193
Namespace dedicated to the Find, Revise, Eliminate Traps (FRET) framework.
Definition fret.h:23
double value_t
Typedef for the state value type.
Definition aliases.h:7
typename std::conditional_t< is_cheap_to_copy_v< T >, T, const T & > param_type
Alias template defining the best way to pass a parameter of a given type.
Definition type_traits.h:25
Represents a closed interval over the extended reals as a pair of lower and upper bound.
Definition interval.h:12