AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
policy_verification.h
1#ifndef VERIFICATION_POLICY_VERIFICATION_H
2#define VERIFICATION_POLICY_VERIFICATION_H
3
4#include "probfd/storage/per_state_storage.h"
5
6#include "probfd/mdp.h"
7#include "probfd/policy.h"
8
9#include <ranges>
10#include <stack>
11#include <vector>
12
13namespace tests {
14
15/*
16 * Verfies that a policy is proper for the initial state and that the state
17 * values are correct by verifying the Bellman optimality equation.
18 */
19template <typename State, typename Action>
20extern bool verify_policy(
23 probfd::StateID init_id)
24{
25 using namespace probfd;
26
27 struct StateInfo {
28 bool is_dead = true;
29 bool explored = false;
30 unsigned stack_id = std::numeric_limits<unsigned>::max();
31 };
32
33 struct ExplorationInfo {
34 ExplorationInfo(probfd::StateID state_id, unsigned stack_id)
35 : state_id(state_id)
36 , lowlink(stack_id)
37 {
38 }
39
40 probfd::StateID state_id;
41 unsigned lowlink = std::numeric_limits<unsigned>::max();
42
44 };
45
46 std::stack<ExplorationInfo> open;
47 std::vector<probfd::StateID> stack;
48 storage::PerStateStorage<StateInfo> state_infos;
49
50 open.emplace(init_id, 0);
51
52 for (;;) {
53 recurse:;
54 ExplorationInfo* info = &open.top();
55
56 probfd::StateID state_id = info->state_id;
57 State state = mdp.get_state(state_id);
58 StateInfo* state_info = &state_infos[state_id];
59
60 state_info->explored = true;
61 state_info->stack_id = stack.size();
62 stack.push_back(state_id);
63
64 std::optional decision = policy.get_decision(state);
65
66 // Check if goal. No decision in this case.
67 if (mdp.get_termination_info(state).is_goal_state()) {
68 if (decision) return false;
69 state_info->is_dead = false;
70 goto backtracking;
71 }
72
73 // Otherwise, a decision must be made.
74 if (!decision) return false;
75
76 // Generate successors.
78 state,
79 decision->action,
80 info->successors);
81
82 // Check Bellman equation
83 {
84 value_t expected_cost = mdp.get_action_cost(decision->action);
85
86 for (const auto [successor_id, probability] : info->successors) {
87 const State successor = mdp.get_state(successor_id);
88 std::optional succ_decision = policy.get_decision(successor);
89
90 const value_t succ_val =
91 succ_decision ? succ_decision->q_value_interval.lower
92 : 0_vt;
93
94 expected_cost += probability * succ_val;
95 }
96
97 if (!is_approx_equal(
98 decision->q_value_interval.lower,
99 expected_cost))
100 return false;
101 }
102
103 if (info->successors.empty()) abort();
104
105 for (;;) {
106 // DFS Expansion
107 do {
108 const probfd::StateID successor_id =
109 (info->successors.end() - 1)->item;
110 StateInfo& succ_info = state_infos[successor_id.id];
111
112 if (!succ_info.explored) {
113 open.emplace(successor_id, stack.size());
114 goto recurse;
115 }
116
117 state_info->is_dead = state_info->is_dead && succ_info.is_dead;
118 info->lowlink = std::min(info->lowlink, succ_info.stack_id);
119 info->successors.erase(info->successors.end() - 1);
120 } while (!info->successors.empty());
121
122 backtracking:;
123
124 // Backtracking
125 do {
126 const unsigned stack_id = state_info->stack_id;
127 const unsigned lowlink = info->lowlink;
128
129 // Check for SCC
130 if (stack_id == lowlink) {
131 // SCC must be able to reach the goal.
132 if (state_info->is_dead) return false;
133
134 std::ranges::subrange scc(
135 stack.begin() + stack_id,
136 stack.end());
137
138 // Erase the scc from the stack.
139 for (const probfd::StateID state_id : scc) {
140 state_infos[state_id.id].stack_id =
141 std::numeric_limits<unsigned>::max();
142 }
143
144 stack.erase(scc.begin(), scc.end());
145 }
146
147 // Backtrack from successor
148 open.pop();
149
150 if (open.empty()) return true;
151
152 info = &open.top();
153 state_id = info->state_id;
154 state = mdp.get_state(state_id);
155 state_info = &state_infos[state_id];
156
157 // The successor we backtracked from.
158 const probfd::StateID successor_id =
159 (info->successors.end() - 1)->item;
160
161 const StateInfo& succ_info = state_infos[successor_id.id];
162 state_info->is_dead = state_info->is_dead && succ_info.is_dead;
163 info->lowlink = std::min(info->lowlink, lowlink);
164 info->successors.erase(info->successors.end() - 1);
165 } while (info->successors.empty());
166 }
167 }
168}
169
170} // namespace tests
171
172#endif // VERIFICATION_POLICY_VERIFICATION_H
virtual value_t get_action_cost(param_type< Action > action)=0
Gets the cost of an action.
virtual TerminationInfo get_termination_info(param_type< State > state)=0
Returns the cost to terminate in a given state and checks whether a state is a goal.
A convenience class that represents a finite probability distribution.
Definition task_state_space.h:27
Basic interface for MDPs.
Definition mdp_algorithm.h:14
Represents a deterministic, stationary, partial policy.
Definition solver_interface.h:16
virtual std::optional< PolicyDecision< Action > > get_decision(const State &state) const =0
Retrives the action and optimal state value interval specified by the policy for a given state.
virtual State get_state(StateID state_id)=0
Get the state mapped to a given state ID.
virtual void generate_action_transitions(param_type< State > state, param_type< Action > action, Distribution< StateID > &result)=0
Generates the successor distribution for a given state and action.
bool is_goal_state() const
Check if this state is a goal.
Definition cost_function.h:34
The top-level namespace of probabilistic Fast Downward.
Definition command_line.h:8
double value_t
Typedef for the state value type.
Definition aliases.h:7
bool is_approx_equal(value_t v1, value_t v2, value_t epsilon=g_epsilon)
Equivalent to .
A StateID represents a state within a StateIDMap. Just like Fast Downward's StateID type,...
Definition types.h:22