1#ifndef GUARD_INCLUDE_PROBFD_ALGORITHMS_ACYCLIC_VALUE_ITERATION_H
2#error "This file should only be included from acyclic_value_iteration.h"
5#include "probfd/algorithms/utils.h"
7#include "probfd/policies/map_policy.h"
9#include "probfd/evaluator.h"
10#include "probfd/mdp.h"
12#include "downward/utils/countdown_timer.h"
16template <
typename State,
typename Action>
23template <
typename State,
typename Action>
29template <
typename State,
typename Action>
35template <
typename State,
typename Action>
41template <
typename State,
typename Action>
44 out <<
" Expanded state(s): " << state_expansions << std::endl;
45 out <<
" Pruned state(s): " << pruned_states << std::endl;
46 out <<
" Terminal state(s): " << terminal_states << std::endl;
47 out <<
" Goal state(s): " << goal_states << std::endl;
52template <
typename State,
typename Action>
53void AcyclicValueIterationObserverCollection<State, Action>::register_observer(
54 std::shared_ptr<Observer> observer)
56 observers_.push_back(std::move(observer));
59template <
typename State,
typename Action>
60void AcyclicValueIterationObserverCollection<State, Action>::
61 notify_state_selected_for_expansion(
const State& state)
63 for (
auto& observer : observers_) {
64 observer.on_state_selected_for_expansion(state);
68template <
typename State,
typename Action>
69void AcyclicValueIterationObserverCollection<State, Action>::notify_goal_state(
72 for (
auto& observer : observers_) {
73 observer.on_goal_state(state);
77template <
typename State,
typename Action>
78void AcyclicValueIterationObserverCollection<State, Action>::
79 notify_terminal_state(
const State& state)
81 for (
auto& observer : observers_) {
82 observer.on_terminal_state(state);
86template <
typename State,
typename Action>
87void AcyclicValueIterationObserverCollection<State, Action>::
88 notify_pruned_state(
const State& state)
90 for (
auto& observer : observers_) {
91 observer.on_pruned_state(state);
97template <
typename State,
typename Action>
98AcyclicValueIteration<State, Action>::IncrementalExpansionInfo::
99 IncrementalExpansionInfo(StateID state_id, StateInfo& state_info)
101 , state_info(state_info)
105template <
typename State,
typename Action>
106void AcyclicValueIteration<State, Action>::IncrementalExpansionInfo::
107 setup_transition(MDPType& mdp)
109 assert(transition.empty());
110 auto& next_action = remaining_aops.back();
111 t_value = mdp.get_action_cost(next_action);
112 const State state = mdp.get_state(state_id);
113 mdp.generate_action_transitions(state, next_action, transition);
114 successor = transition.begin();
117template <
typename State,
typename Action>
118void AcyclicValueIteration<State, Action>::IncrementalExpansionInfo::
119 backtrack_successor(
value_t probability, StateInfo& succ_info)
122 t_value += probability * succ_info.value;
123 succ_info.status = StateInfo::CLOSED;
126template <
typename State,
typename Action>
127bool AcyclicValueIteration<State, Action>::IncrementalExpansionInfo::advance(
131 return next_successor() || next_transition(mdp, policy);
134template <
typename State,
typename Action>
135bool AcyclicValueIteration<State, Action>::IncrementalExpansionInfo::
138 assert(successor != transition.end());
139 if (++successor != transition.end()) {
143 finalize_transition();
148template <
typename State,
typename Action>
149bool AcyclicValueIteration<State, Action>::IncrementalExpansionInfo::
150 next_transition(MDPType& mdp, MapPolicy* policy)
152 assert(!remaining_aops.empty());
153 remaining_aops.pop_back();
155 if (remaining_aops.empty()) {
156 finalize_expansion(policy);
161 setup_transition(mdp);
166template <
typename State,
typename Action>
167void AcyclicValueIteration<State, Action>::IncrementalExpansionInfo::
168 finalize_transition()
171 if (t_value < state_info.value) {
172 state_info.best_action = remaining_aops.back();
173 state_info.value = t_value;
177template <
typename State,
typename Action>
178void AcyclicValueIteration<State, Action>::IncrementalExpansionInfo::
179 finalize_expansion(MapPolicy* policy)
181 if (!policy || !state_info.best_action)
return;
182 policy->emplace_decision(
184 *state_info.best_action,
185 Interval(state_info.value));
188template <
typename State,
typename Action>
189auto AcyclicValueIteration<State, Action>::compute_policy(
191 EvaluatorType& heuristic,
194 double max_time) -> std::unique_ptr<PolicyType>
196 std::unique_ptr<MapPolicy> policy(
new MapPolicy(&mdp));
197 this->solve(mdp, heuristic, initial_state, max_time, policy.get());
201template <
typename State,
typename Action>
202Interval AcyclicValueIteration<State, Action>::solve(
204 EvaluatorType& heuristic,
209 return solve(mdp, heuristic, initial_state, max_time,
nullptr);
212template <
typename State,
typename Action>
213Interval AcyclicValueIteration<State, Action>::solve(
215 EvaluatorType& heuristic,
220 utils::CountdownTimer timer(max_time);
222 const StateID initial_state_id = mdp.get_state_id(initial_state);
223 StateInfo& iinfo = state_infos_[initial_state_id];
225 expansion_stack_.emplace(initial_state_id, iinfo);
226 iinfo.status = StateInfo::ON_STACK;
228 IncrementalExpansionInfo* e;
232 e = &expansion_stack_.top();
233 }
while (expand_state(mdp, heuristic, *e) &&
234 push_successor(mdp, policy, *e, timer));
237 expansion_stack_.pop();
239 if (expansion_stack_.empty()) {
240 return Interval(iinfo.value);
243 timer.throw_if_expired();
245 e = &expansion_stack_.top();
247 const auto [succ_id, probability] = *e->successor;
248 e->backtrack_successor(probability, state_infos_[succ_id]);
249 }
while (!e->advance(mdp, policy) ||
250 !push_successor(mdp, policy, *e, timer));
254template <
typename State,
typename Action>
255void AcyclicValueIteration<State, Action>::register_observer(
256 std::shared_ptr<Observer> observer)
258 observers_.register_observer(std::move(observer));
261template <
typename State,
typename Action>
262bool AcyclicValueIteration<State, Action>::push_successor(
265 IncrementalExpansionInfo& e,
266 utils::CountdownTimer& timer)
269 timer.throw_if_expired();
271 const auto [succ_id, probability] = *e.successor;
272 StateInfo& succ_info = state_infos_[succ_id];
274 if (succ_info.status == StateInfo::ON_STACK) {
275 std::cerr <<
"State space is not acyclic!" << std::endl;
276 utils::exit_with(utils::ExitCode::SEARCH_CRITICAL_ERROR);
279 if (succ_info.status == StateInfo::NEW) {
280 expansion_stack_.emplace(succ_id, succ_info);
281 succ_info.status = StateInfo::ON_STACK;
285 assert(succ_info.status == StateInfo::CLOSED);
287 e.backtrack_successor(probability, succ_info);
288 }
while (e.advance(mdp, policy));
293template <
typename State,
typename Action>
294bool AcyclicValueIteration<State, Action>::expand_state(
296 EvaluatorType& heuristic,
297 IncrementalExpansionInfo& e_info)
299 const State state = mdp.get_state(e_info.state_id);
300 StateInfo& succ_info = e_info.state_info;
302 assert(succ_info.status == StateInfo::ON_STACK);
304 const TerminationInfo term_info = mdp.get_termination_info(state);
305 const value_t term_value = term_info.get_cost();
307 succ_info.value = term_value;
309 if (term_info.is_goal_state()) {
310 observers_.notify_state_selected_for_expansion(state);
311 observers_.notify_goal_state(state);
315 if (heuristic.evaluate(state) == term_value) {
316 observers_.notify_pruned_state(state);
320 observers_.notify_state_selected_for_expansion(state);
322 assert(e_info.remaining_aops.empty());
324 mdp.generate_applicable_actions(state, e_info.remaining_aops);
325 if (e_info.remaining_aops.empty()) {
326 observers_.notify_terminal_state(state);
330 e_info.setup_transition(mdp);
Namespace dedicated to the acyclic value iteration algorithm.
Definition acyclic_value_iteration.h:22
double value_t
Typedef for the state value type.
Definition aliases.h:7
typename std::conditional_t< is_cheap_to_copy_v< T >, T, const T & > param_type
Alias template defining the best way to pass a parameter of a given type.
Definition type_traits.h:25
An observer that collects basic statistics of the acyclic value iteration algorithm.
Definition acyclic_value_iteration.h:58
void on_state_selected_for_expansion(const State &)
Called when the algorithm selects a state for expansion.
Definition acyclic_value_iteration_impl.h:17
void on_goal_state(const State &)
Called when a goal state is encountered during the expansion check.
Definition acyclic_value_iteration_impl.h:24
void on_pruned_state(const State &)
Called when a state is pruned during the expansion check.
Definition acyclic_value_iteration_impl.h:36
void on_terminal_state(const State &)
Called when a terminal state is encountered during the expansion check.
Definition acyclic_value_iteration_impl.h:30