AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
quotient_system_impl.h
1#ifndef GUARD_INCLUDE_PROBFD_QUOTIENTS_QUOTIENT_SYSTEM_H
2#error "This file should only be included from quotient_system.h"
3#endif
4
5#include "probfd/distribution.h"
6#include "probfd/transition.h"
7
8#include "downward/utils/collections.h"
9
10namespace probfd::quotients {
11
12template <typename Action>
13struct QuotientInformation<Action>::StateInfo {
14 StateID state_id;
15 size_t num_outer_acts = 0;
16 size_t num_inner_acts = 0;
17};
18
19template <typename Action>
20size_t QuotientInformation<Action>::num_members() const
21{
22 return state_infos_.size();
23}
24
25template <typename Action>
26auto QuotientInformation<Action>::member_ids()
27{
28 return std::views::transform(state_infos_, &StateInfo::state_id);
29}
30
31template <typename Action>
32auto QuotientInformation<Action>::member_ids() const
33{
34 return std::views::transform(state_infos_, &StateInfo::state_id);
35}
36
37template <typename Action>
38void QuotientInformation<Action>::filter_actions(
39 const std::ranges::input_range auto& filter)
40{
41 if (std::ranges::empty(filter)) {
42 return;
43 }
44
45 total_num_outer_acts_ = 0;
46
47 auto act_it = aops_.begin();
48
49 for (auto& info : state_infos_) {
50 auto outer_end = std::stable_partition(
51 act_it,
52 act_it + info.num_outer_acts,
53 [&info, &filter](Action a) {
54 return !std::ranges::contains(
55 filter,
56 QuotientAction<Action>(info.state_id, a));
57 });
58
59 const size_t num_total_acts = info.num_outer_acts + info.num_inner_acts;
60
61 info.num_outer_acts = std::distance(act_it, outer_end);
62 info.num_inner_acts = num_total_acts - info.num_outer_acts;
63
64 total_num_outer_acts_ += info.num_outer_acts;
65
66 act_it += num_total_acts;
67 }
68
69 assert(act_it == aops_.end());
70}
71
72template <typename State, typename Action>
73QuotientState<State, Action>::QuotientState(MDPType& mdp, State single)
74 : mdp(mdp)
75 , single_or_quotient(std::move(single))
76{
77}
78
79template <typename State, typename Action>
80QuotientState<State, Action>::QuotientState(
81 MDPType& mdp,
82 const QuotientInformationType* quotient)
83 : mdp(mdp)
84 , single_or_quotient(quotient)
85{
86}
87
88template <typename State, typename Action>
89template <std::invocable<param_type<State>> F>
90value_t QuotientState<State, Action>::member_maximum(F&& f) const
91 requires(std::is_convertible_v<
92 std::invoke_result_t<F, param_type<State>>,
93 value_t>)
94{
95 using namespace std::views;
96
97 return std::visit(
98 overloaded{
99 [&](const QuotientInformationType* quotient) {
100 value_t res = -INFINITE_VALUE;
101 for (param_type<State> state :
102 quotient->member_ids() | transform(std::bind_front(
103 &MDPType::get_state,
104 std::ref(mdp)))) {
105 res = std::max(res, f(state));
106 }
107 return res;
108 },
109 [&](param_type<State> single) { return f(single); }},
110 single_or_quotient);
111}
112
113template <typename State, typename Action>
114void QuotientState<State, Action>::for_each_member_state(
115 std::invocable<param_type<State>> auto&& f) const
116{
117 std::visit(
118 overloaded{
119 [&](const QuotientInformationType* quotient) {
120 std::ranges::for_each(
121 quotient->member_ids() |
122 std::views::transform(std::bind_front(
123 &MDPType::get_state,
124 std::ref(mdp))),
125 f);
126 },
127 [&](param_type<State> single) { f(single); }},
128 single_or_quotient);
129}
130
131template <typename State, typename Action>
132size_t QuotientState<State, Action>::num_members() const
133{
134 return std::visit(
135 overloaded{
136 [](const QuotientInformationType* quotient) {
137 return quotient->num_members();
138 },
139 [](param_type<State>) -> size_t { return 1; }},
140 single_or_quotient);
141}
142
143template <typename State, typename Action>
144void QuotientState<State, Action>::get_collapsed_actions(
145 std::vector<QuotientAction<Action>>& aops) const
146{
147 std::visit(
148 overloaded{
149 [&](const QuotientInformationType* info) {
150 aops.reserve(info->aops_.size() - info->total_num_outer_acts_);
151
152 auto aid = info->aops_.begin();
153
154 for (const auto& sinfo : info->state_infos_) {
155 aid += sinfo.num_outer_acts; // Start with inner actions
156 const auto inners_end = aid + sinfo.num_inner_acts;
157 for (; aid != inners_end; ++aid) {
158 aops.emplace_back(sinfo.state_id, *aid);
159 }
160 }
161
162 assert(
163 aops.size() ==
164 info->aops_.size() - info->total_num_outer_acts_);
165 },
166 [](param_type<State>) { return; }},
167 single_or_quotient);
168}
169
170template <typename State, typename Action>
171quotient_id_iterator<State, Action>::quotient_id_iterator(
172 const QuotientSystem<State, Action>* qs,
173 StateID x)
174 : qs_(qs)
175 , i_(x)
176
177{
178}
179
180template <typename State, typename Action>
181auto quotient_id_iterator<State, Action>::operator++() -> quotient_id_iterator&
182{
183 while (++i_.id < qs_->quotient_ids_.size()) {
184 const StateID ref = qs_->quotient_ids_[i_];
185 if (i_ == (ref & QuotientSystem<State, Action>::MASK)) {
186 break;
187 }
188 }
189
190 return *this;
191}
192
193template <typename State, typename Action>
194bool operator==(
195 const typename QuotientSystem<State, Action>::quotient_id_iterator& left,
196 const typename QuotientSystem<State, Action>::quotient_id_iterator& right)
197{
198 return left.i_ == right.i_;
199}
200
201template <typename State, typename Action>
202StateID quotient_id_iterator<State, Action>::operator*() const
203{
204 return i_;
205}
206
207template <typename State, typename Action>
208QuotientSystem<State, Action>::QuotientSystem(MDPType& mdp)
209 : mdp_(mdp)
210{
211}
212
213template <typename State, typename Action>
214StateID QuotientSystem<State, Action>::get_state_id(param_type<QState> state)
215{
216 return std::visit(
217 overloaded{
218 [&](const QuotientInformationType* info) {
219 return info->state_infos_.front().state_id;
220 },
221 [&](param_type<State> s) { return mdp_.get_state_id(s); }},
222 state.single_or_quotient);
223}
224
225template <typename State, typename Action>
226auto QuotientSystem<State, Action>::get_state(StateID state_id) -> QState
227{
228 const QuotientInformationType* info = get_quotient_info(state_id);
229
230 if (info) {
231 return QState(mdp_, info);
232 }
233
234 return QState(mdp_, mdp_.get_state(state_id));
235}
236
237template <typename State, typename Action>
238void QuotientSystem<State, Action>::generate_applicable_actions(
239 param_type<QState> state,
240 std::vector<QAction>& aops)
241{
242 std::visit(
243 overloaded{
244 [&](const QuotientInformationType* info) {
245 aops.reserve(info->total_num_outer_acts_);
246
247 auto aid = info->aops_.begin();
248
249 for (const auto& sinfo : info->state_infos_) {
250 const auto outers_end = aid + sinfo.num_outer_acts;
251 for (; aid != outers_end; ++aid) {
252 aops.emplace_back(sinfo.state_id, *aid);
253 }
254 aid += sinfo.num_inner_acts; // Skip inner actions
255 }
256
257 assert(aops.size() == info->total_num_outer_acts_);
258 },
259 [&](param_type<State> state) {
260 std::vector<Action> orig;
261 mdp_.generate_applicable_actions(state, orig);
262
263 const StateID state_id = mdp_.get_state_id(state);
264 aops.reserve(orig.size());
265
266 for (const Action& a : orig) {
267 aops.emplace_back(state_id, a);
268 }
269 }},
270 state.single_or_quotient);
271}
272
273template <typename State, typename Action>
274void QuotientSystem<State, Action>::generate_action_transitions(
275 param_type<QState>,
276 QAction a,
277 Distribution<StateID>& result)
278{
279 Distribution<StateID> orig;
280 const State state = this->mdp_.get_state(a.state_id);
281 mdp_.generate_action_transitions(state, a.action, orig);
282
283 for (const auto& [state_id, probability] : orig) {
284 result.add_probability(
285 get_masked_state_id(state_id) & MASK,
286 probability);
287 }
288}
289
290template <typename State, typename Action>
291void QuotientSystem<State, Action>::generate_all_transitions(
292 param_type<QState> state,
293 std::vector<QAction>& aops,
294 std::vector<Distribution<StateID>>& successors)
295{
296 std::visit(
297 overloaded{
298 [&](const QuotientInformationType* info) {
299 aops.reserve(info->total_num_outer_acts_);
300 successors.reserve(info->total_num_outer_acts_);
301
302 auto aop = info->aops_.begin();
303
304 for (const auto& info : info->state_infos_) {
305 const auto outers_end = aop + info.num_outer_acts;
306 for (; aop != outers_end; ++aop) {
307 const QAction& a =
308 aops.emplace_back(info.state_id, *aop);
309 generate_action_transitions(
310 state,
311 a,
312 successors.emplace_back());
313 }
314 aop += info.num_inner_acts; // Skip inner actions
315 }
316
317 assert(aops.size() == info->total_num_outer_acts_);
318 assert(successors.size() == info->total_num_outer_acts_);
319 },
320 [&](param_type<State> state) {
321 std::vector<Action> orig_a;
322 mdp_.generate_applicable_actions(state, orig_a);
323
324 const StateID state_id = mdp_.get_state_id(state);
325 aops.reserve(orig_a.size());
326 successors.reserve(orig_a.size());
327
328 for (Action oa : orig_a) {
329 aops.emplace_back(state_id, oa);
330 auto& dist = successors.emplace_back();
331
332 Distribution<StateID> orig;
333 mdp_.generate_action_transitions(state, oa, orig);
334
335 for (const auto& [state_id, probability] : orig) {
336 dist.add_probability(
337 get_masked_state_id(state_id) & MASK,
338 probability);
339 }
340 }
341 }},
342 state.single_or_quotient);
343}
344
345template <typename State, typename Action>
346void QuotientSystem<State, Action>::generate_all_transitions(
347 param_type<QState> state,
348 std::vector<Transition<QAction>>& transitions)
349{
350 std::visit(
351 overloaded{
352 [&](const QuotientInformationType* info) {
353 transitions.reserve(info->total_num_outer_acts_);
354
355 auto aop = info->aops_.begin();
356
357 for (const auto& info : info->state_infos_) {
358 const auto outers_end = aop + info.num_outer_acts;
359 for (; aop != outers_end; ++aop) {
360 QAction qa(info.state_id, *aop);
361 Transition<QAction>& t = transitions.emplace_back(qa);
362 generate_action_transitions(
363 state,
364 qa,
365 t.successor_dist);
366 }
367 aop += info.num_inner_acts; // Skip inner actions
368 }
369
370 assert(transitions.size() == info->total_num_outer_acts_);
371 },
372 [&](param_type<State> state) {
373 std::vector<Action> orig_a;
374 mdp_.generate_applicable_actions(state, orig_a);
375
376 const StateID state_id = mdp_.get_state_id(state);
377 transitions.reserve(orig_a.size());
378
379 for (Action a : orig_a) {
380 QAction qa(state_id, a);
381 Transition<QAction>& t = transitions.emplace_back(qa);
382
383 Distribution<StateID> orig;
384 mdp_.generate_action_transitions(state, a, orig);
385
386 for (const auto& [state_id, probability] : orig) {
387 t.successor_dist.add_probability(
388 get_masked_state_id(state_id) & MASK,
389 probability);
390 }
391 }
392 }},
393 state.single_or_quotient);
394}
395
396template <typename State, typename Action>
397TerminationInfo
398QuotientSystem<State, Action>::get_termination_info(param_type<QState> s)
399{
400 return std::visit(
401 overloaded{
402 [&](const QuotientInformationType* info) {
403 return info->termination_info_;
404 },
405 [&](param_type<State> state) {
406 return mdp_.get_termination_info(state);
407 }},
408 s.single_or_quotient);
409}
410
411template <typename State, typename Action>
412value_t QuotientSystem<State, Action>::get_action_cost(QAction qa)
413{
414 return mdp_.get_action_cost(qa.action);
415}
416
417template <typename State, typename Action>
418auto QuotientSystem<State, Action>::get_parent_mdp() -> MDPType&
419{
420 return mdp_;
421}
422
423template <typename State, typename Action>
424auto QuotientSystem<State, Action>::begin() const -> const_iterator
425{
426 return quotient_id_iterator(this, 0);
427}
428
429template <typename State, typename Action>
430auto QuotientSystem<State, Action>::end() const -> const_iterator
431{
432 return quotient_id_iterator(this, quotient_ids_.size());
433}
434
435template <typename State, typename Action>
436auto QuotientSystem<State, Action>::translate_state(param_type<State> s) const
437 -> QState
438{
439 StateID id = mdp_.get_state_id(s);
440 const auto* info = get_quotient_info(get_masked_state_id(id));
441
442 if (info) {
443 return QState(mdp_, info);
444 }
445
446 return QState(mdp_, s);
447}
448
449template <typename State, typename Action>
450StateID QuotientSystem<State, Action>::translate_state_id(StateID sid) const
451{
452 return StateID(get_masked_state_id(sid) & MASK);
453}
454
455template <typename State, typename Action>
456template <typename Range>
457void QuotientSystem<State, Action>::build_quotient(Range& states)
458{
459 auto range =
460 std::views::zip(states, std::views::repeat(std::vector<QAction>()));
461 this->build_quotient(range, *range.begin());
462}
463
464template <typename State, typename Action>
465template <typename SubMDP>
466void QuotientSystem<State, Action>::build_quotient(
467 SubMDP submdp,
468 std::ranges::range_reference_t<SubMDP> entry)
469{
470 using namespace std::views;
471
472 const StateID rid = get<0>(entry);
473 const auto& raops = get<1>(entry);
474
475 value_t min_termination = INFINITE_VALUE;
476 bool is_goal = false;
477
478 // Get or create quotient
479 QuotientInformationType& qinfo = quotients_[rid];
480
481 // We handle the representative state first so that it
482 // appears first in the data structure.
483 if (qinfo.state_infos_.empty()) {
484 // Add this state to the quotient
485 auto& b = qinfo.state_infos_.emplace_back(rid);
486 set_masked_state_id(rid, rid);
487
488 const State repr = mdp_.get_state(rid);
489
490 // Merge goal state status and termination cost
491 const auto repr_term = mdp_.get_termination_info(repr);
492 min_termination = std::min(min_termination, repr_term.get_cost());
493 is_goal = is_goal || repr_term.is_goal_state();
494
495 // Generate the applicable actions and add them to the new
496 // quotient
497 const size_t prev_size = qinfo.aops_.size();
498 mdp_.generate_applicable_actions(repr, qinfo.aops_);
499
500 // Partition new actions
501 auto new_aops = qinfo.aops_ | drop(prev_size);
502
503 {
504 auto [pivot, last] = partition_actions(
505 new_aops,
506 raops | transform(&QAction::action));
507
508 b.num_outer_acts = std::distance(new_aops.begin(), pivot);
509 b.num_inner_acts = std::distance(pivot, last);
510 }
511
512 qinfo.total_num_outer_acts_ += b.num_outer_acts;
513 } else {
514 // Filter actions
515 qinfo.filter_actions(raops);
516
517 // Merge goal state status and termination cost
518 const auto repr_term = qinfo.termination_info_;
519 min_termination = std::min(min_termination, repr_term.get_cost());
520 is_goal = is_goal || repr_term.is_goal_state();
521 }
522
523 for (const auto& entry : submdp) {
524 const StateID state_id = get<0>(entry);
525 const auto& aops = get<1>(entry);
526
527 // Already handled.
528 if (state_id == rid) {
529 continue;
530 }
531
532 const StateID::size_type qsqid = get_masked_state_id(state_id);
533
534 // If the state is a quotient state, add all states it
535 // represents to the new quotient
536 if (qsqid & FLAG) {
537 // Get the old quotient
538 auto qit = quotients_.find(qsqid & MASK);
539 QuotientInformationType& q = qit->second;
540
541 // Filter actions
542 q.filter_actions(aops);
543
544 // Merge goal state status and termination cost
545 const auto mem_term = q.termination_info_;
546 min_termination = std::min(min_termination, mem_term.get_cost());
547 is_goal = is_goal || mem_term.is_goal_state();
548
549 // Insert all states belonging to it to the new quotient
550 for (const auto& p : q.state_infos_) {
551 qinfo.state_infos_.push_back(p);
552 set_masked_state_id(p.state_id, rid);
553 }
554
555 // Move the actions to the new quotient
556 std::ranges::move(q.aops_, std::back_inserter(qinfo.aops_));
557 qinfo.total_num_outer_acts_ += q.total_num_outer_acts_;
558
559 // Erase the old quotient
560 quotients_.erase(qit);
561 } else {
562 // Add this state to the quotient
563 auto& b = qinfo.state_infos_.emplace_back(state_id);
564 set_masked_state_id(state_id, rid);
565
566 const State mem = mdp_.get_state(state_id);
567
568 // Merge goal state status and termination cost
569 const auto mem_term = mdp_.get_termination_info(mem);
570 min_termination = std::min(min_termination, mem_term.get_cost());
571 is_goal = is_goal || mem_term.is_goal_state();
572
573 // Generate the applicable actions and add them to the new
574 // quotient
575 const size_t prev_size = qinfo.aops_.size();
576 mdp_.generate_applicable_actions(mem, qinfo.aops_);
577
578 // Partition new actions
579 auto new_aops = qinfo.aops_ | drop(prev_size);
580
581 auto [pivot, last] = partition_actions(
582 new_aops,
583 aops | std::views::transform(&QAction::action));
584
585 b.num_outer_acts = std::distance(new_aops.begin(), pivot);
586 b.num_inner_acts = std::distance(pivot, last);
587
588 qinfo.total_num_outer_acts_ += b.num_outer_acts;
589 }
590 }
591
592 qinfo.termination_info_ =
593 is_goal ? TerminationInfo::from_goal()
594 : TerminationInfo::from_non_goal(min_termination);
595}
596
597template <typename State, typename Action>
598template <typename SubMDP>
599void QuotientSystem<State, Action>::build_new_quotient(
600 SubMDP submdp,
601 std::ranges::range_reference_t<SubMDP> entry)
602{
603 const StateID rid = get<0>(entry);
604 const auto& raops = get<1>(entry);
605
606 // Get or create quotient
607 QuotientInformationType& qinfo = quotients_[rid];
608
609 // We handle the representative state first so that it
610 // appears first in the data structure.
611 assert(qinfo.state_infos_.empty());
612
613 // Merged goal state status and termination cost
614 value_t min_termination;
615 bool is_goal;
616
617 {
618 // Add this state to the quotient
619 auto& b = qinfo.state_infos_.emplace_back(rid);
620 set_masked_state_id(rid, rid);
621
622 const State repr = mdp_.get_state(rid);
623
624 // Merge goal state status and termination cost
625 const auto repr_term = mdp_.get_termination_info(repr);
626 min_termination = repr_term.get_cost();
627 is_goal = repr_term.is_goal_state();
628
629 // Generate the applicable actions
630 mdp_.generate_applicable_actions(repr, qinfo.aops_);
631
632 // Partition actions
633 auto [pivot, last] = partition_actions(qinfo.aops_, raops);
634
635 b.num_outer_acts = std::distance(qinfo.aops_.begin(), pivot);
636 b.num_inner_acts = std::distance(pivot, last);
637
638 qinfo.total_num_outer_acts_ += b.num_outer_acts;
639 }
640
641 for (const auto& entry : submdp) {
642 const StateID state_id = get<0>(entry);
643 const auto& aops = get<1>(entry);
644
645 // Already handled.
646 if (state_id == rid) {
647 continue;
648 }
649
650 assert(!(get_masked_state_id(state_id) & FLAG));
651
652 // Add this state to the quotient
653 auto& b = qinfo.state_infos_.emplace_back(state_id);
654 set_masked_state_id(state_id, rid);
655
656 const State mem = mdp_.get_state(state_id);
657
658 // Merge goal state status and termination cost
659 const auto mem_term = mdp_.get_termination_info(mem);
660 min_termination = std::min(min_termination, mem_term.get_cost());
661 is_goal = is_goal || mem_term.is_goal_state();
662
663 // Generate the applicable actions
664 mdp_.generate_applicable_actions(mem, qinfo.aops_);
665
666 // Partition actions
667 auto [pivot, last] = partition_actions(qinfo.aops_, aops);
668
669 b.num_outer_acts = std::distance(qinfo.aops_.begin(), pivot);
670 b.num_inner_acts = std::distance(pivot, last);
671
672 qinfo.total_num_outer_acts_ += b.num_outer_acts;
673 }
674
675 qinfo.termination_info_ =
676 is_goal ? TerminationInfo::from_goal()
677 : TerminationInfo::from_non_goal(min_termination);
678}
679
680template <typename State, typename Action>
681auto QuotientSystem<State, Action>::partition_actions(
682 std::ranges::input_range auto&& aops,
683 const std::ranges::input_range auto& filter) const
684{
685 if (filter.empty()) {
686 return std::ranges::subrange(aops.begin(), aops.end());
687 }
688
689 return std::ranges::stable_partition(aops, [&filter](const Action& action) {
690 return std::ranges::find(filter, action) == filter.end();
691 });
692}
693
694template <typename State, typename Action>
695auto QuotientSystem<State, Action>::get_quotient_info(StateID state_id)
696 -> QuotientInformationType*
697{
698 const StateID::size_type qid = get_masked_state_id(state_id);
699 return qid & FLAG ? &quotients_.find(qid & MASK)->second : nullptr;
700}
701
702template <typename State, typename Action>
703auto QuotientSystem<State, Action>::get_quotient_info(StateID state_id) const
704 -> const QuotientInformationType*
705{
706 const StateID::size_type qid = get_masked_state_id(state_id);
707 return qid & FLAG ? &quotients_.find(qid & MASK)->second : nullptr;
708}
709
710template <typename State, typename Action>
711StateID::size_type
712QuotientSystem<State, Action>::get_masked_state_id(StateID sid) const
713{
714 return sid < quotient_ids_.size() ? quotient_ids_[sid] : sid.id;
715}
716
717template <typename State, typename Action>
718void QuotientSystem<State, Action>::set_masked_state_id(
719 StateID sid,
720 const StateID::size_type& qsid)
721{
722 if (sid >= quotient_ids_.size()) {
723 for (auto idx = quotient_ids_.size(); idx <= sid; ++idx) {
724 quotient_ids_.push_back(idx);
725 }
726 }
727
728 quotient_ids_[sid] = qsid | FLAG;
729}
730
731} // namespace probfd::quotients
double value_t
Typedef for the state value type.
Definition aliases.h:7
STL namespace.