1#ifndef PER_STATE_INFORMATION_H
2#define PER_STATE_INFORMATION_H
4#include "downward/state_registry.h"
6#include "downward/algorithms/segmented_vector.h"
7#include "downward/algorithms/subscriber.h"
8#include "downward/utils/collections.h"
12#include <unordered_map>
40class PerStateInformation :
public subscriber::Subscriber<StateRegistry> {
41 const Entry default_value;
42 using EntryVectorMap = std::unordered_map<
44 segmented_vector::SegmentedVector<Entry>*>;
45 EntryVectorMap entries_by_registry;
47 mutable const StateRegistry* cached_registry;
48 mutable segmented_vector::SegmentedVector<Entry>* cached_entries;
56 segmented_vector::SegmentedVector<Entry>*
57 get_entries(
const StateRegistry* registry)
59 if (cached_registry != registry) {
60 cached_registry = registry;
61 auto it = entries_by_registry.find(registry);
62 if (it == entries_by_registry.end()) {
63 cached_entries =
new segmented_vector::SegmentedVector<Entry>();
64 entries_by_registry[registry] = cached_entries;
65 registry->subscribe(
this);
67 cached_entries = it->second;
71 cached_registry == registry &&
72 cached_entries == entries_by_registry[registry]);
73 return cached_entries;
82 const segmented_vector::SegmentedVector<Entry>*
83 get_entries(
const StateRegistry* registry)
const
85 if (cached_registry != registry) {
86 const auto it = entries_by_registry.find(registry);
87 if (it == entries_by_registry.end()) {
90 cached_registry = registry;
92 const_cast<segmented_vector::SegmentedVector<Entry>*
>(
96 assert(cached_registry == registry);
97 return cached_entries;
101 PerStateInformation()
103 , cached_registry(nullptr)
104 , cached_entries(nullptr)
108 explicit PerStateInformation(
const Entry& default_value_)
109 : default_value(default_value_)
110 , cached_registry(nullptr)
111 , cached_entries(nullptr)
115 PerStateInformation(
const PerStateInformation<Entry>&) =
delete;
116 PerStateInformation& operator=(
const PerStateInformation<Entry>&) =
delete;
118 virtual ~PerStateInformation()
override
120 for (
auto it : entries_by_registry) {
125 Entry& operator[](
const State& state)
127 const StateRegistry* registry = state.get_registry();
129 std::cerr <<
"Tried to access per-state information with an "
130 <<
"unregistered state." << std::endl;
131 utils::exit_with(utils::ExitCode::SEARCH_CRITICAL_ERROR);
133 segmented_vector::SegmentedVector<Entry>* entries =
134 get_entries(registry);
135 int state_id = state.get_id().value;
136 assert(state.get_id() != StateID::no_state);
137 size_t virtual_size = registry->size();
138 assert(utils::in_bounds(state_id, *registry));
139 if (entries->size() < virtual_size) {
140 entries->resize(virtual_size, default_value);
142 return (*entries)[state_id];
145 const Entry& operator[](
const State& state)
const
147 const StateRegistry* registry = state.get_registry();
149 std::cerr <<
"Tried to access per-state information with an "
150 <<
"unregistered state." << std::endl;
151 utils::exit_with(utils::ExitCode::SEARCH_CRITICAL_ERROR);
153 const segmented_vector::SegmentedVector<Entry>* entries =
154 get_entries(registry);
156 return default_value;
158 int state_id = state.get_id().value;
159 assert(state.get_id() != StateID::no_state);
160 assert(utils::in_bounds(state_id, *registry));
161 int num_entries = entries->size();
162 if (state_id >= num_entries) {
163 return default_value;
165 return (*entries)[state_id];
169 notify_service_destroyed(
const StateRegistry* registry)
override
171 delete entries_by_registry[registry];
172 entries_by_registry.erase(registry);
173 if (registry == cached_registry) {
174 cached_registry =
nullptr;
175 cached_entries =
nullptr;