1#ifndef DOWNWARD_STATE_REGISTRY_H
2#define DOWNWARD_STATE_REGISTRY_H
4#include "downward/abstract_task.h"
5#include "downward/axioms.h"
6#include "downward/state_id.h"
8#include "downward/algorithms/int_hash_set.h"
9#include "downward/algorithms/int_packer.h"
10#include "downward/algorithms/segmented_vector.h"
11#include "downward/algorithms/subscriber.h"
12#include "downward/task_utils/task_properties.h"
13#include "downward/utils/hash.h"
110namespace int_packer {
114using PackedStateBin = int_packer::IntPacker::Bin;
116class StateRegistry :
public subscriber::SubscriberService<StateRegistry> {
117 struct StateIDSemanticHash {
118 const segmented_vector::SegmentedArrayVector<PackedStateBin>&
122 const segmented_vector::SegmentedArrayVector<PackedStateBin>&
125 : state_data_pool(state_data_pool)
126 , state_size(state_size)
130 int_hash_set::HashType operator()(
int id)
const
132 const PackedStateBin* data = state_data_pool[id];
133 utils::HashState hash_state;
134 for (
int i = 0; i < state_size; ++i) {
135 hash_state.feed(data[i]);
137 return hash_state.get_hash32();
141 struct StateIDSemanticEqual {
142 const segmented_vector::SegmentedArrayVector<PackedStateBin>&
145 StateIDSemanticEqual(
146 const segmented_vector::SegmentedArrayVector<PackedStateBin>&
149 : state_data_pool(state_data_pool)
150 , state_size(state_size)
154 bool operator()(
int lhs,
int rhs)
const
156 const PackedStateBin* lhs_data = state_data_pool[lhs];
157 const PackedStateBin* rhs_data = state_data_pool[rhs];
158 return std::equal(lhs_data, lhs_data + state_size, rhs_data);
168 int_hash_set::IntHashSet<StateIDSemanticHash, StateIDSemanticEqual>;
170 PlanningTaskProxy task_proxy;
171 const int_packer::IntPacker& state_packer;
172 AxiomEvaluator& axiom_evaluator;
173 const int num_variables;
175 segmented_vector::SegmentedArrayVector<PackedStateBin> state_data_pool;
176 StateIDSet registered_states;
178 std::unique_ptr<State> cached_initial_state;
180 StateID insert_id_or_pop_state();
181 int get_bins_per_state()
const;
184 explicit StateRegistry(
const PlanningTaskProxy& task_proxy);
186 const PlanningTaskProxy& get_task_proxy()
const {
return task_proxy; }
188 int get_num_variables()
const {
return num_variables; }
190 const int_packer::IntPacker& get_state_packer()
const
200 State lookup_state(StateID
id)
const;
207 State lookup_state(StateID
id, std::vector<int>&& state_values)
const;
214 const State& get_initial_state();
222 get_successor_state(
const State& predecessor,
const OperatorProxy& op);
224 template <
typename Effects>
225 State get_successor_state(
const State& predecessor,
const Effects& effects)
227 state_data_pool.push_back(predecessor.get_buffer());
228 PackedStateBin* buffer = state_data_pool[state_data_pool.size() - 1];
231 if (task_properties::has_axioms(task_proxy)) {
232 predecessor.unpack();
233 std::vector<int> new_values = predecessor.get_unpacked_values();
234 for (
auto effect : effects) {
235 if (does_fire(effect, predecessor)) {
236 FactPair effect_pair = effect.get_fact().get_pair();
237 new_values[effect_pair.var] = effect_pair.value;
240 axiom_evaluator.evaluate(new_values);
241 for (
size_t i = 0; i < new_values.size(); ++i) {
242 state_packer.set(buffer, i, new_values[i]);
244 ::StateID
id = insert_id_or_pop_state();
246 .create_state(*
this,
id, buffer, std::move(new_values));
248 for (
auto effect : effects) {
249 if (does_fire(effect, predecessor)) {
250 FactPair effect_pair = effect.get_fact().get_pair();
257 ::StateID
id = insert_id_or_pop_state();
258 return task_proxy.create_state(*
this,
id, buffer);
265 size_t size()
const {
return registered_states.size(); }
267 int get_state_size_in_bytes()
const;
269 void print_statistics(utils::LogProxy& log)
const;
271 class const_iterator {
279 friend class StateRegistry;
281 const StateRegistry* registry;
285 const_iterator(
const StateRegistry& registry,
size_t start)
287 : registry(®istry)
299 using iterator_category = std::forward_iterator_tag;
300 using value_type = StateID;
301 using pointer =
const StateID*;
302 using reference =
const StateID&;
303 using difference_type = void;
305 const_iterator& operator++()
312 operator==(
const const_iterator& lhs,
const const_iterator& rhs)
314 assert(&lhs.registry == &rhs.registry);
315 return lhs.pos == rhs.pos;
318 StateID operator*() {
return pos; }
320 StateID* operator->() {
return &pos; }
323 const_iterator begin()
const {
return const_iterator(*
this, 0); }
325 const_iterator end()
const {
return const_iterator(*
this, size()); }