339 assert(non_loop_prob > 0_vt);
341 return t_value * (1_vt / non_loop_prob);
344template <
typename State,
typename Action,
typename StateInfoT>
345auto HeuristicSearchBase<State, Action, StateInfoT>::compute_q_values(
346 CostFunctionType& cost_function,
348 std::vector<TransitionType>& transitions,
350 std::vector<AlgorithmValueType>& qvalues)
const -> AlgorithmValueType
352 AlgorithmValueType best_value(termination_cost);
354 qvalues.reserve(transitions.size());
356 for (
const auto& transition : transitions) {
357 const value_t cost = cost_function.get_action_cost(transition.action);
358 auto q = compute_qvalue(cost, state_id, transition);
360 qvalues.push_back(q);
366template <
typename State,
typename Action,
typename StateInfoT>
367auto HeuristicSearchBase<State, Action, StateInfoT>::filter_greedy_transitions(
368 std::vector<TransitionType>& transitions,
369 std::vector<AlgorithmValueType>& qvalues,
370 const AlgorithmValueType& best_value,
371 value_t epsilon)
const -> AlgorithmValueType
373 auto view = std::views::zip(transitions, qvalues);
374 auto [it, end] = std::ranges::remove_if(
376 [&](
const AlgorithmValueType& value) {
384 const size_t offset = std::distance(view.begin(), it);
385 transitions.erase(transitions.begin() + offset, transitions.end());
390template <
typename State,
typename Action,
typename StateInfoT>
391Interval HeuristicSearchAlgorithm<State, Action, StateInfoT>::solve(
395 ProgressReport progress,
398 HSBase::initialize_initial_state(mdp, h, state);
399 return this->do_solve(mdp, h, state, progress, max_time);
402template <
typename State,
typename Action,
typename StateInfoT>
403auto HeuristicSearchAlgorithm<State, Action, StateInfoT>::compute_policy(
407 ProgressReport progress,
408 double max_time) -> std::unique_ptr<PolicyType>
410 this->solve(mdp, h, initial_state, progress, max_time);
416 using MapPolicy = policies::MapPolicy<State, Action>;
417 std::unique_ptr<MapPolicy> policy(
new MapPolicy(&mdp));
419 const StateID initial_state_id = mdp.get_state_id(initial_state);
421 std::deque<StateID> queue;
422 std::set<StateID> visited;
423 queue.push_back(initial_state_id);
424 visited.insert(initial_state_id);
426 std::vector<TransitionType> transitions;
427 std::vector<AlgorithmValueType> qvalues;
430 const StateID state_id = queue.front();
433 std::optional<Action> action;
435 if constexpr (HSBase::StorePolicy) {
436 const StateInfo& state_info = this->state_infos_[state_id];
437 action = state_info.get_policy();
439 const State state = mdp.get_state(state_id);
440 const value_t termination_cost =
441 mdp.get_termination_info(state).get_cost();
443 ClearGuard _(transitions, qvalues);
444 this->generate_non_tip_transitions(mdp, state, transitions);
446 this->compute_bellman_and_greedy(
454 this->select_greedy_transition(mdp, std::nullopt, transitions)
455 .transform([](
const auto& t) {
return t.action; });
463 const Interval bound = this->lookup_bounds(state_id);
465 policy->emplace_decision(state_id, *action, bound);
468 const State state = mdp.get_state(state_id);
470 Distribution<StateID> successors;
471 mdp.generate_action_transitions(state, *action, successors);
473 for (
const StateID succ_id : successors.support()) {
474 if (visited.insert(succ_id).second) {
475 queue.push_back(succ_id);
478 }
while (!queue.empty());
483template <
typename State,
typename Action,
typename StateInfoT>
485 std::ostream& out)
const
487 HSBase::print_statistics(out);
488 this->print_additional_statistics(out);