AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
fret_impl.h
1#ifndef GUARD_INCLUDE_PROBFD_ALGORITHMS_FRET_H
2#error "This file should only be included from fret.h"
3#endif
4
5#include "probfd/policies/map_policy.h"
6
7#include "probfd/quotients/quotient_max_heuristic.h"
8
9#include "downward/utils/countdown_timer.h"
10
12
13namespace internal {
14
15inline void Statistics::print(std::ostream& out) const
16{
17 out << " FRET iterations: " << iterations << std::endl;
18#if defined(EXPENSIVE_STATISTICS)
19 out << " Heuristic search: " << heuristic_search << std::endl;
20 out << " Trap identification: " << (trap_identification() - trap_removal())
21 << std::endl;
22 out << " Trap removal: " << trap_removal << std::endl;
23#endif
24}
25
26} // namespace internal
27
28template <
29 typename State,
30 typename Action,
31 typename StateInfoT,
32 typename GreedyGraphGenerator>
33FRET<State, Action, StateInfoT, GreedyGraphGenerator>::FRET(
34 std::shared_ptr<QHeuristicSearchAlgorithm> algorithm)
35 : base_algorithm_(std::move(algorithm))
36{
37}
38
39template <
40 typename State,
41 typename Action,
42 typename StateInfoT,
43 typename GreedyGraphGenerator>
44auto FRET<State, Action, StateInfoT, GreedyGraphGenerator>::compute_policy(
45 MDPType& mdp,
46 EvaluatorType& heuristic,
47 param_type<State> state,
48 ProgressReport progress,
49 double max_time) -> std::unique_ptr<PolicyType>
50{
51 QuotientSystem quotient(mdp);
52 quotients::QuotientMaxHeuristic<State, Action> qheuristic(heuristic);
53 this->solve(
54 quotient,
55 qheuristic,
56 quotient.translate_state(state),
57 progress,
58 max_time);
59
60 /*
61 * The quotient policy only specifies the optimal actions between traps.
62 * We need to supplement the optimal actions within the traps, i.e.
63 * the actions which point every other member state of the trap towards
64 * that trap member state that owns the optimal quotient action.
65 *
66 * We fully explore the quotient policy starting from the initial state
67 * and compute the optimal 'inner' actions for each trap. To this end,
68 * we first generate the sub-MDP of the trap. Afterwards, we expand the
69 * trap graph backwards from the state that has the optimal quotient
70 * action. For each encountered state, we select the action with which
71 * it is encountered first as the policy action.
72 */
73
74 std::unique_ptr<policies::MapPolicy<State, Action>> policy(
75 new policies::MapPolicy<State, Action>(&mdp));
76
77 const StateID initial_state_id = mdp.get_state_id(state);
78
79 std::deque<StateID> queue;
80 std::set<StateID> visited;
81 queue.push_back(initial_state_id);
82 visited.insert(initial_state_id);
83
84 do {
85 const StateID quotient_id = queue.front();
86 const QState quotient_state = quotient.get_state(quotient_id);
87 queue.pop_front();
88
89 auto& base_info = base_algorithm_->state_infos_[quotient_id];
90 std::optional quotient_action = base_info.get_policy();
91
92 // Terminal states have no policy decision.
93 if (!quotient_action) {
94 continue;
95 }
96
97 const Interval quotient_bound =
98 base_algorithm_->lookup_bounds(quotient_id);
99
100 const StateID exiting_id = quotient_action->state_id;
101
102 policy->emplace_decision(
103 exiting_id,
104 quotient_action->action,
105 quotient_bound);
106
107 // Nothing else needs to be done if the trap has only one state.
108 if (quotient_state.num_members() != 1) {
109 std::unordered_map<StateID, std::set<QAction>> parents;
110
111 // Build the inverse graph
112 std::vector<QAction> inner_actions;
113 quotient_state.get_collapsed_actions(inner_actions);
114
115 for (const QAction& qaction : inner_actions) {
116 StateID source_id = qaction.state_id;
117 Action action = qaction.action;
118
119 const State source = mdp.get_state(source_id);
120
121 Distribution<StateID> successors;
122 mdp.generate_action_transitions(source, action, successors);
123
124 for (const StateID succ_id : successors.support()) {
125 parents[succ_id].insert(qaction);
126 }
127 }
128
129 // Now traverse the inverse graph starting from the exiting
130 // state
131 std::deque<StateID> inverse_queue;
132 std::set<StateID> inverse_visited;
133 inverse_queue.push_back(exiting_id);
134 inverse_visited.insert(exiting_id);
135
136 do {
137 const StateID next_id = inverse_queue.front();
138 inverse_queue.pop_front();
139
140 for (const auto& [pred_id, act] : parents[next_id]) {
141 if (inverse_visited.insert(pred_id).second) {
142 policy->emplace_decision(pred_id, act, quotient_bound);
143 inverse_queue.push_back(pred_id);
144 }
145 }
146 } while (!inverse_queue.empty());
147 }
148
149 // Push the successor traps.
150 Distribution<StateID> successors;
151 quotient.generate_action_transitions(
152 quotient_state,
153 *quotient_action,
154 successors);
155
156 for (const StateID succ_id : successors.support()) {
157 if (visited.insert(succ_id).second) {
158 queue.push_back(succ_id);
159 }
160 }
161 } while (!queue.empty());
162
163 return policy;
164}
165
166template <
167 typename State,
168 typename Action,
169 typename StateInfoT,
170 typename GreedyGraphGenerator>
171Interval FRET<State, Action, StateInfoT, GreedyGraphGenerator>::solve(
172 MDPType& mdp,
173 EvaluatorType& heuristic,
174 param_type<State> state,
175 ProgressReport progress,
176 double max_time)
177{
178 QuotientSystem quotient(mdp);
179 quotients::QuotientMaxHeuristic<State, Action> qheuristic(heuristic);
180 return solve(
181 quotient,
182 qheuristic,
183 quotient.translate_state(state),
184 progress,
185 max_time);
186}
187
188template <
189 typename State,
190 typename Action,
191 typename StateInfoT,
192 typename GreedyGraphGenerator>
194 std::ostream& out) const
195{
196 this->base_algorithm_->print_statistics(out);
197 statistics_.print(out);
198}
199
200template <
201 typename State,
202 typename Action,
203 typename StateInfoT,
204 typename GreedyGraphGenerator>
206 QuotientSystem& quotient,
207 QEvaluator& heuristic,
208 param_type<QState> state,
209 ProgressReport& progress,
210 double max_time)
211{
212 utils::CountdownTimer timer(max_time);
213
214 progress.register_print([&](std::ostream& out) {
215 out << "fret=" << statistics_.iterations
216 << ", traps=" << statistics_.traps;
217 });
218
219 for (;;) {
220 const Interval value =
221 heuristic_search(quotient, heuristic, state, progress, timer);
222
223 if (find_and_remove_traps(quotient, state, timer)) {
224 return value;
225 }
226
227 base_algorithm_->reset_search_state();
228 }
229}
230
231template <
232 typename State,
233 typename Action,
234 typename StateInfoT,
235 typename GreedyGraphGenerator>
236Interval
237FRET<State, Action, StateInfoT, GreedyGraphGenerator>::heuristic_search(
238 QuotientSystem& quotient,
239 QEvaluator& heuristic,
240 param_type<QState> state,
241 ProgressReport& progress,
242 utils::CountdownTimer& timer)
243{
244#if defined(EXPENSIVE_STATISTICS)
245 TimerScope scoped(statistics_.heuristic_search);
246#endif
247
248 return base_algorithm_->solve(
249 quotient,
250 heuristic,
251 state,
252 progress,
253 timer.get_remaining_time());
254}
255
256template <
257 typename State,
258 typename Action,
259 typename StateInfoT,
260 typename GreedyGraphGenerator>
261bool FRET<State, Action, StateInfoT, GreedyGraphGenerator>::
262 find_and_remove_traps(
263 QuotientSystem& quotient,
264 param_type<QState> state,
265 utils::CountdownTimer& timer)
266{
267 using namespace internal;
268
269#if defined(EXPENSIVE_STATISTICS)
270 TimerScope scoped(statistics_.trap_identification);
271#endif
272 unsigned int trap_counter = 0;
273 unsigned int unexpanded = 0;
274
275 storage::StateHashMap<TarjanStateInformation> state_infos;
276 std::deque<ExplorationInfo> exploration_queue;
277 std::deque<StackInfo> stack;
278
279 StateID state_id = quotient.get_state_id(state);
280 TarjanStateInformation* sinfo = &state_infos[state_id];
281
282 if (!push(
283 quotient,
284 exploration_queue,
285 stack,
286 *sinfo,
287 state_id,
288 unexpanded)) {
289 return unexpanded == 0;
290 }
291
292 ExplorationInfo* einfo = &exploration_queue.back();
293
294 for (;;) {
295 do {
296 timer.throw_if_expired();
297
298 const StateID succid = einfo->successors.back();
299 TarjanStateInformation& succ_info = state_infos[succid];
300
301 if (succ_info.is_on_stack()) {
302 sinfo->lowlink =
303 std::min(sinfo->lowlink, succ_info.stack_index);
304 } else if (
305 !succ_info.is_explored() && push(
306 quotient,
307 exploration_queue,
308 stack,
309 succ_info,
310 succid,
311 unexpanded)) {
312 einfo = &exploration_queue.back();
313 state_id = einfo->state_id;
314 sinfo = &state_infos[state_id];
315 continue;
316 } else {
317 einfo->is_leaf = false;
318 }
319
320 einfo->successors.pop_back();
321 } while (!einfo->successors.empty());
322
323 do {
324 const unsigned last_lowlink = sinfo->lowlink;
325 const bool scc_found = last_lowlink == sinfo->stack_index;
326 const bool can_reach_child_scc = scc_found || !einfo->is_leaf;
327
328 if (scc_found) {
329 auto scc = stack | std::views::drop(sinfo->stack_index);
330
331 for (const auto& info : scc) {
332 state_infos[info.state_id].close();
333 }
334
335 if (einfo->is_leaf) {
336 // Terminal and self-loop leaf SCCs are always pruned
337 assert(scc.size() > 1);
338 {
339#if defined(EXPENSIVE_STATISTICS)
340 TimerScope t(statistics_.trap_removal);
341#endif
342 quotient.build_quotient(scc, *scc.begin());
343 }
344
345 auto& base_info = base_algorithm_->state_infos_[state_id];
346 base_info.set_on_fringe();
347 base_algorithm_->update_policy(base_info, std::nullopt);
348
349 ++statistics_.traps;
350 ++trap_counter;
351 }
352
353 stack.erase(scc.begin(), scc.end());
354 }
355
356 exploration_queue.pop_back();
357
358 if (exploration_queue.empty()) {
359 ++statistics_.iterations;
360 return trap_counter == 0 && unexpanded == 0;
361 }
362
363 timer.throw_if_expired();
364
365 einfo = &exploration_queue.back();
366 state_id = einfo->state_id;
367 sinfo = &state_infos[state_id];
368
369 sinfo->lowlink = std::min(sinfo->lowlink, last_lowlink);
370 if (can_reach_child_scc) {
371 einfo->is_leaf = false;
372 }
373
374 einfo->successors.pop_back();
375 } while (einfo->successors.empty());
376 }
377}
378
379template <
380 typename State,
381 typename Action,
382 typename StateInfoT,
383 typename GreedyGraphGenerator>
384bool FRET<State, Action, StateInfoT, GreedyGraphGenerator>::push(
385 QuotientSystem& quotient,
386 std::deque<internal::ExplorationInfo>& queue,
387 std::deque<StackInfo>& stack,
388 internal::TarjanStateInformation& info,
389 StateID state_id,
390 unsigned int& unexpanded)
391{
392 const auto& state_info = base_algorithm_->state_infos_[state_id];
393
394 if (state_info.is_goal_or_terminal()) {
395 return false;
396 }
397
398 GreedyGraphGenerator greedy_graph;
399 std::vector<QAction> aops;
400 std::vector<StateID> succs;
401 if (greedy_graph.get_successors(
402 quotient,
403 *base_algorithm_,
404 state_id,
405 aops,
406 succs)) {
407 ++unexpanded;
408 }
409
410 if (succs.empty()) {
411 return false;
412 }
413
414 info.open(stack.size());
415 stack.emplace_back(state_id, std::move(aops));
416 queue.emplace_back(state_id, std::move(succs));
417 return true;
418}
419
420template <typename State, typename Action, typename StateInfoT>
421bool ValueGraph<State, Action, StateInfoT>::get_successors(
422 QuotientSystem& quotient,
423 QHeuristicSearchAlgorithm& base_algorithm,
424 StateID qstate,
425 std::vector<QAction>& aops,
426 std::vector<StateID>& successors)
427{
428 assert(successors.empty());
429
430 auto& state_info = base_algorithm.state_infos_[qstate];
431
432 const QState state = quotient.get_state(qstate);
433 const value_t termination_cost =
434 quotient.get_termination_info(state).get_cost();
435
436 ClearGuard _(opt_transitions_, ids_, q_values);
437 base_algorithm.generate_non_tip_transitions(
438 quotient,
439 state,
440 opt_transitions_);
441
442 auto value = base_algorithm.compute_bellman_and_greedy(
443 quotient,
444 qstate,
445 opt_transitions_,
446 termination_cost,
447 q_values);
448
449 bool value_changed = base_algorithm.update_value(state_info, value);
450
451 for (const auto& transition : opt_transitions_) {
452 aops.push_back(transition.action);
453
454 for (const StateID sid : transition.successor_dist.support()) {
455 if (ids_.insert(sid).second) {
456 successors.push_back(sid);
457 }
458 }
459 }
460
461 return value_changed;
462}
463
464template <typename State, typename Action, typename StateInfoT>
465bool PolicyGraph<State, Action, StateInfoT>::get_successors(
466 QuotientSystem& quotient,
467 QHeuristicSearchAlgorithm& base_algorithm,
468 StateID quotient_state_id,
469 std::vector<QAction>& aops,
470 std::vector<StateID>& successors)
471{
472 auto& base_info = base_algorithm.state_infos_[quotient_state_id];
473 auto a = base_info.get_policy();
474
475 if (!a.has_value()) return false;
476
477 ClearGuard _(t_);
478
479 const QState quotient_state = quotient.get_state(quotient_state_id);
480 quotient.generate_action_transitions(quotient_state, *a, t_);
481
482 for (StateID sid : t_.support()) {
483 successors.push_back(sid);
484 }
485
486 aops.push_back(*a);
487
488 return false;
489}
490
491} // namespace probfd::algorithms::fret
A registry for print functions related to search progress.
Definition progress_report.h:33
void register_print(Printer f)
Appends a new printer to the list of printers.
Implemetation of the Find-Revise-Eliminate-Traps (FRET) framework kolobov:etal:icaps-11 .
Definition heuristic_search_base.h:39
void print_statistics(std::ostream &out) const override
Prints algorithm statistics to the specified output stream.
Definition fret_impl.h:193
Namespace dedicated to the Find, Revise, Eliminate Traps (FRET) framework.
Definition fret.h:23
double value_t
Typedef for the state value type.
Definition aliases.h:7
typename std::conditional_t< is_cheap_to_copy_v< T >, T, const T & > param_type
Alias template defining the best way to pass a parameter of a given type.
Definition type_traits.h:25
STL namespace.
Represents a closed interval over the extended reals as a pair of lower and upper bound.
Definition interval.h:12