1#ifndef GUARD_INCLUDE_PROBFD_ALGORITHMS_IDUAL_H
2#error "This file should only be included from idual.h"
5#include "probfd/algorithms/types.h"
6#include "probfd/algorithms/utils.h"
8#include "probfd/utils/not_implemented.h"
10#include "probfd/distribution.h"
11#include "probfd/evaluator.h"
12#include "probfd/transition.h"
14#include "downward/utils/countdown_timer.h"
15#include "probfd/policies/map_policy.h"
21inline void Statistics::print(std::ostream& out)
const
23 out <<
" Iterations: " << iterations << std::endl;
24 out <<
" Expansions: " << expansions << std::endl;
25 out <<
" Open states: " << open << std::endl;
26 out <<
" LP Variables: " << lp_variables << std::endl;
27 out <<
" LP Constraints: " << lp_constraints << std::endl;
30inline unsigned ValueGroup::get_id(
value_t val)
32 values_.push_back(val);
33 auto it = indices_.insert(values_.size() - 1);
41template <
typename State,
typename Action>
42IDual<State, Action>::IDual(lp::LPSolverType solver_type)
43 : lp_solver_(solver_type)
47template <
typename State,
typename Action>
48Interval IDual<State, Action>::solve(
50 EvaluatorType& heuristic,
51 param_type<State> initial_state,
52 ProgressReport progress,
55 std::vector<double> primal_solution;
56 std::vector<double> dual_solution;
67template <
typename State,
typename Action>
68auto IDual<State, Action>::compute_policy(
70 EvaluatorType& heuristic,
71 param_type<State> initial_state,
72 ProgressReport progress,
73 double max_time) -> std::unique_ptr<PolicyType>
75 std::vector<double> primal_solution;
76 std::vector<double> dual_solution;
103 const StateID initial_state_id = mdp.get_state_id(initial_state);
106 storage::PerStateStorage<std::vector<Edge>> predecessor_edges;
108 std::deque<StateID> queue{initial_state_id};
109 std::unordered_set<StateID> visited{initial_state_id};
111 std::deque<StateID> back_queue;
114 const StateID state_id = queue.front();
117 std::vector<Action> aops;
118 const State state = mdp.get_state(state_id);
119 mdp.generate_applicable_actions(state, aops);
121 auto& state_info = state_infos_[state_id];
124 state_info.status == PerStateInfo::CLOSED ||
125 state_info.status == PerStateInfo::TERMINAL);
127 if (state_info.status == PerStateInfo::TERMINAL) {
128 back_queue.push_back(state_id);
132 unsigned int constraint_index = state_infos_[state_id].constraints_idx;
133 assert(constraint_index != std::numeric_limits<unsigned>::max());
137 for (
const Action& action : aops) {
138 if (dual_solution[constraint_index + i] > g_epsilon) {
139 Distribution<StateID> distribution;
140 mdp.generate_action_transitions(state, action, distribution);
142 for (
const StateID succ_id : distribution.support()) {
143 predecessor_edges[succ_id].emplace_back(state_id, action);
144 if (!visited.insert(succ_id).second)
continue;
145 queue.push_back(succ_id);
154 back_queue.push_back(state_id);
156 }
while (!queue.empty());
161 visited.insert(back_queue.begin(), back_queue.end());
163 auto policy = std::make_unique<policies::MapPolicy<State, Action>>(&mdp);
165 while (!back_queue.empty()) {
166 const StateID state_id = back_queue.front();
167 back_queue.pop_front();
169 for (
auto& edge : predecessor_edges[state_id]) {
170 if (!visited.insert(edge.predecessor).second)
continue;
171 policy->emplace_decision(
175 primal_solution[state_infos_[edge.predecessor].var_idx]));
176 back_queue.push_back(edge.predecessor);
183template <
typename State,
typename Action>
184Interval IDual<State, Action>::solve(
186 EvaluatorType& heuristic,
187 param_type<State> initial_state,
188 ProgressReport progress,
190 std::vector<double>& primal_solution,
191 std::vector<double>& dual_solution)
193 using namespace std::views;
195 utils::CountdownTimer timer(max_time);
197 const double inf = lp_solver_.get_infinity();
199 StateID prev_state = StateID::UNDEFINED;
200 std::vector<StateID> frontier;
201 storage::StateHashMap<FrontierStateInfo> open_states;
205 const TerminationInfo term = mdp.get_termination_info(initial_state);
207 if (term.is_goal_state()) {
208 return Interval(0_vt);
211 const value_t term_cost = term.get_cost();
212 const value_t estimate = heuristic.evaluate(initial_state);
214 assert(estimate <= term_cost);
216 if (estimate == term_cost) {
217 return Interval(estimate);
220 named_vector::NamedVector<lp::LPVariable> vars;
221 named_vector::NamedVector<lp::LPConstraint> constraints;
223 vars.emplace_back(-inf, estimate, 1.0);
225 lp_solver_.load_problem(lp::LinearProgram(
226 lp::LPObjectiveSense::MAXIMIZE,
228 std::move(constraints),
230 prev_state = mdp.get_state_id(initial_state);
231 PerStateInfo& info = state_infos_[prev_state];
233 info.status = PerStateInfo::CLOSED;
234 frontier.push_back(prev_state);
237 std::vector<Transition<Action>> transitions;
241 progress.register_bound(
"v", [&] {
242 return Interval(objective, INFINITE_VALUE);
245 progress.register_print([&](std::ostream& out) {
246 out <<
"iteration=" << statistics_.iterations;
250 ++statistics_.iterations;
251 statistics_.expansions += frontier.size();
253 for (
const StateID state_id : frontier) {
254 timer.throw_if_expired();
256 const State state = mdp.get_state(state_id);
257 const TerminationInfo term_info = mdp.get_termination_info(state);
258 const auto t_cost = term_info.get_cost();
260 auto& info = state_infos_[state_id];
261 const unsigned var_id = info.var_idx;
262 assert(info.status == PerStateInfo::CLOSED);
263 info.constraints_idx = lp_solver_.get_num_constraints();
265 lp_solver_.set_variable_upper_bound(var_id, t_cost);
267 if (term_info.is_goal_state()) {
271 ClearGuard _(transitions);
272 mdp.generate_all_transitions(state, transitions);
274 for (
const auto [action, transition] : transitions) {
275 if (transition.is_dirac(state_id))
continue;
277 int next_constraint_id = lp_solver_.get_num_constraints();
278 lp::LPConstraint c(-inf, inf);
280 double base_val = mdp.get_action_cost(action);
281 StateID next_prev_state = prev_state;
284 for (
const auto& [succ_id, prob] : transition) {
285 if (succ_id == state_id) {
290 PerStateInfo& succ_info = state_infos_[succ_id];
292 if (succ_id > prev_state) {
294 state_infos_[succ_id].var_idx ==
295 std::numeric_limits<unsigned>::max());
297 const State succ_state = mdp.get_state(succ_id);
298 const auto term = mdp.get_termination_info(succ_state);
299 const value_t term_cost = term.get_cost();
300 const value_t estimate = heuristic.evaluate(succ_state);
302 if (term_cost == estimate) {
303 succ_info.status = PerStateInfo::TERMINAL;
304 succ_info.var_idx = terminals_.get_id(estimate);
305 base_val += prob * estimate;
307 int next_var_id = lp_solver_.get_num_variables();
308 lp_solver_.add_variable(
309 lp::LPVariable(-inf, estimate, 0.0),
311 std::vector<double>());
312 succ_info.status = PerStateInfo::OPEN;
313 succ_info.var_idx = next_var_id;
314 c.insert(next_var_id, -prob);
315 open_states[succ_id].incoming.push_back(
319 if (succ_id > next_prev_state) {
320 next_prev_state = succ_id;
323 assert(succ_info.status != PerStateInfo::NEW);
325 switch (succ_info.status) {
326 case PerStateInfo::OPEN:
327 open_states[succ_id].incoming.push_back(
330 case PerStateInfo::CLOSED:
331 c.insert(succ_info.var_idx, -prob);
334 base_val += prob * terminals_[succ_info.var_idx];
340 prev_state = next_prev_state;
344 c.set_upper_bound(base_val);
345 lp_solver_.add_constraint(c);
353 timer.throw_if_expired();
355 assert(lp_solver_.has_optimal_solution());
356 primal_solution = lp_solver_.extract_solution();
357 dual_solution = lp_solver_.extract_dual_solution();
358 objective = lp_solver_.get_objective_value();
360 open_states.erase_if([&](
const auto& pair) {
361 for (
auto& r : pair.second.incoming) {
362 if (dual_solution[r] > g_epsilon) {
363 state_infos_[pair.first.id].status = PerStateInfo::CLOSED;
364 frontier.emplace_back(pair.first);
373 }
while (!frontier.empty());
375 assert(!dual_solution.empty());
377 statistics_.lp_variables = lp_solver_.get_num_variables();
378 statistics_.lp_constraints = lp_solver_.get_num_constraints();
379 statistics_.open = open_states.size();
381 return Interval(objective, INFINITE_VALUE);
Namespace dedicated to the i-dual MDP algorithm.
Definition idual.h:22
double value_t
Typedef for the state value type.
Definition aliases.h:7