AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
per_state_information.h
1#ifndef PER_STATE_INFORMATION_H
2#define PER_STATE_INFORMATION_H
3
4#include "downward/state_registry.h"
5
6#include "downward/algorithms/segmented_vector.h"
7#include "downward/algorithms/subscriber.h"
8#include "downward/utils/collections.h"
9
10#include <cassert>
11#include <iostream>
12#include <unordered_map>
13
14/*
15 PerStateInformation is used to associate information with states.
16 PerStateInformation<Entry> logically behaves somewhat like an unordered map
17 from states to objects of class Entry. However, lookup of unknown states is
18 supported and leads to insertion of a default value (similar to the
19 defaultdict class in Python).
20
21 For example, search algorithms can use it to associate g values or create
22 operators with a state.
23
24 Implementation notes: PerStateInformation is essentially implemented as a
25 kind of two-level map:
26 1. Find the correct SegmentedVector for the registry of the given state.
27 2. Look up the associated entry in the SegmentedVector based on the ID of
28 the state.
29 It is common in many use cases that we look up information for states from
30 the same registry in sequence. Therefore, to make step 1. more efficient, we
31 remember (in "cached_registry" and "cached_entries") the results of the
32 previous lookup and reuse it on consecutive lookups for the same registry.
33
34 A PerStateInformation object subscribes to every StateRegistry for which it
35 stores information. Once a StateRegistry is destroyed, it notifies all
36 subscribed objects, which in turn destroy all information stored for states
37 in that registry.
38*/
39template <class Entry>
40class PerStateInformation : public subscriber::Subscriber<StateRegistry> {
41 const Entry default_value;
42 using EntryVectorMap = std::unordered_map<
43 const StateRegistry*,
44 segmented_vector::SegmentedVector<Entry>*>;
45 EntryVectorMap entries_by_registry;
46
47 mutable const StateRegistry* cached_registry;
48 mutable segmented_vector::SegmentedVector<Entry>* cached_entries;
49
50 /*
51 Returns the SegmentedVector associated with the given StateRegistry.
52 If no vector is associated with this registry yet, an empty one is
53 created. Both the registry and the returned vector are cached to speed up
54 consecutive calls with the same registry.
55 */
56 segmented_vector::SegmentedVector<Entry>*
57 get_entries(const StateRegistry* registry)
58 {
59 if (cached_registry != registry) {
60 cached_registry = registry;
61 auto it = entries_by_registry.find(registry);
62 if (it == entries_by_registry.end()) {
63 cached_entries = new segmented_vector::SegmentedVector<Entry>();
64 entries_by_registry[registry] = cached_entries;
65 registry->subscribe(this);
66 } else {
67 cached_entries = it->second;
68 }
69 }
70 assert(
71 cached_registry == registry &&
72 cached_entries == entries_by_registry[registry]);
73 return cached_entries;
74 }
75
76 /*
77 Returns the SegmentedVector associated with the given StateRegistry.
78 Returns nullptr, if no vector is associated with this registry yet.
79 Otherwise, both the registry and the returned vector are cached to speed
80 up consecutive calls with the same registry.
81 */
82 const segmented_vector::SegmentedVector<Entry>*
83 get_entries(const StateRegistry* registry) const
84 {
85 if (cached_registry != registry) {
86 const auto it = entries_by_registry.find(registry);
87 if (it == entries_by_registry.end()) {
88 return nullptr;
89 } else {
90 cached_registry = registry;
91 cached_entries =
92 const_cast<segmented_vector::SegmentedVector<Entry>*>(
93 it->second);
94 }
95 }
96 assert(cached_registry == registry);
97 return cached_entries;
98 }
99
100public:
101 PerStateInformation()
102 : default_value()
103 , cached_registry(nullptr)
104 , cached_entries(nullptr)
105 {
106 }
107
108 explicit PerStateInformation(const Entry& default_value_)
109 : default_value(default_value_)
110 , cached_registry(nullptr)
111 , cached_entries(nullptr)
112 {
113 }
114
115 PerStateInformation(const PerStateInformation<Entry>&) = delete;
116 PerStateInformation& operator=(const PerStateInformation<Entry>&) = delete;
117
118 virtual ~PerStateInformation() override
119 {
120 for (auto it : entries_by_registry) {
121 delete it.second;
122 }
123 }
124
125 Entry& operator[](const State& state)
126 {
127 const StateRegistry* registry = state.get_registry();
128 if (!registry) {
129 std::cerr << "Tried to access per-state information with an "
130 << "unregistered state." << std::endl;
131 utils::exit_with(utils::ExitCode::SEARCH_CRITICAL_ERROR);
132 }
133 segmented_vector::SegmentedVector<Entry>* entries =
134 get_entries(registry);
135 int state_id = state.get_id().value;
136 assert(state.get_id() != StateID::no_state);
137 size_t virtual_size = registry->size();
138 assert(utils::in_bounds(state_id, *registry));
139 if (entries->size() < virtual_size) {
140 entries->resize(virtual_size, default_value);
141 }
142 return (*entries)[state_id];
143 }
144
145 const Entry& operator[](const State& state) const
146 {
147 const StateRegistry* registry = state.get_registry();
148 if (!registry) {
149 std::cerr << "Tried to access per-state information with an "
150 << "unregistered state." << std::endl;
151 utils::exit_with(utils::ExitCode::SEARCH_CRITICAL_ERROR);
152 }
153 const segmented_vector::SegmentedVector<Entry>* entries =
154 get_entries(registry);
155 if (!entries) {
156 return default_value;
157 }
158 int state_id = state.get_id().value;
159 assert(state.get_id() != StateID::no_state);
160 assert(utils::in_bounds(state_id, *registry));
161 int num_entries = entries->size();
162 if (state_id >= num_entries) {
163 return default_value;
164 }
165 return (*entries)[state_id];
166 }
167
168 virtual void
169 notify_service_destroyed(const StateRegistry* registry) override
170 {
171 delete entries_by_registry[registry];
172 entries_by_registry.erase(registry);
173 if (registry == cached_registry) {
174 cached_registry = nullptr;
175 cached_entries = nullptr;
176 }
177 }
178};
179
180#endif