AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
idual_impl.h
1#ifndef GUARD_INCLUDE_PROBFD_ALGORITHMS_IDUAL_H
2#error "This file should only be included from idual.h"
3#endif
4
5#include "probfd/algorithms/types.h"
6#include "probfd/algorithms/utils.h"
7
8#include "probfd/utils/not_implemented.h"
9
10#include "probfd/distribution.h"
11#include "probfd/evaluator.h"
12#include "probfd/transition.h"
13
14#include "downward/utils/countdown_timer.h"
15#include "probfd/policies/map_policy.h"
16
17#include <deque>
18
20
21inline void Statistics::print(std::ostream& out) const
22{
23 out << " Iterations: " << iterations << std::endl;
24 out << " Expansions: " << expansions << std::endl;
25 out << " Open states: " << open << std::endl;
26 out << " LP Variables: " << lp_variables << std::endl;
27 out << " LP Constraints: " << lp_constraints << std::endl;
28}
29
30inline unsigned ValueGroup::get_id(value_t val)
31{
32 values_.push_back(val);
33 auto it = indices_.insert(values_.size() - 1);
34 if (!it.second) {
35 values_.pop_back();
36 }
37
38 return *it.first;
39}
40
41template <typename State, typename Action>
42IDual<State, Action>::IDual(lp::LPSolverType solver_type)
43 : lp_solver_(solver_type)
44{
45}
46
47template <typename State, typename Action>
48Interval IDual<State, Action>::solve(
49 MDPType& mdp,
50 EvaluatorType& heuristic,
51 param_type<State> initial_state,
52 ProgressReport progress,
53 double max_time)
54{
55 std::vector<double> primal_solution;
56 std::vector<double> dual_solution;
57 return solve(
58 mdp,
59 heuristic,
60 initial_state,
61 progress,
62 max_time,
63 primal_solution,
64 dual_solution);
65}
66
67template <typename State, typename Action>
68auto IDual<State, Action>::compute_policy(
69 MDPType& mdp,
70 EvaluatorType& heuristic,
71 param_type<State> initial_state,
72 ProgressReport progress,
73 double max_time) -> std::unique_ptr<PolicyType>
74{
75 std::vector<double> primal_solution;
76 std::vector<double> dual_solution;
77 solve(
78 mdp,
79 heuristic,
80 initial_state,
81 progress,
82 max_time,
83 primal_solution,
84 dual_solution);
85
86 // The entries in the solution with a positive value represent transitions
87 // which have positive occupation measure under the computed policy, i.e.,
88 // which are selected by the policy with positive probability.
89 // The policy can however be stochastic, i.e., select multiple transitions,
90 // if traps exist.
91 // To transform the policy to a deterministic optimal one, we first
92 // exhaustively explore the policy and build the inverse edge relation.
93 // Then we exhaustively traverse the policy backwards from the terminal
94 // states with duplicate checking. Upon generating a predecessor via an
95 // inverse edge, we select the corresponding action for the policy before
96 // marking it as visited and adding it to the exploration queue.
97
98 struct Edge {
99 StateID predecessor;
100 Action action;
101 };
102
103 const StateID initial_state_id = mdp.get_state_id(initial_state);
104
105 // Explore the policy graph and build the inverse edge relation.
106 storage::PerStateStorage<std::vector<Edge>> predecessor_edges;
107
108 std::deque<StateID> queue{initial_state_id};
109 std::unordered_set<StateID> visited{initial_state_id};
110
111 std::deque<StateID> back_queue;
112
113 do {
114 const StateID state_id = queue.front();
115 queue.pop_front();
116
117 std::vector<Action> aops;
118 const State state = mdp.get_state(state_id);
119 mdp.generate_applicable_actions(state, aops);
120
121 auto& state_info = state_infos_[state_id];
122
123 assert(
124 state_info.status == PerStateInfo::CLOSED ||
125 state_info.status == PerStateInfo::TERMINAL);
126
127 if (state_info.status == PerStateInfo::TERMINAL) {
128 back_queue.push_back(state_id);
129 continue;
130 }
131
132 unsigned int constraint_index = state_infos_[state_id].constraints_idx;
133 assert(constraint_index != std::numeric_limits<unsigned>::max());
134
135 size_t actions = 0;
136 unsigned int i = 0;
137 for (const Action& action : aops) {
138 if (dual_solution[constraint_index + i] > g_epsilon) {
139 Distribution<StateID> distribution;
140 mdp.generate_action_transitions(state, action, distribution);
141
142 for (const StateID succ_id : distribution.support()) {
143 predecessor_edges[succ_id].emplace_back(state_id, action);
144 if (!visited.insert(succ_id).second) continue;
145 queue.push_back(succ_id);
146 }
147
148 ++actions;
149 }
150 ++i;
151 }
152
153 if (actions == 0) {
154 back_queue.push_back(state_id);
155 }
156 } while (!queue.empty());
157
158 // Now do the backwards exploration and extract a deterministic policy.
159 visited.clear();
160
161 visited.insert(back_queue.begin(), back_queue.end());
162
163 auto policy = std::make_unique<policies::MapPolicy<State, Action>>(&mdp);
164
165 while (!back_queue.empty()) {
166 const StateID state_id = back_queue.front();
167 back_queue.pop_front();
168
169 for (auto& edge : predecessor_edges[state_id]) {
170 if (!visited.insert(edge.predecessor).second) continue;
171 policy->emplace_decision(
172 edge.predecessor,
173 edge.action,
174 Interval(
175 primal_solution[state_infos_[edge.predecessor].var_idx]));
176 back_queue.push_back(edge.predecessor);
177 }
178 }
179
180 return policy;
181}
182
183template <typename State, typename Action>
184Interval IDual<State, Action>::solve(
185 MDPType& mdp,
186 EvaluatorType& heuristic,
187 param_type<State> initial_state,
188 ProgressReport progress,
189 double max_time,
190 std::vector<double>& primal_solution,
191 std::vector<double>& dual_solution)
192{
193 using namespace std::views;
194
195 utils::CountdownTimer timer(max_time);
196
197 const double inf = lp_solver_.get_infinity();
198
199 StateID prev_state = StateID::UNDEFINED;
200 std::vector<StateID> frontier;
201 storage::StateHashMap<FrontierStateInfo> open_states;
202
203 {
204 // initialize lp
205 const TerminationInfo term = mdp.get_termination_info(initial_state);
206
207 if (term.is_goal_state()) {
208 return Interval(0_vt);
209 }
210
211 const value_t term_cost = term.get_cost();
212 const value_t estimate = heuristic.evaluate(initial_state);
213
214 assert(estimate <= term_cost);
215
216 if (estimate == term_cost) {
217 return Interval(estimate);
218 }
219
220 named_vector::NamedVector<lp::LPVariable> vars;
221 named_vector::NamedVector<lp::LPConstraint> constraints;
222
223 vars.emplace_back(-inf, estimate, 1.0);
224
225 lp_solver_.load_problem(lp::LinearProgram(
226 lp::LPObjectiveSense::MAXIMIZE,
227 std::move(vars),
228 std::move(constraints),
229 inf));
230 prev_state = mdp.get_state_id(initial_state);
231 PerStateInfo& info = state_infos_[prev_state];
232 info.var_idx = 0;
233 info.status = PerStateInfo::CLOSED;
234 frontier.push_back(prev_state);
235 }
236
237 std::vector<Transition<Action>> transitions;
238
239 value_t objective = 0_vt;
240
241 progress.register_bound("v", [&] {
242 return Interval(objective, INFINITE_VALUE);
243 });
244
245 progress.register_print([&](std::ostream& out) {
246 out << "iteration=" << statistics_.iterations;
247 });
248
249 do {
250 ++statistics_.iterations;
251 statistics_.expansions += frontier.size();
252
253 for (const StateID state_id : frontier) {
254 timer.throw_if_expired();
255
256 const State state = mdp.get_state(state_id);
257 const TerminationInfo term_info = mdp.get_termination_info(state);
258 const auto t_cost = term_info.get_cost();
259
260 auto& info = state_infos_[state_id];
261 const unsigned var_id = info.var_idx;
262 assert(info.status == PerStateInfo::CLOSED);
263 info.constraints_idx = lp_solver_.get_num_constraints();
264
265 lp_solver_.set_variable_upper_bound(var_id, t_cost);
266
267 if (term_info.is_goal_state()) {
268 continue;
269 }
270
271 ClearGuard _(transitions);
272 mdp.generate_all_transitions(state, transitions);
273
274 for (const auto [action, transition] : transitions) {
275 if (transition.is_dirac(state_id)) continue;
276
277 int next_constraint_id = lp_solver_.get_num_constraints();
278 lp::LPConstraint c(-inf, inf);
279
280 double base_val = mdp.get_action_cost(action);
281 StateID next_prev_state = prev_state;
282 double w = 1.0;
283
284 for (const auto& [succ_id, prob] : transition) {
285 if (succ_id == state_id) {
286 w -= prob;
287 continue;
288 }
289
290 PerStateInfo& succ_info = state_infos_[succ_id];
291
292 if (succ_id > prev_state) {
293 assert(
294 state_infos_[succ_id].var_idx ==
295 std::numeric_limits<unsigned>::max());
296
297 const State succ_state = mdp.get_state(succ_id);
298 const auto term = mdp.get_termination_info(succ_state);
299 const value_t term_cost = term.get_cost();
300 const value_t estimate = heuristic.evaluate(succ_state);
301
302 if (term_cost == estimate) {
303 succ_info.status = PerStateInfo::TERMINAL;
304 succ_info.var_idx = terminals_.get_id(estimate);
305 base_val += prob * estimate;
306 } else {
307 int next_var_id = lp_solver_.get_num_variables();
308 lp_solver_.add_variable(
309 lp::LPVariable(-inf, estimate, 0.0),
310 std::vector<int>(),
311 std::vector<double>());
312 succ_info.status = PerStateInfo::OPEN;
313 succ_info.var_idx = next_var_id;
314 c.insert(next_var_id, -prob);
315 open_states[succ_id].incoming.push_back(
316 next_constraint_id);
317 }
318
319 if (succ_id > next_prev_state) {
320 next_prev_state = succ_id;
321 }
322 } else {
323 assert(succ_info.status != PerStateInfo::NEW);
324
325 switch (succ_info.status) {
326 case PerStateInfo::OPEN:
327 open_states[succ_id].incoming.push_back(
328 next_constraint_id);
329 [[fallthrough]];
330 case PerStateInfo::CLOSED:
331 c.insert(succ_info.var_idx, -prob);
332 break;
333 default:
334 base_val += prob * terminals_[succ_info.var_idx];
335 break;
336 }
337 }
338 }
339
340 prev_state = next_prev_state;
341
342 assert(w > 0_vt);
343 c.insert(var_id, w);
344 c.set_upper_bound(base_val);
345 lp_solver_.add_constraint(c);
346 }
347 }
348
349 frontier.clear();
350
351 lp_solver_.solve();
352
353 timer.throw_if_expired();
354
355 assert(lp_solver_.has_optimal_solution());
356 primal_solution = lp_solver_.extract_solution();
357 dual_solution = lp_solver_.extract_dual_solution();
358 objective = lp_solver_.get_objective_value();
359
360 open_states.erase_if([&](const auto& pair) {
361 for (auto& r : pair.second.incoming) {
362 if (dual_solution[r] > g_epsilon) { // inflow > 0
363 state_infos_[pair.first.id].status = PerStateInfo::CLOSED;
364 frontier.emplace_back(pair.first);
365 return true;
366 }
367 }
368
369 return false;
370 });
371
372 progress.print();
373 } while (!frontier.empty());
374
375 assert(!dual_solution.empty());
376
377 statistics_.lp_variables = lp_solver_.get_num_variables();
378 statistics_.lp_constraints = lp_solver_.get_num_constraints();
379 statistics_.open = open_states.size();
380
381 return Interval(objective, INFINITE_VALUE);
382}
383
384} // namespace probfd::algorithms::idual
Namespace dedicated to the i-dual MDP algorithm.
Definition idual.h:22
double value_t
Typedef for the state value type.
Definition aliases.h:7