1#ifndef GUARD_INCLUDE_PROBFD_QUOTIENTS_QUOTIENT_SYSTEM_H
2#error "This file should only be included from quotient_system.h"
5#include "probfd/distribution.h"
6#include "probfd/transition.h"
8#include "downward/utils/collections.h"
10namespace probfd::quotients {
12template <
typename Action>
13struct QuotientInformation<Action>::StateInfo {
15 size_t num_outer_acts = 0;
16 size_t num_inner_acts = 0;
19template <
typename Action>
20size_t QuotientInformation<Action>::num_members()
const
22 return state_infos_.size();
25template <
typename Action>
26auto QuotientInformation<Action>::member_ids()
28 return std::views::transform(state_infos_, &StateInfo::state_id);
31template <
typename Action>
32auto QuotientInformation<Action>::member_ids()
const
34 return std::views::transform(state_infos_, &StateInfo::state_id);
37template <
typename Action>
38void QuotientInformation<Action>::filter_actions(
39 const std::ranges::input_range
auto& filter)
41 if (std::ranges::empty(filter)) {
45 total_num_outer_acts_ = 0;
47 auto act_it = aops_.begin();
49 for (
auto& info : state_infos_) {
50 auto outer_end = std::stable_partition(
52 act_it + info.num_outer_acts,
53 [&info, &filter](Action a) {
54 return !std::ranges::contains(
56 QuotientAction<Action>(info.state_id, a));
59 const size_t num_total_acts = info.num_outer_acts + info.num_inner_acts;
61 info.num_outer_acts = std::distance(act_it, outer_end);
62 info.num_inner_acts = num_total_acts - info.num_outer_acts;
64 total_num_outer_acts_ += info.num_outer_acts;
66 act_it += num_total_acts;
69 assert(act_it == aops_.end());
72template <
typename State,
typename Action>
73QuotientState<State, Action>::QuotientState(MDPType& mdp, State single)
75 , single_or_quotient(
std::move(single))
79template <
typename State,
typename Action>
80QuotientState<State, Action>::QuotientState(
82 const QuotientInformationType* quotient)
84 , single_or_quotient(quotient)
88template <
typename State,
typename Action>
89template <std::invocable<param_type<State>> F>
90value_t QuotientState<State, Action>::member_maximum(F&& f)
const
91 requires(std::is_convertible_v<
92 std::invoke_result_t<F, param_type<State>>,
95 using namespace std::views;
99 [&](
const QuotientInformationType* quotient) {
101 for (param_type<State> state :
102 quotient->member_ids() | transform(std::bind_front(
105 res = std::max(res, f(state));
109 [&](param_type<State> single) {
return f(single); }},
113template <
typename State,
typename Action>
114void QuotientState<State, Action>::for_each_member_state(
115 std::invocable<param_type<State>>
auto&& f)
const
119 [&](
const QuotientInformationType* quotient) {
120 std::ranges::for_each(
121 quotient->member_ids() |
122 std::views::transform(std::bind_front(
127 [&](param_type<State> single) { f(single); }},
131template <
typename State,
typename Action>
132size_t QuotientState<State, Action>::num_members()
const
136 [](
const QuotientInformationType* quotient) {
137 return quotient->num_members();
139 [](param_type<State>) ->
size_t {
return 1; }},
143template <
typename State,
typename Action>
144void QuotientState<State, Action>::get_collapsed_actions(
145 std::vector<QuotientAction<Action>>& aops)
const
149 [&](
const QuotientInformationType* info) {
150 aops.reserve(info->aops_.size() - info->total_num_outer_acts_);
152 auto aid = info->aops_.begin();
154 for (
const auto& sinfo : info->state_infos_) {
155 aid += sinfo.num_outer_acts;
156 const auto inners_end = aid + sinfo.num_inner_acts;
157 for (; aid != inners_end; ++aid) {
158 aops.emplace_back(sinfo.state_id, *aid);
164 info->aops_.size() - info->total_num_outer_acts_);
166 [](param_type<State>) {
return; }},
170template <
typename State,
typename Action>
171quotient_id_iterator<State, Action>::quotient_id_iterator(
172 const QuotientSystem<State, Action>* qs,
180template <
typename State,
typename Action>
181auto quotient_id_iterator<State, Action>::operator++() -> quotient_id_iterator&
183 while (++i_.id < qs_->quotient_ids_.size()) {
184 const StateID ref = qs_->quotient_ids_[i_];
185 if (i_ == (ref & QuotientSystem<State, Action>::MASK)) {
193template <
typename State,
typename Action>
195 const typename QuotientSystem<State, Action>::quotient_id_iterator& left,
196 const typename QuotientSystem<State, Action>::quotient_id_iterator& right)
198 return left.i_ == right.i_;
201template <
typename State,
typename Action>
202StateID quotient_id_iterator<State, Action>::operator*()
const
207template <
typename State,
typename Action>
208QuotientSystem<State, Action>::QuotientSystem(MDPType& mdp)
213template <
typename State,
typename Action>
214StateID QuotientSystem<State, Action>::get_state_id(param_type<QState> state)
218 [&](
const QuotientInformationType* info) {
219 return info->state_infos_.front().state_id;
221 [&](param_type<State> s) {
return mdp_.get_state_id(s); }},
222 state.single_or_quotient);
225template <
typename State,
typename Action>
226auto QuotientSystem<State, Action>::get_state(StateID state_id) -> QState
228 const QuotientInformationType* info = get_quotient_info(state_id);
231 return QState(mdp_, info);
234 return QState(mdp_, mdp_.get_state(state_id));
237template <
typename State,
typename Action>
238void QuotientSystem<State, Action>::generate_applicable_actions(
239 param_type<QState> state,
240 std::vector<QAction>& aops)
244 [&](
const QuotientInformationType* info) {
245 aops.reserve(info->total_num_outer_acts_);
247 auto aid = info->aops_.begin();
249 for (
const auto& sinfo : info->state_infos_) {
250 const auto outers_end = aid + sinfo.num_outer_acts;
251 for (; aid != outers_end; ++aid) {
252 aops.emplace_back(sinfo.state_id, *aid);
254 aid += sinfo.num_inner_acts;
257 assert(aops.size() == info->total_num_outer_acts_);
259 [&](param_type<State> state) {
260 std::vector<Action> orig;
261 mdp_.generate_applicable_actions(state, orig);
263 const StateID state_id = mdp_.get_state_id(state);
264 aops.reserve(orig.size());
266 for (
const Action& a : orig) {
267 aops.emplace_back(state_id, a);
270 state.single_or_quotient);
273template <
typename State,
typename Action>
274void QuotientSystem<State, Action>::generate_action_transitions(
277 Distribution<StateID>& result)
279 Distribution<StateID> orig;
280 const State state = this->mdp_.get_state(a.state_id);
281 mdp_.generate_action_transitions(state, a.action, orig);
283 for (
const auto& [state_id, probability] : orig) {
284 result.add_probability(
285 get_masked_state_id(state_id) & MASK,
290template <
typename State,
typename Action>
291void QuotientSystem<State, Action>::generate_all_transitions(
292 param_type<QState> state,
293 std::vector<QAction>& aops,
294 std::vector<Distribution<StateID>>& successors)
298 [&](
const QuotientInformationType* info) {
299 aops.reserve(info->total_num_outer_acts_);
300 successors.reserve(info->total_num_outer_acts_);
302 auto aop = info->aops_.begin();
304 for (
const auto& info : info->state_infos_) {
305 const auto outers_end = aop + info.num_outer_acts;
306 for (; aop != outers_end; ++aop) {
308 aops.emplace_back(info.state_id, *aop);
309 generate_action_transitions(
312 successors.emplace_back());
314 aop += info.num_inner_acts;
317 assert(aops.size() == info->total_num_outer_acts_);
318 assert(successors.size() == info->total_num_outer_acts_);
320 [&](param_type<State> state) {
321 std::vector<Action> orig_a;
322 mdp_.generate_applicable_actions(state, orig_a);
324 const StateID state_id = mdp_.get_state_id(state);
325 aops.reserve(orig_a.size());
326 successors.reserve(orig_a.size());
328 for (Action oa : orig_a) {
329 aops.emplace_back(state_id, oa);
330 auto& dist = successors.emplace_back();
332 Distribution<StateID> orig;
333 mdp_.generate_action_transitions(state, oa, orig);
335 for (
const auto& [state_id, probability] : orig) {
336 dist.add_probability(
337 get_masked_state_id(state_id) & MASK,
342 state.single_or_quotient);
345template <
typename State,
typename Action>
346void QuotientSystem<State, Action>::generate_all_transitions(
347 param_type<QState> state,
348 std::vector<Transition<QAction>>& transitions)
352 [&](
const QuotientInformationType* info) {
353 transitions.reserve(info->total_num_outer_acts_);
355 auto aop = info->aops_.begin();
357 for (
const auto& info : info->state_infos_) {
358 const auto outers_end = aop + info.num_outer_acts;
359 for (; aop != outers_end; ++aop) {
360 QAction qa(info.state_id, *aop);
361 Transition<QAction>& t = transitions.emplace_back(qa);
362 generate_action_transitions(
367 aop += info.num_inner_acts;
370 assert(transitions.size() == info->total_num_outer_acts_);
372 [&](param_type<State> state) {
373 std::vector<Action> orig_a;
374 mdp_.generate_applicable_actions(state, orig_a);
376 const StateID state_id = mdp_.get_state_id(state);
377 transitions.reserve(orig_a.size());
379 for (Action a : orig_a) {
380 QAction qa(state_id, a);
381 Transition<QAction>& t = transitions.emplace_back(qa);
383 Distribution<StateID> orig;
384 mdp_.generate_action_transitions(state, a, orig);
386 for (
const auto& [state_id, probability] : orig) {
387 t.successor_dist.add_probability(
388 get_masked_state_id(state_id) & MASK,
393 state.single_or_quotient);
396template <
typename State,
typename Action>
398QuotientSystem<State, Action>::get_termination_info(param_type<QState> s)
402 [&](
const QuotientInformationType* info) {
403 return info->termination_info_;
405 [&](param_type<State> state) {
406 return mdp_.get_termination_info(state);
408 s.single_or_quotient);
411template <
typename State,
typename Action>
412value_t QuotientSystem<State, Action>::get_action_cost(QAction qa)
414 return mdp_.get_action_cost(qa.action);
417template <
typename State,
typename Action>
418auto QuotientSystem<State, Action>::get_parent_mdp() -> MDPType&
423template <
typename State,
typename Action>
424auto QuotientSystem<State, Action>::begin() const -> const_iterator
426 return quotient_id_iterator(
this, 0);
429template <
typename State,
typename Action>
430auto QuotientSystem<State, Action>::end() const -> const_iterator
432 return quotient_id_iterator(
this, quotient_ids_.size());
435template <
typename State,
typename Action>
436auto QuotientSystem<State, Action>::translate_state(param_type<State> s)
const
439 StateID
id = mdp_.get_state_id(s);
440 const auto* info = get_quotient_info(get_masked_state_id(
id));
443 return QState(mdp_, info);
446 return QState(mdp_, s);
449template <
typename State,
typename Action>
450StateID QuotientSystem<State, Action>::translate_state_id(StateID sid)
const
452 return StateID(get_masked_state_id(sid) & MASK);
455template <
typename State,
typename Action>
456template <
typename Range>
457void QuotientSystem<State, Action>::build_quotient(Range& states)
460 std::views::zip(states, std::views::repeat(std::vector<QAction>()));
461 this->build_quotient(range, *range.begin());
464template <
typename State,
typename Action>
465template <
typename SubMDP>
466void QuotientSystem<State, Action>::build_quotient(
468 std::ranges::range_reference_t<SubMDP> entry)
470 using namespace std::views;
472 const StateID rid = get<0>(entry);
473 const auto& raops = get<1>(entry);
475 value_t min_termination = INFINITE_VALUE;
476 bool is_goal =
false;
479 QuotientInformationType& qinfo = quotients_[rid];
483 if (qinfo.state_infos_.empty()) {
485 auto& b = qinfo.state_infos_.emplace_back(rid);
486 set_masked_state_id(rid, rid);
488 const State repr = mdp_.get_state(rid);
491 const auto repr_term = mdp_.get_termination_info(repr);
492 min_termination = std::min(min_termination, repr_term.get_cost());
493 is_goal = is_goal || repr_term.is_goal_state();
497 const size_t prev_size = qinfo.aops_.size();
498 mdp_.generate_applicable_actions(repr, qinfo.aops_);
501 auto new_aops = qinfo.aops_ | drop(prev_size);
504 auto [pivot, last] = partition_actions(
506 raops | transform(&QAction::action));
508 b.num_outer_acts = std::distance(new_aops.begin(), pivot);
509 b.num_inner_acts = std::distance(pivot, last);
512 qinfo.total_num_outer_acts_ += b.num_outer_acts;
515 qinfo.filter_actions(raops);
518 const auto repr_term = qinfo.termination_info_;
519 min_termination = std::min(min_termination, repr_term.get_cost());
520 is_goal = is_goal || repr_term.is_goal_state();
523 for (
const auto& entry : submdp) {
524 const StateID state_id = get<0>(entry);
525 const auto& aops = get<1>(entry);
528 if (state_id == rid) {
532 const StateID::size_type qsqid = get_masked_state_id(state_id);
538 auto qit = quotients_.find(qsqid & MASK);
539 QuotientInformationType& q = qit->second;
542 q.filter_actions(aops);
545 const auto mem_term = q.termination_info_;
546 min_termination = std::min(min_termination, mem_term.get_cost());
547 is_goal = is_goal || mem_term.is_goal_state();
550 for (
const auto& p : q.state_infos_) {
551 qinfo.state_infos_.push_back(p);
552 set_masked_state_id(p.state_id, rid);
556 std::ranges::move(q.aops_, std::back_inserter(qinfo.aops_));
557 qinfo.total_num_outer_acts_ += q.total_num_outer_acts_;
560 quotients_.erase(qit);
563 auto& b = qinfo.state_infos_.emplace_back(state_id);
564 set_masked_state_id(state_id, rid);
566 const State mem = mdp_.get_state(state_id);
569 const auto mem_term = mdp_.get_termination_info(mem);
570 min_termination = std::min(min_termination, mem_term.get_cost());
571 is_goal = is_goal || mem_term.is_goal_state();
575 const size_t prev_size = qinfo.aops_.size();
576 mdp_.generate_applicable_actions(mem, qinfo.aops_);
579 auto new_aops = qinfo.aops_ | drop(prev_size);
581 auto [pivot, last] = partition_actions(
583 aops | std::views::transform(&QAction::action));
585 b.num_outer_acts = std::distance(new_aops.begin(), pivot);
586 b.num_inner_acts = std::distance(pivot, last);
588 qinfo.total_num_outer_acts_ += b.num_outer_acts;
592 qinfo.termination_info_ =
593 is_goal ? TerminationInfo::from_goal()
594 : TerminationInfo::from_non_goal(min_termination);
597template <
typename State,
typename Action>
598template <
typename SubMDP>
599void QuotientSystem<State, Action>::build_new_quotient(
601 std::ranges::range_reference_t<SubMDP> entry)
603 const StateID rid = get<0>(entry);
604 const auto& raops = get<1>(entry);
607 QuotientInformationType& qinfo = quotients_[rid];
611 assert(qinfo.state_infos_.empty());
619 auto& b = qinfo.state_infos_.emplace_back(rid);
620 set_masked_state_id(rid, rid);
622 const State repr = mdp_.get_state(rid);
625 const auto repr_term = mdp_.get_termination_info(repr);
626 min_termination = repr_term.get_cost();
627 is_goal = repr_term.is_goal_state();
630 mdp_.generate_applicable_actions(repr, qinfo.aops_);
633 auto [pivot, last] = partition_actions(qinfo.aops_, raops);
635 b.num_outer_acts = std::distance(qinfo.aops_.begin(), pivot);
636 b.num_inner_acts = std::distance(pivot, last);
638 qinfo.total_num_outer_acts_ += b.num_outer_acts;
641 for (
const auto& entry : submdp) {
642 const StateID state_id = get<0>(entry);
643 const auto& aops = get<1>(entry);
646 if (state_id == rid) {
650 assert(!(get_masked_state_id(state_id) & FLAG));
653 auto& b = qinfo.state_infos_.emplace_back(state_id);
654 set_masked_state_id(state_id, rid);
656 const State mem = mdp_.get_state(state_id);
659 const auto mem_term = mdp_.get_termination_info(mem);
660 min_termination = std::min(min_termination, mem_term.get_cost());
661 is_goal = is_goal || mem_term.is_goal_state();
664 mdp_.generate_applicable_actions(mem, qinfo.aops_);
667 auto [pivot, last] = partition_actions(qinfo.aops_, aops);
669 b.num_outer_acts = std::distance(qinfo.aops_.begin(), pivot);
670 b.num_inner_acts = std::distance(pivot, last);
672 qinfo.total_num_outer_acts_ += b.num_outer_acts;
675 qinfo.termination_info_ =
676 is_goal ? TerminationInfo::from_goal()
677 : TerminationInfo::from_non_goal(min_termination);
680template <
typename State,
typename Action>
681auto QuotientSystem<State, Action>::partition_actions(
682 std::ranges::input_range
auto&& aops,
683 const std::ranges::input_range
auto& filter)
const
685 if (filter.empty()) {
686 return std::ranges::subrange(aops.begin(), aops.end());
689 return std::ranges::stable_partition(aops, [&filter](
const Action& action) {
690 return std::ranges::find(filter, action) == filter.end();
694template <
typename State,
typename Action>
695auto QuotientSystem<State, Action>::get_quotient_info(StateID state_id)
696 -> QuotientInformationType*
698 const StateID::size_type qid = get_masked_state_id(state_id);
699 return qid & FLAG ? "ients_.find(qid & MASK)->second :
nullptr;
702template <
typename State,
typename Action>
703auto QuotientSystem<State, Action>::get_quotient_info(StateID state_id)
const
704 ->
const QuotientInformationType*
706 const StateID::size_type qid = get_masked_state_id(state_id);
707 return qid & FLAG ? "ients_.find(qid & MASK)->second :
nullptr;
710template <
typename State,
typename Action>
712QuotientSystem<State, Action>::get_masked_state_id(StateID sid)
const
714 return sid < quotient_ids_.size() ? quotient_ids_[sid] : sid.id;
717template <
typename State,
typename Action>
718void QuotientSystem<State, Action>::set_masked_state_id(
720 const StateID::size_type& qsid)
722 if (sid >= quotient_ids_.size()) {
723 for (
auto idx = quotient_ids_.size(); idx <= sid; ++idx) {
724 quotient_ids_.push_back(idx);
728 quotient_ids_[sid] = qsid | FLAG;
double value_t
Typedef for the state value type.
Definition aliases.h:7