AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
trap_aware_lrtdp_impl.h
1#ifndef GUARD_INCLUDE_PROBFD_ALGORITHMS_TRAP_AWARE_LRTDP_H
2#error "This file should only be included from trap_aware_lrtdp.h"
3#endif
4
5#include "probfd/algorithms/successor_sampler.h"
6
7#include "probfd/quotients/quotient_max_heuristic.h"
8
9#include "probfd/utils/guards.h"
10
11#include "downward/utils/countdown_timer.h"
12
14
15namespace internal {
16
17inline void Statistics::print(std::ostream& out) const
18{
19 out << " Trials: " << trials << std::endl;
20 out << " Average trial length: "
21 << (static_cast<double>(trial_length) / static_cast<double>(trials))
22 << std::endl;
23 out << " Bellman backups (trials): " << trial_bellman_backups << std::endl;
24 out << " Bellman backups (check&solved): "
25 << check_and_solve_bellman_backups << std::endl;
26 out << " Trap removals: " << traps << std::endl;
27 out << " Trap removal time: " << trap_timer << std::endl;
28}
29
30inline void Statistics::register_report(ProgressReport& report) const
31{
32 report.register_print([this](std::ostream& out) {
33 out << "traps=" << traps << ", trials=" << trials;
34 });
35}
36
37} // namespace internal
38
39template <typename State, typename Action, bool UseInterval>
40bool TALRTDPImpl<State, Action, UseInterval>::ExplorationInformation::
41 next_successor()
42{
43 successors.pop_back();
44 return !successors.empty();
45}
46
47template <typename State, typename Action, bool UseInterval>
48StateID
49TALRTDPImpl<State, Action, UseInterval>::ExplorationInformation::get_successor()
50 const
51{
52 return successors.back();
53}
54
55template <typename State, typename Action, bool UseInterval>
56TALRTDPImpl<State, Action, UseInterval>::TALRTDPImpl(
57 std::shared_ptr<QuotientPolicyPicker> policy_chooser,
58 TrialTerminationCondition stop_consistent,
59 bool reexpand_traps,
60 std::shared_ptr<QuotientSuccessorSampler> succ_sampler)
61 : Base(policy_chooser)
62 , stop_at_consistent_(stop_consistent)
63 , reexpand_traps_(reexpand_traps)
64 , sample_(succ_sampler)
65 , stack_index_(STATE_UNSEEN)
66{
67}
68
69template <typename State, typename Action, bool UseInterval>
70Interval TALRTDPImpl<State, Action, UseInterval>::solve_quotient(
71 QuotientSystem& quotient,
72 QEvaluator& heuristic,
73 param_type<QState> state,
74 ProgressReport& progress,
75 double max_time)
76{
77 utils::CountdownTimer timer(max_time);
78
79 Base::initialize_initial_state(quotient, heuristic, state);
80
81 const StateID state_id = quotient.get_state_id(state);
82 const StateInfo& state_info = this->state_infos_[state_id];
83
84 progress.register_bound("v", [&state_info]() {
85 return as_interval(state_info.value);
86 });
87
88 this->statistics_.register_report(progress);
89
90 bool terminate;
91 do {
92 terminate = trial(quotient, heuristic, state_id, timer);
93 assert(state_id == quotient.translate_state_id(state_id));
94 statistics_.trials++;
95 progress.print();
96 } while (!terminate);
97
98 return state_info.get_bounds();
99}
100
101template <typename State, typename Action, bool UseInterval>
102void TALRTDPImpl<State, Action, UseInterval>::print_statistics(
103 std::ostream& out) const
104{
105 this->statistics_.print(out);
106}
107
108template <typename State, typename Action, bool UseInterval>
109bool TALRTDPImpl<State, Action, UseInterval>::trial(
110 QuotientSystem& quotient,
111 QEvaluator& heuristic,
112 StateID start_state,
113 utils::CountdownTimer& timer)
114{
115 using enum TrialTerminationCondition;
116
117 assert(current_trial_.empty());
118
119 ClearGuard guard(current_trial_);
120 current_trial_.push_back(start_state);
121 for (;;) {
122 timer.throw_if_expired();
123
124 StateID stateid = current_trial_.back();
125 auto& info = this->state_infos_[stateid];
126
127 if (info.is_solved()) {
128 current_trial_.pop_back();
129 break;
130 }
131
132 const QState state = quotient.get_state(stateid);
133 const value_t termination_cost =
134 quotient.get_termination_info(state).get_cost();
135
136 ClearGuard _(transitions_, qvalues_);
137 if (info.is_on_fringe()) {
138 this->expand_and_initialize(
139 quotient,
140 heuristic,
141 state,
142 info,
143 transitions_);
144 } else {
145 this->generate_non_tip_transitions(quotient, state, transitions_);
146 }
147
148 const auto value = this->compute_bellman_and_greedy(
149 quotient,
150 stateid,
151 transitions_,
152 termination_cost,
153 qvalues_);
154
155 statistics_.trial_bellman_backups++;
156
157 auto transition = this->select_greedy_transition(
158 quotient,
159 info.get_policy(),
160 transitions_);
161
162 bool value_changed = this->update_value(info, value);
163 this->update_policy(info, transition);
164
165 if (!transition.has_value()) {
166 info.set_solved();
167 current_trial_.pop_back();
168 break;
169 }
170
171 if ((stop_at_consistent_ == CONSISTENT && !value_changed) ||
172 (stop_at_consistent_ == INCONSISTENT && value_changed) ||
173 (stop_at_consistent_ == REVISITED && info.is_on_trial())) {
174 break;
175 }
176
177 if (stop_at_consistent_ == REVISITED) {
178 info.set_on_trial();
179 }
180
181 auto next = sample_->sample(
182 stateid,
183 transition->action,
184 transition->successor_dist,
185 this->state_infos_);
186
187 current_trial_.push_back(next);
188 }
189
190 statistics_.trial_length += current_trial_.size();
191 if (stop_at_consistent_ == REVISITED) {
192 for (const StateID state : current_trial_) {
193 this->state_infos_[state].clear_trial_flag();
194 }
195 }
196
197 do {
198 timer.throw_if_expired();
199
200 if (!check_and_solve(quotient, heuristic, timer)) {
201 return false;
202 }
203
204 current_trial_.pop_back();
205 } while (!current_trial_.empty());
206
207 return true;
208}
209
210template <typename State, typename Action, bool UseInterval>
211bool TALRTDPImpl<State, Action, UseInterval>::check_and_solve(
212 QuotientSystem& quotient,
213 QEvaluator& heuristic,
214 utils::CountdownTimer& timer)
215{
216 assert(!this->current_trial_.empty());
217
218 push(quotient.translate_state_id(this->current_trial_.back()));
219
220 ExplorationInformation* einfo;
221 StateInfo* sinfo;
222
223 for (;;) {
224 do {
225 einfo = &queue_.back();
226 sinfo = &this->state_infos_[einfo->state];
227 } while (this->initialize(
228 quotient,
229 heuristic,
230 einfo->state,
231 *sinfo,
232 *einfo) &&
233 this->push_successor(quotient, *einfo, timer));
234
235 do {
236 if (einfo->is_root) {
237 const StateID state_id = einfo->state;
238 const unsigned stack_index = stack_index_[state_id];
239 auto scc = stack_ | std::views::drop(stack_index);
240
241 if (einfo->is_trap && scc.size() > 1) {
242 assert(einfo->rv);
243 for (const auto& entry : scc) {
244 stack_index_[entry.state_id] = STATE_CLOSED;
245 }
246
247 TimerScope scope(statistics_.trap_timer);
248 quotient.build_quotient(scc, *scc.begin());
249 sinfo->update_policy(std::nullopt);
250 ++statistics_.traps;
251 stack_.erase(scc.begin(), scc.end());
252
253 if (reexpand_traps_) {
254 queue_.pop_back();
255 push(state_id);
256 break;
257 }
258
259 ++statistics_.check_and_solve_bellman_backups;
260
261 const QState state = quotient.get_state(state_id);
262 const value_t termination_cost =
263 quotient.get_termination_info(state).get_cost();
264
265 {
266 ClearGuard _(transitions_, qvalues_);
267 this->generate_non_tip_transitions(
268 quotient,
269 state,
270 transitions_);
271
272 auto value = this->compute_bellman_and_greedy(
273 quotient,
274 state_id,
275 transitions_,
276 termination_cost,
277 qvalues_);
278
279 auto transition = this->select_greedy_transition(
280 quotient,
281 sinfo->get_policy(),
282 transitions_);
283
284 this->update_value(*sinfo, value);
285 this->update_policy(*sinfo, transition);
286 }
287
288 einfo->rv = false;
289 } else {
290 for (const auto& entry : scc) {
291 const StateID id = entry.state_id;
292 StateInfo& info = this->state_infos_[id];
293 stack_index_[id] = STATE_CLOSED;
294 if (info.is_solved()) continue;
295 if (einfo->rv) {
296 info.set_solved();
297 } else {
298 const QState state = quotient.get_state(id);
299 const value_t termination_cost =
300 quotient.get_termination_info(state).get_cost();
301
302 ClearGuard _(transitions_, qvalues_);
303 this->generate_non_tip_transitions(
304 quotient,
305 state,
306 transitions_);
307
308 ++this->statistics_.check_and_solve_bellman_backups;
309
310 auto value = this->compute_bellman_and_greedy(
311 quotient,
312 id,
313 transitions_,
314 termination_cost,
315 qvalues_);
316
317 auto transition = this->select_greedy_transition(
318 quotient,
319 info.get_policy(),
320 transitions_);
321
322 this->update_value(info, value);
323 this->update_policy(info, transition);
324 }
325 }
326 stack_.erase(scc.begin(), scc.end());
327 }
328
329 einfo->is_trap = false;
330 }
331
332 ExplorationInformation bt_einfo = std::move(*einfo);
333
334 queue_.pop_back();
335
336 if (queue_.empty()) {
337 assert(stack_.empty());
338 stack_index_.clear();
339 return sinfo->is_solved();
340 }
341
342 timer.throw_if_expired();
343
344 einfo = &queue_.back();
345 sinfo = &this->state_infos_[einfo->state];
346
347 einfo->update(bt_einfo);
348 } while (!einfo->next_successor() ||
349 !this->push_successor(quotient, *einfo, timer));
350 }
351}
352
353template <typename State, typename Action, bool UseInterval>
354bool TALRTDPImpl<State, Action, UseInterval>::push_successor(
355 QuotientSystem& quotient,
356 ExplorationInformation& einfo,
357 utils::CountdownTimer& timer)
358{
359 do {
360 timer.throw_if_expired();
361
362 const StateID succ = quotient.translate_state_id(einfo.get_successor());
363 StateInfo& succ_info = this->state_infos_[succ];
364 int& sidx = stack_index_[succ];
365 if (sidx == STATE_UNSEEN) {
366 push(succ);
367 return true;
368 } else if (sidx >= 0) {
369 int& sidx2 = stack_index_[einfo.state];
370 if (sidx < sidx2) {
371 sidx2 = sidx;
372 einfo.is_root = false;
373 }
374 } else {
375 einfo.update(succ_info);
376 }
377 } while (einfo.next_successor());
378
379 return false;
380}
381
382template <typename State, typename Action, bool UseInterval>
383void TALRTDPImpl<State, Action, UseInterval>::push(StateID state)
384{
385 queue_.emplace_back(state);
386 stack_index_[state] = stack_.size();
387 stack_.emplace_back(state);
388}
389
390template <typename State, typename Action, bool UseInterval>
391bool TALRTDPImpl<State, Action, UseInterval>::initialize(
392 QuotientSystem& quotient,
393 QEvaluator& heuristic,
394 StateID state_id,
395 StateInfo& state_info,
396 ExplorationInformation& e_info)
397{
398 assert(quotient.translate_state_id(state_id) == state_id);
399
400 if (state_info.is_solved()) {
401 e_info.is_trap = false;
402 return false;
403 }
404
405 const QState state = quotient.get_state(state_id);
406 const value_t termination_cost =
407 quotient.get_termination_info(state).get_cost();
408
409 ClearGuard _(transitions_, qvalues_);
410
411 if (state_info.is_on_fringe()) {
412 this->expand_and_initialize(
413 quotient,
414 heuristic,
415 state,
416 state_info,
417 transitions_);
418 } else {
419 this->generate_non_tip_transitions(quotient, state, transitions_);
420 }
421
422 ++this->statistics_.check_and_solve_bellman_backups;
423
424 const auto value = this->compute_bellman_and_greedy(
425 quotient,
426 state_id,
427 transitions_,
428 termination_cost,
429 qvalues_);
430
431 auto transition = this->select_greedy_transition(
432 quotient,
433 state_info.get_policy(),
434 transitions_);
435
436 bool value_changed = this->update_value(state_info, value);
437 this->update_policy(state_info, transition);
438
439 if (!transition) {
440 e_info.rv = e_info.rv && !value_changed;
441 e_info.is_trap = false;
442 return false;
443 }
444
445 if (value_changed) {
446 e_info.rv = false;
447 e_info.is_trap = false;
448 return false;
449 }
450
451 for (const StateID sel : transition->successor_dist.support()) {
452 if (sel != state_id) {
453 e_info.successors.push_back(sel);
454 }
455 }
456
457 assert(!e_info.successors.empty());
458 e_info.is_trap = quotient.get_action_cost(transition->action) == 0;
459 stack_.back().aops.emplace_back(transition->action);
460 return true;
461}
462
463template <typename State, typename Action, bool UseInterval>
464TALRTDP<State, Action, UseInterval>::TALRTDP(
465 std::shared_ptr<QuotientPolicyPicker> policy_chooser,
466 TrialTerminationCondition stop_consistent,
467 bool reexpand_traps,
468 std::shared_ptr<QuotientSuccessorSampler> succ_sampler)
469 : algorithm_(policy_chooser, stop_consistent, reexpand_traps, succ_sampler)
470{
471}
472
473template <typename State, typename Action, bool UseInterval>
474Interval TALRTDP<State, Action, UseInterval>::solve(
475 MDPType& mdp,
476 EvaluatorType& heuristic,
477 param_type<State> s,
478 ProgressReport progress,
479 double max_time)
480{
481 QuotientSystem quotient(mdp);
482 quotients::QuotientMaxHeuristic<State, Action> qheuristic(heuristic);
483 return algorithm_.solve_quotient(
484 quotient,
485 qheuristic,
486 quotient.translate_state(s),
487 progress,
488 max_time);
489}
490
491template <typename State, typename Action, bool UseInterval>
492auto TALRTDP<State, Action, UseInterval>::compute_policy(
493 MDPType& mdp,
494 EvaluatorType& heuristic,
495 param_type<State> state,
496 ProgressReport progress,
497 double max_time) -> std::unique_ptr<PolicyType>
498{
499 QuotientSystem quotient(mdp);
500 quotients::QuotientMaxHeuristic<State, Action> qheuristic(heuristic);
501
502 QState qinit = quotient.translate_state(state);
503 algorithm_.solve_quotient(quotient, qheuristic, qinit, progress, max_time);
504
505 /*
506 * The quotient policy only specifies the optimal actions between
507 * traps. We need to supplement the optimal actions within the
508 * traps, i.e. the actions which point every other member state of
509 * the trap towards that trap member state that owns the optimal
510 * quotient action.
511 *
512 * We fully explore the quotient policy starting from the initial
513 * state and compute the optimal 'inner' actions for each trap. To
514 * this end, we first generate the sub-MDP of the trap. Afterwards,
515 * we expand the trap graph backwards from the state that has the
516 * optimal quotient action. For each encountered state, we select
517 * the action with which it is encountered first as the policy
518 * action.
519 */
520 using MapPolicy = policies::MapPolicy<State, Action>;
521 std::unique_ptr<MapPolicy> policy(new MapPolicy(&mdp));
522
523 const StateID initial_state_id = quotient.get_state_id(qinit);
524
525 std::deque<StateID> queue({initial_state_id});
526 std::set<StateID> visited({initial_state_id});
527
528 do {
529 const StateID quotient_id = queue.front();
530 const QState quotient_state = quotient.get_state(quotient_id);
531 queue.pop_front();
532
533 const auto& state_info = algorithm_.state_infos_[quotient_id];
534
535 std::optional quotient_action = state_info.get_policy();
536
537 // Terminal states have no policy decision.
538 if (!quotient_action) {
539 continue;
540 }
541
542 const Interval quotient_bound = as_interval(state_info.value);
543
544 const StateID exiting_id = quotient_action->state_id;
545
546 policy->emplace_decision(
547 exiting_id,
548 quotient_action->action,
549 quotient_bound);
550
551 // Nothing else needs to be done if the trap has only one state.
552 if (quotient_state.num_members() != 1) {
553 std::unordered_map<StateID, std::set<QAction>> parents;
554
555 // Build the inverse graph
556 std::vector<QAction> inner_actions;
557 quotient_state.get_collapsed_actions(inner_actions);
558
559 for (const QAction& qaction : inner_actions) {
560 StateID source_id = qaction.state_id;
561 Action action = qaction.action;
562
563 const State source = mdp.get_state(source_id);
564
565 Distribution<StateID> successors;
566 mdp.generate_action_transitions(source, action, successors);
567
568 for (const StateID succ_id : successors.support()) {
569 parents[succ_id].insert(qaction);
570 }
571 }
572
573 // Now traverse the inverse graph starting from the exiting
574 // state
575 std::deque<StateID> inverse_queue({exiting_id});
576 std::set<StateID> inverse_visited({exiting_id});
577
578 do {
579 const StateID next_id = inverse_queue.front();
580 inverse_queue.pop_front();
581
582 for (const auto& [pred_id, act] : parents[next_id]) {
583 if (inverse_visited.insert(pred_id).second) {
584 policy->emplace_decision(pred_id, act, quotient_bound);
585 inverse_queue.push_back(pred_id);
586 }
587 }
588 } while (!inverse_queue.empty());
589 }
590
591 // Push the successor traps.
592 Distribution<StateID> successors;
593 quotient.generate_action_transitions(
594 quotient_state,
595 *quotient_action,
596 successors);
597
598 for (const StateID succ_id : successors.support()) {
599 if (visited.insert(succ_id).second) {
600 queue.push_back(succ_id);
601 }
602 }
603 } while (!queue.empty());
604
605 return policy;
606}
607
608template <typename State, typename Action, bool UseInterval>
609void TALRTDP<State, Action, UseInterval>::print_statistics(
610 std::ostream& out) const
611{
612 return algorithm_.print_statistics(out);
613}
614
615} // namespace probfd::algorithms::trap_aware_lrtdp
TrialTerminationCondition
Enumeration type specifying the termination condition for trials sampled during LRTDP.
Definition lrtdp.h:25
Namespace dedicated to labelled real-time dynamic programming (LRTDP) with native trap handling suppo...
Definition trap_aware_lrtdp.h:22
Interval as_interval(value_t lower_bound)
Returns the interval with the given lower bound and infinte upper bound.
double value_t
Typedef for the state value type.
Definition aliases.h:7