AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
per_state_array.h
1#ifndef PER_STATE_ARRAY_H
2#define PER_STATE_ARRAY_H
3
4#include "downward/per_state_information.h"
5
6#include <cassert>
7#include <unordered_map>
8
9template <class T>
10class ConstArrayView {
11 const T* p;
12 int size_;
13
14public:
15 ConstArrayView(const T* p, int size)
16 : p(p)
17 , size_(size)
18 {
19 }
20 ConstArrayView(const ConstArrayView<T>& other) = default;
21
22 ConstArrayView<T>& operator=(const ConstArrayView<T>& other) = default;
23
24 const T& operator[](int index) const
25 {
26 assert(index >= 0 && index < size_);
27 return p[index];
28 }
29
30 int size() const { return size_; }
31};
32
33template <class T>
34class ArrayView {
35 T* p;
36 int size_;
37
38public:
39 ArrayView(T* p, int size)
40 : p(p)
41 , size_(size)
42 {
43 }
44 ArrayView(const ArrayView<T>& other) = default;
45
46 ArrayView<T>& operator=(const ArrayView<T>& other) = default;
47
48 operator ConstArrayView<T>() const { return ConstArrayView<T>(p, size_); }
49
50 T& operator[](int index)
51 {
52 assert(index >= 0 && index < size_);
53 return p[index];
54 }
55
56 const T& operator[](int index) const
57 {
58 assert(index >= 0 && index < size_);
59 return p[index];
60 }
61
62 int size() const { return size_; }
63};
64
65/*
66 PerStateArray is used to associate array-like information with states.
67 PerStateArray<Entry> logically behaves somewhat like an unordered map
68 from states to equal-length arrays of class Entry. However, lookup of
69 unknown states is supported and leads to insertion of a default value
70 (similar to the defaultdict class in Python).
71
72 The implementation is similar to the one of PerStateInformation, which
73 also contains more documentation.
74*/
75
76template <class Element>
77class PerStateArray : public subscriber::Subscriber<StateRegistry> {
78 const std::vector<Element> default_array;
79 using EntryArrayVectorMap = std::unordered_map<
80 const StateRegistry*,
81 segmented_vector::SegmentedArrayVector<Element>*>;
82 EntryArrayVectorMap entry_arrays_by_registry;
83
84 mutable const StateRegistry* cached_registry;
85 mutable segmented_vector::SegmentedArrayVector<Element>* cached_entries;
86
87 segmented_vector::SegmentedArrayVector<Element>*
88 get_entries(const StateRegistry* registry)
89 {
90 if (cached_registry != registry) {
91 cached_registry = registry;
92 auto it = entry_arrays_by_registry.find(registry);
93 if (it == entry_arrays_by_registry.end()) {
94 cached_entries =
95 new segmented_vector::SegmentedArrayVector<Element>(
96 default_array.size());
97 entry_arrays_by_registry[registry] = cached_entries;
98 registry->subscribe(this);
99 } else {
100 cached_entries = it->second;
101 }
102 }
103 assert(
104 cached_registry == registry &&
105 cached_entries == entry_arrays_by_registry[registry]);
106 return cached_entries;
107 }
108
109 const segmented_vector::SegmentedArrayVector<Element>*
110 get_entries(const StateRegistry* registry) const
111 {
112 if (cached_registry != registry) {
113 const auto it = entry_arrays_by_registry.find(registry);
114 if (it == entry_arrays_by_registry.end()) {
115 return nullptr;
116 } else {
117 cached_registry = registry;
118 cached_entries = const_cast<
119 segmented_vector::SegmentedArrayVector<Element>*>(
120 it->second);
121 }
122 }
123 assert(cached_registry == registry);
124 return cached_entries;
125 }
126
127public:
128 explicit PerStateArray(const std::vector<Element>& default_array)
129 : default_array(default_array)
130 , cached_registry(nullptr)
131 , cached_entries(nullptr)
132 {
133 }
134
135 PerStateArray(const PerStateArray<Element>&) = delete;
136 PerStateArray& operator=(const PerStateArray<Element>&) = delete;
137
138 virtual ~PerStateArray() override
139 {
140 for (auto it : entry_arrays_by_registry) {
141 delete it.second;
142 }
143 }
144
145 ArrayView<Element> operator[](const State& state)
146 {
147 const StateRegistry* registry = state.get_registry();
148 if (!registry) {
149 std::cerr << "Tried to access per-state array with an unregistered "
150 << "state." << std::endl;
151 utils::exit_with(utils::ExitCode::SEARCH_CRITICAL_ERROR);
152 }
153 segmented_vector::SegmentedArrayVector<Element>* entries =
154 get_entries(registry);
155 int state_id = state.get_id().value;
156 assert(state.get_id() != StateID::no_state);
157 size_t virtual_size = registry->size();
158 assert(utils::in_bounds(state_id, *registry));
159 if (entries->size() < virtual_size) {
160 entries->resize(virtual_size, default_array.data());
161 }
162 return ArrayView<Element>((*entries)[state_id], default_array.size());
163 }
164
165 ConstArrayView<Element> operator[](const State& state) const
166 {
167 const StateRegistry* registry = state.get_registry();
168 if (!registry) {
169 std::cerr << "Tried to access per-state array with an unregistered "
170 << "state." << std::endl;
171 utils::exit_with(utils::ExitCode::SEARCH_CRITICAL_ERROR);
172 }
173 const segmented_vector::SegmentedArrayVector<Element>* entries =
174 get_entries(registry);
175 if (!entries) {
176 ABORT("PerStateArray::operator[] const tried to access "
177 "non-existing entry.");
178 }
179 int state_id = state.get_id().value;
180 assert(state.get_id() != StateID::no_state);
181 assert(utils::in_bounds(state_id, *registry));
182 int num_entries = entries->size();
183 if (state_id >= num_entries) {
184 ABORT("PerStateArray::operator[] const tried to access "
185 "non-existing entry.");
186 }
187 return ConstArrayView<Element>(
188 (*entries)[state_id],
189 default_array.size());
190 }
191
192 virtual void
193 notify_service_destroyed(const StateRegistry* registry) override
194 {
195 delete entry_arrays_by_registry[registry];
196 entry_arrays_by_registry.erase(registry);
197 if (registry == cached_registry) {
198 cached_registry = nullptr;
199 cached_entries = nullptr;
200 }
201 }
202};
203
204#endif