AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
heuristic_search_base_impl.h
1#ifndef GUARD_INCLUDE_PROBFD_ALGORITHMS_HEURISTIC_SEARCH_BASE_H
2#error "This file should only be included from heuristic_search_base.h"
3#endif
4
5#include "probfd/algorithms/policy_picker.h"
6
7#include "probfd/algorithms/utils.h"
8
9#include "probfd/policies/map_policy.h"
10
11#include "probfd/utils/language.h"
12
13#include "probfd/evaluator.h"
14#include "probfd/mdp.h"
15#include "probfd/transition.h"
16
17#include "downward/utils/collections.h"
18
19#include <cassert>
20#include <deque>
21
23
24namespace internal {
25
26inline void Statistics::print(std::ostream& out) const
27{
28 out << " Initial state value estimation: " << initial_state_estimate
29 << std::endl;
30 out << " Initial state value found terminal: "
31 << initial_state_found_terminal << std::endl;
32
33 out << " Evaluated state(s): " << evaluated_states << std::endl;
34 out << " Pruned state(s): " << pruned_states << std::endl;
35 out << " Goal state(s): " << goal_states << std::endl;
36 out << " Terminal state(s): " << terminal_states << std::endl;
37 out << " Self-loop state(s): " << self_loop_states << std::endl;
38 out << " Expanded state(s): " << expanded_states << std::endl;
39 out << " Number of value updates: " << value_updates << std::endl;
40 out << " Number of value changes: " << value_changes << std::endl;
41 out << " Number of policy updates: " << policy_updates << std::endl;
42 out << " Number of policy changes: " << policy_changes << std::endl;
43
44#if defined(EXPENSIVE_STATISTICS)
45 out << " Updating time: " << update_time << std::endl;
46 out << " Policy selection time: " << policy_selection_time << std::endl;
47#endif
48}
49
50} // namespace internal
51
52template <typename State, typename Action, typename StateInfoT>
54 std::shared_ptr<PolicyPickerType> policy_chooser)
55 : policy_chooser_(policy_chooser)
56{
57}
58
59template <typename State, typename Action, typename StateInfoT>
61 StateID state_id) const
62{
63 return state_infos_[state_id].get_bounds();
64}
65
66template <typename State, typename Action, typename StateInfoT>
68 StateID state_id) const
69{
70 return state_infos_[state_id].is_value_initialized();
71}
72
73template <typename State, typename Action, typename StateInfoT>
75 CostFunctionType& cost_function,
76 StateID state_id,
77 const std::vector<TransitionType>& transitions,
78 value_t termination_cost) const -> AlgorithmValueType
79{
80#if defined(EXPENSIVE_STATISTICS)
81 TimerScope scoped_upd_timer(statistics_.update_time);
82#endif
83
84 AlgorithmValueType best_value = AlgorithmValueType(termination_cost);
85
86 for (auto& transition : transitions) {
87 const value_t cost = cost_function.get_action_cost(transition.action);
88 set_min(best_value, compute_qvalue(cost, state_id, transition));
89 }
90
91 return best_value;
92}
93
94template <typename State, typename Action, typename StateInfoT>
96 CostFunctionType& cost_function,
97 StateID state_id,
98 std::vector<TransitionType>& transitions,
99 value_t termination_cost,
100 std::vector<AlgorithmValueType>& qvalues,
101 value_t epsilon) const -> AlgorithmValueType
102{
103#if defined(EXPENSIVE_STATISTICS)
104 TimerScope scoped_upd_timer(statistics_.update_time);
105#endif
106
107 if (transitions.empty()) {
108 return AlgorithmValueType(termination_cost);
109 }
110
111 AlgorithmValueType best_value = compute_q_values(
112 cost_function,
113 state_id,
114 transitions,
115 termination_cost,
116 qvalues);
117
118 if (as_lower_bound(best_value) == termination_cost) {
119 transitions.clear();
120 qvalues.clear();
121 return AlgorithmValueType(termination_cost);
122 }
123
124 filter_greedy_transitions(transitions, qvalues, best_value, epsilon);
125
126 return best_value;
127}
128
129template <typename State, typename Action, typename StateInfoT>
131 MDPType& mdp,
132 std::optional<Action> previous_greedy,
133 std::vector<TransitionType>& transitions) -> std::optional<TransitionType>
134{
135#if defined(EXPENSIVE_STATISTICS)
136 TimerScope scoped(statistics_.policy_selection_time);
137#endif
138
139 if (transitions.empty()) return std::nullopt;
140
141 const int index = this->policy_chooser_->pick_index(
142 mdp,
143 previous_greedy,
144 transitions,
145 state_infos_);
146
147 assert(utils::in_bounds(index, transitions));
148
149 return std::move(transitions[index]);
150}
151
152template <typename State, typename Action, typename StateInfoT>
154 StateInfo& state_info,
155 AlgorithmValueType other,
156 value_t epsilon)
157{
158 ++statistics_.value_updates;
159 bool b = algorithms::update(state_info.value, other, epsilon);
160 if (b) ++statistics_.value_changes;
161 return b;
162}
163
164template <typename State, typename Action, typename StateInfoT>
166 StateInfo& state_info,
167 const std::optional<TransitionType>& transition)
168 requires(StorePolicy)
169{
170 ++statistics_.policy_updates;
171 bool b = state_info.update_policy(transition);
172 if (b) ++statistics_.policy_changes;
173 return b;
174}
175
176template <typename State, typename Action, typename StateInfoT>
179 EvaluatorType& h,
180 param_type<State> state)
181{
182 StateInfo& info = this->state_infos_[mdp.get_state_id(state)];
183
184 if (info.is_value_initialized()) return;
185
186 initialize(mdp, h, state, info);
187
188 statistics_.initial_state_estimate = info.get_value();
189 statistics_.initial_state_found_terminal = info.is_goal_or_terminal();
190}
191
192template <typename State, typename Action, typename StateInfoT>
194 MDPType& mdp,
195 EvaluatorType& h,
196 param_type<State> state,
197 StateInfo& state_info,
198 std::vector<TransitionType>& transitions)
199{
200 assert(!state_info.is_goal_or_terminal());
201 assert(transitions.empty());
202 assert(state_info.is_on_fringe());
203
204 ++statistics_.expanded_states;
205 state_info.removed_from_fringe();
206
207 mdp.generate_all_transitions(state, transitions);
208
209 if (transitions.empty()) {
210 ++statistics_.terminal_states;
211 state_info.set_terminal();
212 return;
213 }
214
215 const StateID state_id = mdp.get_state_id(state);
216
217 erase_if(transitions, [&](auto& transition) {
218 bool loop = true;
219 auto it = transition.successor_dist.begin();
220 auto end = transition.successor_dist.end();
221
222 auto loop_it = end;
223
224 for (; it != end; ++it) {
225 const auto& [succ_id, prob] = *it;
226 if (succ_id == state_id) {
227 loop_it = it;
228 continue;
230 loop = false;
231 auto& succ_info = state_infos_[succ_id];
232 if (succ_info.is_value_initialized()) continue;
233 initialize(mdp, h, mdp.get_state(succ_id), succ_info);
234 }
235
236 if (!loop && loop_it != end) {
237 value_t prob = loop_it->probability;
238 transition.successor_dist.erase(loop_it);
239 transition.successor_dist.normalize(1 / (1 - prob));
241
242 return loop;
243 });
244
245 if (transitions.empty()) {
246 ++statistics_.self_loop_states;
247 state_info.set_terminal();
248 }
249}
250
251template <typename State, typename Action, typename StateInfoT>
252void HeuristicSearchBase<State, Action, StateInfoT>::
253 generate_non_tip_transitions(
254 MDPType& mdp,
255 param_type<State> state,
256 std::vector<TransitionType>& transitions) const
257{
258 assert(transitions.empty());
259
260 mdp.generate_all_transitions(state, transitions);
261
262 const StateID state_id = mdp.get_state_id(state);
263
264 std::erase_if(transitions, [&](auto& transition) {
265 bool loop = true;
266
267 for (StateID succ_id : transition.successor_dist.support()) {
268 if (succ_id != state_id) loop = false;
269 }
270
271 return loop;
272 });
273}
274
275template <typename State, typename Action, typename StateInfoT>
276void HeuristicSearchBase<State, Action, StateInfoT>::print_statistics(
277 std::ostream& out) const
278{
279 out << " Stored " << sizeof(StateInfo) << " bytes per state" << std::endl;
280 statistics_.print(out);
281}
282
283template <typename State, typename Action, typename StateInfoT>
284void HeuristicSearchBase<State, Action, StateInfoT>::initialize(
285 MDPType& mdp,
286 EvaluatorType& h,
287 param_type<State> state,
288 StateInfo& state_info)
289{
290 assert(!state_info.is_value_initialized());
291
292 statistics_.evaluated_states++;
293
294 TerminationInfo term = mdp.get_termination_info(state);
295 const value_t t_cost = term.get_cost();
296
297 if (term.is_goal_state()) {
298 statistics_.goal_states++;
299 state_info.set_goal();
300 state_info.value = AlgorithmValueType(t_cost);
301 return;
302 }
303
304 const value_t estimate = h.evaluate(state);
305
306 if constexpr (UseInterval) {
307 state_info.value = Interval(estimate, t_cost);
308 } else {
309 state_info.value = estimate;
310 }
311
312 if (estimate == t_cost) {
313 statistics_.pruned_states++;
314 state_info.set_terminal();
315 } else {
316 state_info.set_on_fringe();
317 }
318}
319
320template <typename State, typename Action, typename StateInfoT>
321auto HeuristicSearchBase<State, Action, StateInfoT>::compute_qvalue(
322 value_t action_cost,
323 StateID state_id,
324 const TransitionType& transition) const -> AlgorithmValueType
325{
326 AlgorithmValueType t_value(action_cost);
327
328 value_t non_loop_prob = 1_vt;
329
330 for (const auto& [succ_id, prob] : transition.successor_dist) {
331 if (state_id == succ_id) {
332 non_loop_prob -= prob;
333 continue;
334 }
335
336 t_value += prob * state_infos_[succ_id].value;
338
339 assert(non_loop_prob > 0_vt);
340
341 return t_value * (1_vt / non_loop_prob);
342}
343
344template <typename State, typename Action, typename StateInfoT>
345auto HeuristicSearchBase<State, Action, StateInfoT>::compute_q_values(
346 CostFunctionType& cost_function,
347 StateID state_id,
348 std::vector<TransitionType>& transitions,
349 value_t termination_cost,
350 std::vector<AlgorithmValueType>& qvalues) const -> AlgorithmValueType
351{
352 AlgorithmValueType best_value(termination_cost);
353
354 qvalues.reserve(transitions.size());
355
356 for (const auto& transition : transitions) {
357 const value_t cost = cost_function.get_action_cost(transition.action);
358 auto q = compute_qvalue(cost, state_id, transition);
359 set_min(best_value, q);
360 qvalues.push_back(q);
361 }
362
363 return best_value;
364}
365
366template <typename State, typename Action, typename StateInfoT>
367auto HeuristicSearchBase<State, Action, StateInfoT>::filter_greedy_transitions(
368 std::vector<TransitionType>& transitions,
369 std::vector<AlgorithmValueType>& qvalues,
370 const AlgorithmValueType& best_value,
371 value_t epsilon) const -> AlgorithmValueType
372{
373 auto view = std::views::zip(transitions, qvalues);
374 auto [it, end] = std::ranges::remove_if(
375 view,
376 [&](const AlgorithmValueType& value) {
377 return !is_approx_equal(
378 as_lower_bound(best_value),
379 as_lower_bound(value),
380 epsilon);
381 },
382 project<1>);
383
384 const size_t offset = std::distance(view.begin(), it);
385 transitions.erase(transitions.begin() + offset, transitions.end());
386
387 return best_value;
388}
389
390template <typename State, typename Action, typename StateInfoT>
391Interval HeuristicSearchAlgorithm<State, Action, StateInfoT>::solve(
392 MDPType& mdp,
393 EvaluatorType& h,
394 param_type<State> state,
395 ProgressReport progress,
396 double max_time)
397{
398 HSBase::initialize_initial_state(mdp, h, state);
399 return this->do_solve(mdp, h, state, progress, max_time);
400}
401
402template <typename State, typename Action, typename StateInfoT>
403auto HeuristicSearchAlgorithm<State, Action, StateInfoT>::compute_policy(
404 MDPType& mdp,
405 EvaluatorType& h,
406 param_type<State> initial_state,
407 ProgressReport progress,
408 double max_time) -> std::unique_ptr<PolicyType>
409{
410 this->solve(mdp, h, initial_state, progress, max_time);
411
412 /*
413 * Expand some greedy policy graph, starting from the initial state.
414 * Collect optimal actions along the way.
415 */
416 using MapPolicy = policies::MapPolicy<State, Action>;
417 std::unique_ptr<MapPolicy> policy(new MapPolicy(&mdp));
418
419 const StateID initial_state_id = mdp.get_state_id(initial_state);
420
421 std::deque<StateID> queue;
422 std::set<StateID> visited;
423 queue.push_back(initial_state_id);
424 visited.insert(initial_state_id);
425
426 std::vector<TransitionType> transitions;
427 std::vector<AlgorithmValueType> qvalues;
428
429 do {
430 const StateID state_id = queue.front();
431 queue.pop_front();
432
433 std::optional<Action> action;
434
435 if constexpr (HSBase::StorePolicy) {
436 const StateInfo& state_info = this->state_infos_[state_id];
437 action = state_info.get_policy();
438 } else {
439 const State state = mdp.get_state(state_id);
440 const value_t termination_cost =
441 mdp.get_termination_info(state).get_cost();
442
443 ClearGuard _(transitions, qvalues);
444 this->generate_non_tip_transitions(mdp, state, transitions);
445
446 this->compute_bellman_and_greedy(
447 mdp,
448 state_id,
449 transitions,
450 termination_cost,
451 qvalues);
452
453 action =
454 this->select_greedy_transition(mdp, std::nullopt, transitions)
455 .transform([](const auto& t) { return t.action; });
456 }
457
458 // Terminal states have no policy decision.
459 if (!action) {
460 continue;
461 }
462
463 const Interval bound = this->lookup_bounds(state_id);
464
465 policy->emplace_decision(state_id, *action, bound);
466
467 // Push the successor traps.
468 const State state = mdp.get_state(state_id);
469
470 Distribution<StateID> successors;
471 mdp.generate_action_transitions(state, *action, successors);
472
473 for (const StateID succ_id : successors.support()) {
474 if (visited.insert(succ_id).second) {
475 queue.push_back(succ_id);
476 }
477 }
478 } while (!queue.empty());
479
480 return policy;
481}
482
483template <typename State, typename Action, typename StateInfoT>
485 std::ostream& out) const
486{
487 HSBase::print_statistics(out);
488 this->print_additional_statistics(out);
489}
490
491} // namespace probfd::algorithms::heuristic_search
virtual StateID get_state_id(param_type< State > state)=0
Get the state ID for a given state.
void print_statistics(std::ostream &out) const final
Prints algorithm statistics to the specified output stream.
Definition heuristic_search_base_impl.h:484
The common base class for MDP h search algorithms.
Definition heuristic_search_base.h:113
bool update_value(StateInfo &state_info, AlgorithmValueType other, value_t epsilon=g_epsilon)
Updates the value of the state associated with the given storage.
Definition heuristic_search_base_impl.h:153
AlgorithmValueType compute_bellman(CostFunctionType &cost_function, StateID state_id, const std::vector< TransitionType > &transitions, value_t termination_cost) const
Computes the Bellman operator value for a state.
Definition heuristic_search_base_impl.h:74
Interval lookup_bounds(StateID state_id) const
Looks up the current value interval of state_id.
Definition heuristic_search_base_impl.h:60
bool update_policy(StateInfo &state_info, const std::optional< TransitionType > &transition)
Updates the current greedy action of the state associated with the given storage.
Definition heuristic_search_base_impl.h:165
std::optional< TransitionType > select_greedy_transition(MDPType &mdp, std::optional< Action > previous_greedy_action, std::vector< TransitionType > &greedy_transitions)
Selects a greedy transition from the given list of greedy transitions through the policy selector pas...
Definition heuristic_search_base_impl.h:130
bool was_visited(StateID state_id) const
Checks if the state represented by state_id has been visited yet.
Definition heuristic_search_base_impl.h:67
AlgorithmValueType compute_bellman_and_greedy(CostFunctionType &cost_function, StateID state_id, std::vector< TransitionType > &transitions, value_t termination_cost, std::vector< AlgorithmValueType > &qvalues, value_t epsilon=g_epsilon) const
Computes the Bellman operator value for a state, as well as all transitions achieving a value epsilon...
Definition heuristic_search_base_impl.h:95
Namespace dedicated to the MDP h search base implementation.
Definition heuristic_search_base.h:47
bool update(Interval &lhs, Interval rhs, value_t epsilon=g_epsilon)
Intersects two intervals and assigns the result to the left operand.
bool set_min(Interval &lhs, Interval rhs)
Computes the assignments lhs.lower <- min(lhs.lower, rhs.lower) and lower <- min(lhs....
value_t as_lower_bound(Interval interval)
Returns the lower bound of the interval.
double value_t
Typedef for the state value type.
Definition aliases.h:7
typename std::conditional_t< is_cheap_to_copy_v< T >, T, const T & > param_type
Alias template defining the best way to pass a parameter of a given type.
Definition type_traits.h:25
bool is_approx_equal(value_t v1, value_t v2, value_t epsilon=g_epsilon)
Equivalent to .
Represents a closed interval over the extended reals as a pair of lower and upper bound.
Definition interval.h:12
A StateID represents a state within a StateIDMap. Just like Fast Downward's StateID type,...
Definition types.h:22
void print(std::ostream &out) const
Prints the statistics to the specified output stream.
Definition heuristic_search_base_impl.h:26