AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
interval_iteration_impl.h
1#ifndef GUARD_INCLUDE_PROBFD_ALGORITHMS_INTERVAL_ITERATION_H
2#error "This file should only be included from interval_iteration.h"
3#endif
4
5#include "probfd/quotients/quotient_max_heuristic.h"
6
7#include "probfd/utils/not_implemented.h"
8
9#include "probfd/progress_report.h"
10
11#include "downward/utils/collections.h"
12
13#include <iterator>
14
16
17template <typename State, typename Action>
18IntervalIteration<State, Action>::IntervalIteration(
19 bool extract_probability_one_states,
20 bool expand_goals)
21 : extract_probability_one_states_(extract_probability_one_states)
22 , qr_analysis_(expand_goals)
23 , ec_decomposer_(expand_goals)
24 , vi_(expand_goals)
25{
26}
27
28template <typename State, typename Action>
29Interval IntervalIteration<State, Action>::solve(
30 MDPType& mdp,
31 EvaluatorType& heuristic,
32 param_type<State> state,
33 ProgressReport,
34 double max_time)
35{
36 utils::CountdownTimer timer(max_time);
37 std::unique_ptr sys = create_quotient(mdp, heuristic, state, timer);
38 std::vector<StateID> dead, one;
39 storage::PerStateStorage<Interval> value_store;
40 return mysolve(mdp, heuristic, state, value_store, dead, one, *sys, timer);
41}
42
43template <typename State, typename Action>
44auto IntervalIteration<State, Action>::compute_policy(
45 MDPType&,
46 EvaluatorType&,
47 param_type<State>,
48 ProgressReport,
49 double) -> std::unique_ptr<PolicyType>
50{
51 not_implemented();
52}
53
54template <typename State, typename Action>
56{
57 statistics_.print(out);
58}
59
60template <typename State, typename Action>
61template <typename ValueStoreT, typename SetLike, typename SetLike2>
63 MDPType& mdp,
64 EvaluatorType& heuristic,
66 ValueStoreT& value_store,
67 SetLike& dead_ends,
68 SetLike2& one_states,
69 double max_time)
70{
71 utils::CountdownTimer timer(max_time);
72
73 auto sys = create_quotient(mdp, heuristic, state, timer);
74
75 const Interval x = this->mysolve(
76 mdp,
77 heuristic,
78 state,
79 value_store,
80 dead_ends,
81 one_states,
82 *sys,
83 timer);
84
85 for (StateID repr_id : *sys) {
86 const auto value = value_store[repr_id];
87 const bool dead = utils::contains(dead_ends, repr_id);
88 const bool one = utils::contains(one_states, repr_id);
89
90 sys->for_each_member_state(
91 repr_id,
92 [&, value, dead, one](StateID member_id) {
93 value_store[member_id] = value;
94 if (dead) dead_ends.push_back(member_id);
95 if (one) one_states.push_back(member_id);
96 });
97 }
98
99 return x;
100}
101
102template <typename State, typename Action>
103auto IntervalIteration<State, Action>::create_quotient(
104 MDPType& mdp,
105 EvaluatorType& heuristic,
106 param_type<State> state,
107 utils::CountdownTimer& timer) -> std::unique_ptr<QSystem>
108{
109 auto sys = ec_decomposer_.build_quotient_system(
110 mdp,
111 &heuristic,
112 state,
113 timer.get_remaining_time());
114
115 statistics_.ecd_statistics = ec_decomposer_.get_statistics();
116
117 return sys;
118}
119
120template <typename State, typename Action>
121template <typename ValueStoreT, typename SetLike, typename SetLike2>
122Interval IntervalIteration<State, Action>::mysolve(
123 MDPType& mdp,
124 EvaluatorType& heuristic,
125 param_type<State> state,
126 ValueStoreT& value_store,
127 SetLike& dead_ends,
128 SetLike2& one_states,
129 QSystem& sys,
130 utils::CountdownTimer& timer)
131{
132 QState qstate = sys.translate_state(state);
133
134 if (extract_probability_one_states_) {
135 qr_analysis_.run_analysis(
136 sys,
137 nullptr,
138 qstate,
139 std::back_inserter(dead_ends),
140 iterators::discarding_output_iterator(),
141 std::back_inserter(one_states),
142 timer.get_remaining_time());
143 assert(mdp.get_termination_info(mdp.get_state(one_states.front()))
144 .is_goal_state());
145 } else {
146 qr_analysis_.run_analysis(
147 sys,
148 nullptr,
149 qstate,
150 std::back_inserter(dead_ends),
151 iterators::discarding_output_iterator(),
152 iterators::discarding_output_iterator(),
153 timer.get_remaining_time());
154 }
155
156 assert(::utils::is_unique(dead_ends) && ::utils::is_unique(one_states));
157
158 sys.build_quotient(dead_ends);
159 sys.build_quotient(one_states);
160
161 const auto new_init_id = sys.translate_state_id(mdp.get_state_id(state));
162
163 quotients::QuotientMaxHeuristic<State, Action> qheuristic(heuristic);
164
165 const Interval result = vi_.solve(
166 sys,
167 qheuristic,
168 new_init_id,
169 value_store,
170 timer.get_remaining_time());
171 statistics_.tvi_statistics = vi_.get_statistics();
172 return result;
173}
174
175} // namespace probfd::algorithms::interval_iteration
Implemention of interval iteration haddad:etal:misc-17.
Definition interval_iteration.h:62
Namespace dedicated to interval iteration on MaxProb MDPs.
Definition interval_iteration.h:18
typename std::conditional_t< is_cheap_to_copy_v< T >, T, const T & > param_type
Alias template defining the best way to pass a parameter of a given type.
Definition type_traits.h:25
Represents a closed interval over the extended reals as a pair of lower and upper bound.
Definition interval.h:12
A StateID represents a state within a StateIDMap. Just like Fast Downward's StateID type,...
Definition types.h:22