1#ifndef GUARD_INCLUDE_PROBFD_ALGORITHMS_TOPOLOGICAL_VALUE_ITERATION_H
2#error "This file should only be included from topological_value_iteration.h"
5#include "probfd/algorithms/utils.h"
7#include "probfd/policies/map_policy.h"
9#include "probfd/evaluator.h"
10#include "probfd/progress_report.h"
12#include "downward/utils/countdown_timer.h"
18inline void Statistics::print(std::ostream& out)
const
20 out <<
" Expanded state(s): " << expanded_states << std::endl;
21 out <<
" Terminal state(s): " << terminal_states << std::endl;
22 out <<
" Goal state(s): " << goal_states << std::endl;
23 out <<
" Pruned state(s): " << pruned << std::endl;
24 out <<
" Maximal SCCs: " << sccs <<
" (" << singleton_sccs
25 <<
" are singleton)" << std::endl;
26 out <<
" Bellman backups: " << bellman_backups << std::endl;
29template <
typename State,
typename Action,
bool UseInterval>
30TopologicalValueIteration<State, Action, UseInterval>::ExplorationInfo::
31 ExplorationInfo(StateID state_id, StackInfo& stack_info,
unsigned stackidx)
33 , stack_info(stack_info)
39template <
typename State,
typename Action,
bool UseInterval>
40void TopologicalValueIteration<State, Action, UseInterval>::ExplorationInfo::
41 update_lowlink(
unsigned upd)
43 lowlink = std::min(lowlink, upd);
46template <
typename State,
typename Action,
bool UseInterval>
47bool TopologicalValueIteration<State, Action, UseInterval>::ExplorationInfo::
48 next_transition(MDPType& mdp)
53 self_loop_prob = 0_vt;
55 return !aops.empty() &&
56 forward_non_loop_transition(mdp, mdp.get_state(state_id));
59template <
typename State,
typename Action,
bool UseInterval>
60bool TopologicalValueIteration<State, Action, UseInterval>::ExplorationInfo::
64 if (forward_non_loop_successor())
return true;
66 auto& tinfo = stack_info.nconv_qs.back();
68 if (tinfo.finalize_transition(self_loop_prob)) {
69 if (
set_min(stack_info.conv_part, tinfo.conv_part)) {
70 stack_info.best_converged = tinfo.action;
72 stack_info.nconv_qs.pop_back();
78template <
typename State,
typename Action,
bool UseInterval>
79bool TopologicalValueIteration<State, Action, UseInterval>::ExplorationInfo::
80 forward_non_loop_transition(MDPType& mdp,
const State& state)
83 mdp.generate_action_transitions(state, aops.back(), transition);
84 successor = transition.begin();
86 if (forward_non_loop_successor()) {
87 stack_info.nconv_qs.emplace_back(
89 mdp.get_action_cost(aops.back()));
95 }
while (!aops.empty());
100template <
typename State,
typename Action,
bool UseInterval>
101bool TopologicalValueIteration<State, Action, UseInterval>::ExplorationInfo::
102 forward_non_loop_successor()
105 if (successor->item != state_id) {
109 self_loop_prob += successor->probability;
110 }
while (++successor != transition.end());
115template <
typename State,
typename Action,
bool UseInterval>
116Action& TopologicalValueIteration<State, Action, UseInterval>::ExplorationInfo::
122template <
typename State,
typename Action,
bool UseInterval>
123ItemProbabilityPair<StateID>
124TopologicalValueIteration<State, Action, UseInterval>::ExplorationInfo::
125 get_current_successor()
130template <
typename State,
typename Action,
bool UseInterval>
131TopologicalValueIteration<State, Action, UseInterval>::QValueInfo::QValueInfo(
135 , conv_part(action_cost)
139template <
typename State,
typename Action,
bool UseInterval>
140bool TopologicalValueIteration<State, Action, UseInterval>::QValueInfo::
141 finalize_transition(value_t self_loop_prob)
143 if (self_loop_prob != 0_vt) {
145 const value_t normalization = 1_vt / (1_vt - self_loop_prob);
147 conv_part *= normalization;
149 for (
auto& pair : nconv_successors) {
150 pair.probability *= normalization;
154 return nconv_successors.empty();
157template <
typename State,
typename Action,
bool UseInterval>
158auto TopologicalValueIteration<State, Action, UseInterval>::QValueInfo::
159 compute_q_value() const -> AlgorithmValueType
161 AlgorithmValueType res = conv_part;
163 for (
auto& [value, prob] : nconv_successors) {
164 res += prob * (*value);
170template <
typename State,
typename Action,
bool UseInterval>
171TopologicalValueIteration<State, Action, UseInterval>::StackInfo::StackInfo(
173 AlgorithmValueType& value_ref)
179template <
typename State,
typename Action,
bool UseInterval>
180bool TopologicalValueIteration<State, Action, UseInterval>::StackInfo::
183 AlgorithmValueType v = conv_part;
184 best_action = best_converged;
186 for (
const QValueInfo& info : nconv_qs) {
187 if (
set_min(v, info.compute_q_value())) {
188 best_action = info.action;
192 if constexpr (UseInterval) {
194 return !value->bounds_approximately_equal();
200template <
typename State,
typename Action,
bool UseInterval>
203 : expand_goals_(expand_goals)
207template <
typename State,
typename Action,
bool UseInterval>
210 EvaluatorType& heuristic,
213 double max_time) -> std::unique_ptr<PolicyType>
215 storage::PerStateStorage<AlgorithmValueType> value_store;
216 std::unique_ptr<MapPolicy> policy(
new MapPolicy(&mdp));
220 mdp.get_state_id(state),
227template <
typename State,
typename Action,
bool UseInterval>
230 EvaluatorType& heuristic,
235 storage::PerStateStorage<AlgorithmValueType> value_store;
237 ->solve(mdp, heuristic, mdp.get_state_id(state), value_store, max_time);
240template <
typename State,
typename Action,
bool UseInterval>
242 std::ostream& out)
const
244 statistics_.print(out);
247template <
typename State,
typename Action,
bool UseInterval>
254template <
typename State,
typename Action,
bool UseInterval>
255template <
typename ValueStore>
258 EvaluatorType& heuristic,
260 ValueStore& value_store,
264 utils::CountdownTimer timer(max_time);
266 StateInfo& iinfo = state_information_[init_state_id];
267 AlgorithmValueType& init_value = value_store[init_state_id];
269 push_state(init_state_id, iinfo, init_value);
272 ExplorationInfo* explore;
275 explore = &exploration_stack_.back();
276 }
while (initialize_state(mdp, heuristic, *explore, value_store) &&
277 successor_loop(mdp, *explore, value_store, timer));
281 const unsigned stack_id = explore->stackidx;
282 const unsigned lowlink = explore->lowlink;
283 const bool backtrack_from_scc = stack_id == lowlink;
285 if (backtrack_from_scc) {
286 scc_found(stack_ | std::views::drop(stack_id), policy, timer);
289 exploration_stack_.pop_back();
291 if (exploration_stack_.empty()) {
292 if constexpr (UseInterval) {
295 return Interval(init_value, INFINITE_VALUE);
299 timer.throw_if_expired();
301 explore = &exploration_stack_.back();
303 const auto [succ_id, prob] = explore->get_current_successor();
304 AlgorithmValueType& s_value = value_store[succ_id];
305 QValueInfo& tinfo = explore->stack_info.nconv_qs.back();
307 if (backtrack_from_scc) {
308 tinfo.conv_part += prob * s_value;
310 explore->update_lowlink(lowlink);
311 tinfo.nconv_successors.emplace_back(&s_value, prob);
314 (!explore->next_successor() && !explore->next_transition(mdp)) ||
315 !successor_loop(mdp, *explore, value_store, timer));
319template <
typename State,
typename Action,
bool UseInterval>
322 StateInfo& state_info,
323 AlgorithmValueType& state_value)
325 const std::size_t stack_size = stack_.size();
326 exploration_stack_.emplace_back(
328 stack_.emplace_back(state_id, state_value),
330 state_info.stack_id = stack_size;
331 state_info.status = StateInfo::ONSTACK;
334template <
typename State,
typename Action,
bool UseInterval>
335bool TopologicalValueIteration<State, Action, UseInterval>::initialize_state(
337 EvaluatorType& heuristic,
338 ExplorationInfo& exp_info,
341 assert(state_information_[exp_info.state_id].status == StateInfo::NEW);
343 const State state = mdp.get_state(exp_info.state_id);
347 const value_t estimate = heuristic.evaluate(state);
349 exp_info.stack_info.conv_part = AlgorithmValueType(t_cost);
351 AlgorithmValueType& state_value = value_store[exp_info.state_id];
353 if constexpr (UseInterval) {
354 state_value.lower = estimate;
355 state_value.upper = t_cost;
357 state_value = estimate;
361 ++statistics_.goal_states;
363 if (!expand_goals_) {
364 ++statistics_.pruned;
367 }
else if (estimate == t_cost) {
368 ++statistics_.pruned;
372 mdp.generate_applicable_actions(state, exp_info.aops);
374 const size_t num_aops = exp_info.aops.size();
376 exp_info.stack_info.nconv_qs.reserve(num_aops);
378 ++statistics_.expanded_states;
380 if (exp_info.aops.empty()) {
381 ++statistics_.terminal_states;
382 }
else if (exp_info.forward_non_loop_transition(mdp, state)) {
389template <
typename State,
typename Action,
bool UseInterval>
390template <
typename ValueStore>
391bool TopologicalValueIteration<State, Action, UseInterval>::successor_loop(
393 ExplorationInfo& explore,
394 ValueStore& value_store,
395 utils::CountdownTimer& timer)
398 assert(!explore.stack_info.nconv_qs.empty());
399 QValueInfo& tinfo = explore.stack_info.nconv_qs.back();
402 timer.throw_if_expired();
404 const auto [succ_id, prob] = explore.get_current_successor();
405 assert(succ_id != explore.state_id);
406 StateInfo& succ_info = state_information_[succ_id];
407 AlgorithmValueType& s_value = value_store[succ_id];
409 switch (succ_info.status) {
411 case StateInfo::NEW: {
412 push_state(succ_id, succ_info, s_value);
416 case StateInfo::CLOSED: tinfo.conv_part += prob * s_value;
break;
418 case StateInfo::ONSTACK:
419 explore.update_lowlink(succ_info.stack_id);
420 tinfo.nconv_successors.emplace_back(&s_value, prob);
422 }
while (explore.next_successor());
423 }
while (explore.next_transition(mdp));
428template <
typename State,
typename Action,
bool UseInterval>
429void TopologicalValueIteration<State, Action, UseInterval>::scc_found(
432 utils::CountdownTimer& timer)
434 assert(!scc.empty());
438 if (scc.size() == 1) {
441 ++statistics_.singleton_sccs;
442 StackInfo& single = scc.front();
443 StateInfo& state_info = state_information_[single.state_id];
444 update(*single.value, single.conv_part);
445 assert(state_info.status == StateInfo::ONSTACK);
446 state_info.status = StateInfo::CLOSED;
449 for (StackInfo& stk_info : scc) {
450 StateInfo& state_info = state_information_[stk_info.state_id];
451 assert(state_info.status == StateInfo::ONSTACK);
452 assert(!stk_info.nconv_qs.empty());
453 state_info.status = StateInfo::CLOSED;
460 timer.throw_if_expired();
463 auto it = scc.begin();
466 if (it->update_value()) converged =
false;
467 ++statistics_.bellman_backups;
468 }
while (++it != scc.end());
469 }
while (!converged);
473 for (StackInfo& stk_info : scc) {
474 if constexpr (UseInterval) {
475 policy->emplace_decision(
477 *stk_info.best_action,
480 policy->emplace_decision(
482 *stk_info.best_action,
483 Interval(*stk_info.value, INFINITE_VALUE));
489 stack_.erase(scc.begin(), scc.end());
A registry for print functions related to search progress.
Definition progress_report.h:33
Specifies the termination cost and goal status of a state.
Definition cost_function.h:13
bool is_goal_state() const
Check if this state is a goal.
Definition cost_function.h:34
value_t get_cost() const
Obtains the cost paid upon termination in the state.
Definition cost_function.h:41
Implements Topological Value Iteration dai:etal:jair-11.
Definition topological_value_iteration.h:68
Statistics get_statistics() const
Retreive the algorithm statistics.
Definition topological_value_iteration_impl.h:249
void print_statistics(std::ostream &out) const override
Prints algorithm statistics to the specified output stream.
Definition topological_value_iteration_impl.h:241
Namespace dedicated to Topological Value Iteration (TVI).
Definition topological_value_iteration.h:27
bool update(Interval &lhs, Interval rhs, value_t epsilon=g_epsilon)
Intersects two intervals and assigns the result to the left operand.
bool set_min(Interval &lhs, Interval rhs)
Computes the assignments lhs.lower <- min(lhs.lower, rhs.lower) and lower <- min(lhs....
double value_t
Typedef for the state value type.
Definition aliases.h:7
typename std::conditional_t< is_cheap_to_copy_v< T >, T, const T & > param_type
Alias template defining the best way to pass a parameter of a given type.
Definition type_traits.h:25
Represents a closed interval over the extended reals as a pair of lower and upper bound.
Definition interval.h:12
A StateID represents a state within a StateIDMap. Just like Fast Downward's StateID type,...
Definition types.h:22
Topological value iteration statistics.
Definition topological_value_iteration.h:32