AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
acyclic_value_iteration.h
1#ifndef PROBFD_ALGORITHMS_ACYCLIC_VALUE_ITERATION_H
2#define PROBFD_ALGORITHMS_ACYCLIC_VALUE_ITERATION_H
3
4#include "probfd/storage/per_state_storage.h"
5
6#include "probfd/distribution.h"
7#include "probfd/mdp_algorithm.h"
8
9#include <stack>
10
11// Forward Declarations
12namespace utils {
13class CountdownTimer;
14}
15
16namespace probfd::policies {
17template <typename, typename>
18class MapPolicy;
19}
20
23
31template <typename State, typename Action>
33public:
34 virtual ~AcyclicValueIterationObserver() = default;
35
37 virtual void on_state_selected_for_expansion(const State&) = 0;
38
40 virtual void on_goal_state(const State&) = 0;
41
43 virtual void on_terminal_state(const State&) = 0;
44
46 virtual void on_pruned_state(const State&) = 0;
47};
48
56template <typename State, typename Action>
58 : public AcyclicValueIterationObserver<State, Action> {
59 unsigned long long state_expansions = 0;
60 unsigned long long terminal_states = 0;
61 unsigned long long goal_states = 0;
62 unsigned long long pruned_states = 0;
63
64 void on_state_selected_for_expansion(const State&);
65 void on_goal_state(const State&);
66 void on_terminal_state(const State&);
67 void on_pruned_state(const State&);
68
69 void print(std::ostream& out) const;
70};
71
72namespace internal {
73
74template <typename Action>
75struct StateInfo {
76 enum Status : uint8_t { NEW, ON_STACK, CLOSED };
77 Status status = NEW;
78 std::optional<Action> best_action = std::nullopt;
79 value_t value = -INFINITE_VALUE;
80};
81
82template <typename State, typename Action>
83class AcyclicValueIterationObserverCollection {
84 using Observer = AcyclicValueIterationObserver<State, Action>;
85
86 std::vector<Observer> observers_;
87
88public:
89 void register_observer(std::shared_ptr<Observer> observer);
90
91 void notify_state_selected_for_expansion(const State&);
92 void notify_goal_state(const State&);
93 void notify_terminal_state(const State&);
94 void notify_pruned_state(const State&);
95};
96
97} // namespace internal
98
112template <typename State, typename Action>
113class AcyclicValueIteration : public MDPAlgorithm<State, Action> {
114 using Base = typename AcyclicValueIteration::MDPAlgorithm;
115
116 using PolicyType = typename Base::PolicyType;
117 using MDPType = typename Base::MDPType;
118 using EvaluatorType = typename Base::EvaluatorType;
119
120 using MapPolicy = policies::MapPolicy<State, Action>;
121
122 using StateInfo = internal::StateInfo<Action>;
123
125 using ObserverCollection =
126 internal::AcyclicValueIterationObserverCollection<State, Action>;
127
128 struct IncrementalExpansionInfo {
129 const StateID state_id;
130 StateInfo& state_info;
131
132 // Applicable operators left to expand
133 std::vector<Action> remaining_aops;
134
135 // The current transition and transition successor
136 Distribution<StateID> transition;
137 typename Distribution<StateID>::const_iterator successor;
138
139 // The current transition Q-value
140 value_t t_value;
141
142 IncrementalExpansionInfo(StateID state_id, StateInfo& state_info);
143
144 void setup_transition(MDPType& mdp);
145
146 void backtrack_successor(value_t probability, StateInfo& succ_info);
147
148 bool advance(MDPType& mdp, MapPolicy* policy);
149
150 private:
151 bool next_successor();
152 bool next_transition(MDPType& mdp, MapPolicy* policy);
153
154 void finalize_transition();
155 void finalize_expansion(MapPolicy* policy);
156 };
157
158 storage::PerStateStorage<StateInfo> state_infos_;
159 std::stack<IncrementalExpansionInfo> expansion_stack_;
160
161 ObserverCollection observers_;
162
163public:
164 std::unique_ptr<PolicyType> compute_policy(
165 MDPType& mdp,
166 EvaluatorType& heuristic,
167 param_type<State> initial_state,
168 ProgressReport progress,
169 double max_time) override;
170
171 Interval solve(
172 MDPType& mdp,
173 EvaluatorType& heuristic,
174 param_type<State> initial_state,
175 ProgressReport progress,
176 double max_time) override;
177
178 Interval solve(
179 MDPType& mdp,
180 EvaluatorType& heuristic,
181 param_type<State> initial_state,
182 double max_time,
183 MapPolicy* policy);
184
185 void register_observer(std::shared_ptr<Observer> observer);
186
187private:
188 bool push_successor(
189 MDPType& mdp,
190 MapPolicy* policy,
191 IncrementalExpansionInfo& e,
192 utils::CountdownTimer& timer);
193
194 bool expand_state(
195 MDPType& mdp,
196 EvaluatorType& heuristic,
197 IncrementalExpansionInfo& e_info);
198};
199
200} // namespace probfd::algorithms::acyclic_vi
201
202#define GUARD_INCLUDE_PROBFD_ALGORITHMS_ACYCLIC_VALUE_ITERATION_H
203
204#include "probfd/algorithms/acyclic_value_iteration_impl.h"
205
206#undef GUARD_INCLUDE_PROBFD_ALGORITHMS_ACYCLIC_VALUE_ITERATION_H
207
208#endif // PROBFD_ALGORITHMS_ACYCLIC_VALUE_ITERATION_H
A convenience class that represents a finite probability distribution.
Definition task_state_space.h:27
Interface for MDP algorithm implementations.
Definition mdp_algorithm.h:29
A registry for print functions related to search progress.
Definition progress_report.h:33
Implements acyclic Value Iteration.
Definition acyclic_value_iteration.h:113
Models an observer subscribed to events of the acyclic value iteration algorithm.
Definition acyclic_value_iteration.h:32
virtual void on_state_selected_for_expansion(const State &)=0
Called when the algorithm selects a state for expansion.
virtual void on_pruned_state(const State &)=0
Called when a state is pruned during the expansion check.
virtual void on_goal_state(const State &)=0
Called when a goal state is encountered during the expansion check.
virtual void on_terminal_state(const State &)=0
Called when a terminal state is encountered during the expansion check.
Namespace dedicated to the acyclic value iteration algorithm.
Definition acyclic_value_iteration.h:22
double value_t
Typedef for the state value type.
Definition aliases.h:7
typename std::conditional_t< is_cheap_to_copy_v< T >, T, const T & > param_type
Alias template defining the best way to pass a parameter of a given type.
Definition type_traits.h:25
Represents a closed interval over the extended reals as a pair of lower and upper bound.
Definition interval.h:12
A StateID represents a state within a StateIDMap. Just like Fast Downward's StateID type,...
Definition types.h:22
An observer that collects basic statistics of the acyclic value iteration algorithm.
Definition acyclic_value_iteration.h:58
void on_state_selected_for_expansion(const State &)
Called when the algorithm selects a state for expansion.
Definition acyclic_value_iteration_impl.h:17
void on_goal_state(const State &)
Called when a goal state is encountered during the expansion check.
Definition acyclic_value_iteration_impl.h:24
void on_pruned_state(const State &)
Called when a state is pruned during the expansion check.
Definition acyclic_value_iteration_impl.h:36
void on_terminal_state(const State &)
Called when a terminal state is encountered during the expansion check.
Definition acyclic_value_iteration_impl.h:30