AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
distribution.h
1#ifndef PROBFD_DISTRIBUTION_H
2#define PROBFD_DISTRIBUTION_H
3
4#include "probfd/value_type.h"
5
6#include "downward/utils/rng.h"
7
8#include <algorithm>
9#include <cassert>
10#include <compare>
11#include <ranges>
12#include <type_traits>
13#include <utility>
14#include <vector>
15
16namespace probfd {
17
19template <typename T, typename PrType = value_t>
21 template <
22 typename... Args,
23 typename... Args2,
24 size_t... Indices,
25 size_t... Indices2>
27 std::tuple<Args...> args,
28 std::tuple<Args2...> args2,
29 std::index_sequence<Indices...>,
30 std::index_sequence<Indices2...>)
31 : item(std::get<Indices>(args)...)
32 , probability(std::get<Indices2>(args2)...)
33 {
34 }
35
36public:
37 T item;
38 PrType probability;
39
42 requires(std::is_default_constructible_v<T> &&
43 std::is_default_constructible_v<PrType>)
44 = default;
45
46 template <typename A, typename B>
47 requires(std::is_constructible_v<T, A> &&
48 std::is_constructible_v<PrType, B>)
49 explicit(!std::is_convertible_v<T, A> || !std::is_convertible_v<PrType, B>)
50 ItemProbabilityPair(std::pair<A, B> p)
51 : item(std::get<0>(p))
52 , probability(std::get<1>(p))
53 {
54 }
55
56 template <typename A, typename B>
57 requires(std::is_constructible_v<T, A> &&
58 std::is_constructible_v<PrType, B>)
59 explicit(!std::is_convertible_v<T, A> || !std::is_convertible_v<PrType, B>)
60 ItemProbabilityPair(std::tuple<A, B> p)
61 : item(std::get<0>(p))
62 , probability(std::get<1>(p))
63 {
64 }
65
66 template <typename A>
67 requires(std::is_constructible_v<T, A> &&
68 std::is_constructible_v<PrType, A>)
69 explicit(!std::is_convertible_v<T, A> || !std::is_convertible_v<PrType, A>)
70 ItemProbabilityPair(std::array<A, 2> p)
71 : item(std::get<0>(p))
72 , probability(std::get<1>(p))
73 {
74 }
75
77 template <typename A, typename B>
78 requires(std::is_constructible_v<T, A> &&
79 std::is_constructible_v<PrType, B>)
80 explicit(!std::is_convertible_v<T, A> || !std::is_convertible_v<PrType, B>)
82 : item(std::forward<A>(item))
83 , probability(std::forward<B>(probability))
84 {
85 }
86
89 template <typename... Args, typename... Args2>
91 std::piecewise_construct_t,
92 std::tuple<Args...> t1,
93 std::tuple<Args2...> t2)
95 t1,
96 t2,
97 std::index_sequence_for<Args...>{},
98 std::index_sequence_for<Args2...>{})
99 {
100 }
101
103 friend auto operator<=>(
105 const ItemProbabilityPair<T, PrType>& right) = default;
106
107 operator std::pair<T, PrType>() const
108 {
109 return std::make_pair(item, probability);
110 }
111
112 template <std::size_t Index>
113 auto& get() &
114 {
115 if constexpr (Index == 0) return item;
116 if constexpr (Index == 1) return probability;
117 }
118
119 template <std::size_t Index>
120 const auto& get() const&
121 {
122 if constexpr (Index == 0) return item;
123 if constexpr (Index == 1) return probability;
124 }
125
126 template <std::size_t Index>
127 auto&& get() &&
128 {
129 if constexpr (Index == 0) return std::move(item);
130 if constexpr (Index == 1) return std::move(probability);
131 }
132
133 template <std::size_t Index>
134 const auto&& get() const&&
135 {
136 if constexpr (Index == 0) return std::move(item);
137 if constexpr (Index == 1) return std::move(probability);
138 }
139};
140
143
147
154template <typename T>
155class Distribution {
156 std::vector<ItemProbabilityPair<T>> distribution_;
157
158public:
159 using iterator = typename std::vector<ItemProbabilityPair<T>>::iterator;
160 using const_iterator =
161 typename std::vector<ItemProbabilityPair<T>>::const_iterator;
162
163 Distribution() = default;
164
165 Distribution(std::initializer_list<ItemProbabilityPair<T>> list)
166 : distribution_(list)
167 {
168 normalize();
169 }
170
171 Distribution(
172 std::initializer_list<ItemProbabilityPair<T>> list,
173 no_normalize_t)
174 : distribution_(list)
175 {
176 }
177
178 template <std::ranges::input_range R>
179 requires(std::convertible_to<
180 std::ranges::range_reference_t<R>,
181 ItemProbabilityPair<T>>)
182 explicit Distribution(std::from_range_t, R&& pair_range)
183 : distribution_(std::from_range, std::forward<R>(pair_range))
184 {
185 normalize();
186 }
187
188 template <std::ranges::input_range R>
189 requires(std::convertible_to<
190 std::ranges::range_reference_t<R>,
191 ItemProbabilityPair<T>>)
192 explicit Distribution(std::from_range_t, no_normalize_t, R&& pair_range)
193#ifdef __cpp_lib_containers_ranges
194 : distribution_(std::from_range, std::forward<R>(pair_range))
195#else
196 : distribution_(
197 std::ranges::begin(pair_range),
198 std::ranges::end(pair_range))
199#endif
200 {
201 }
202
207 void reserve(size_t capacity) { distribution_.reserve(capacity); }
208
212 [[nodiscard]]
213 bool empty() const
214 {
215 return distribution_.empty();
216 }
217
221 [[nodiscard]]
222 size_t size() const
223 {
224 return distribution_.size();
225 }
226
227 void clear() { distribution_.clear(); }
228
229 void swap(Distribution<T>& other)
230 {
231 other.distribution_.swap(distribution_);
232 }
233
234 void add_probability(T t, value_t prob)
235 {
236 assert(prob > 0.0);
237
238 auto it = this->find(t);
239
240 if (it != end()) {
241 it->probability += prob;
242 return;
243 }
244
245 distribution_.emplace(it, std::move(t), prob);
246 }
247
248 auto find(this auto&& self, const T& t)
249 {
250 auto it = std::ranges::lower_bound(
251 self.distribution_,
252 t,
253 std::less<>{},
255
256 if (it == self.end() || it->item == t) {
257 return it;
258 }
259
260 return self.end();
261 }
262
267 bool is_dirac(const T& t) const
268 {
269 return size() == 1 && distribution_.front().item == t;
270 }
271
275 [[nodiscard]]
276 bool is_dirac() const
277 {
278 return size() == 1;
279 }
280
285 template <typename RandomVariable>
286 requires(std::invocable<RandomVariable, T>)
287 value_t expectation(RandomVariable rv) const
288 {
289 value_t ex = 0;
290 for (const auto [succ, prob] : distribution_) {
291 ex += prob * rv(succ);
292 }
293 return ex;
294 }
295
300 template <typename RandomVariable>
301 requires requires(RandomVariable& rv, const T& t) {
302 { rv[t] } -> std::convertible_to<value_t>;
303 }
304 value_t expectation(RandomVariable rv) const
305 {
306 value_t ex = 0_vt;
307 for (const auto [succ, prob] : distribution_) {
308 ex += prob * rv[succ];
309 }
310 return ex;
311 }
312
316 void normalize(value_t scale)
317 {
318 for (auto& pair : distribution_) {
319 pair.probability *= scale;
320 }
321 }
322
327 {
328 if (empty()) {
329 return;
330 }
331 value_t sum = 0;
332 for (auto& pair : distribution_) {
333 sum += pair.probability;
334 }
335 normalize(1_vt / sum);
336 }
337
338 auto sample(utils::RandomNumberGenerator& rng)
339 {
340 assert(!empty());
341
342 // Important!
343 normalize();
344
345 const value_t r = rng.random();
346
347 auto it = distribution_.begin();
348 value_t sum = it->probability;
349
350 while (sum <= r) {
351 sum += (++it)->probability;
352 }
353
354 return it;
355 }
356
357 auto sample(utils::RandomNumberGenerator& rng) const
358 {
359 return static_cast<const_iterator>(
360 const_cast<Distribution<T>*>(this)->sample(rng));
361 }
362
369 iterator erase(iterator it) { return distribution_.erase(it); }
370
377 iterator erase(iterator it, iterator last)
378 {
379 return distribution_.erase(it, last);
380 }
381
382 template <typename UnaryPredicate>
383 size_t remove_if(UnaryPredicate pred)
384 {
385 return std::erase_if(distribution_, pred);
386 }
387
388 template <typename UnaryPredicate>
389 value_t remove_if_normalize(UnaryPredicate pred)
390 {
391 value_t normalize_factor = 0_vt;
392
393 std::erase_if(
394 this->distribution_,
395 [&pred, &normalize_factor](auto& target) {
396 if (pred(target)) {
397 normalize_factor += target.probability;
398 return true;
399 }
400 return false;
401 });
402
403 if (normalize_factor > 0) {
404 this->normalize(1_vt / (1_vt - normalize_factor));
405 }
406
407 return normalize_factor;
408 }
409
410 value_t remove_if_normalize(const T& t)
411 {
412 return remove_if_normalize(
413 [&](const auto& elem) { return elem.item == t; });
414 }
415
416 auto begin(this auto&& self) { return self.distribution_.begin(); }
417 auto end(this auto&& self) { return self.distribution_.end(); }
418
419 auto support(this auto&& self)
420 {
421 return std::views::transform(
422 self.distribution_,
424 }
425
426 friend auto
427 operator<=>(const Distribution<T>& left, const Distribution<T>& right) =
428 default;
429};
430
431namespace detail {
432
433template <typename T>
434struct is_item_prob_pair : std::false_type {};
435
436template <typename T>
437struct is_item_prob_pair<ItemProbabilityPair<T>> : std::true_type {};
438
439template <typename T>
440constexpr bool is_item_prob_pair_v = is_item_prob_pair<T>::value;
441
442template <typename T>
443struct item {};
444
445template <typename T>
446struct item<ItemProbabilityPair<T>> {
447 using type = T;
448};
449
450template <typename T>
451using item_t = typename item<T>::type;
452
453}
454
455template <std::ranges::input_range R>
456 requires(detail::is_item_prob_pair_v<std::ranges::range_value_t<R>>)
457Distribution(std::from_range_t, R&&)
458 -> Distribution<detail::item_t<std::ranges::range_value_t<R>>>;
459
460} // namespace probfd
461
462template <typename T, typename F>
463struct std::tuple_size<probfd::ItemProbabilityPair<T, F>>
464 : public integral_constant<std::size_t, 2> {};
465
466template <std::size_t I, typename T, typename F>
467struct std::tuple_element<I, probfd::ItemProbabilityPair<T, F>> {
468 static_assert(false, "Invalid index");
469};
470
471template <typename T, typename F>
472struct std::tuple_element<0, probfd::ItemProbabilityPair<T, F>> {
473 using type = T;
474};
475
476template <typename T, typename F>
477struct std::tuple_element<1, probfd::ItemProbabilityPair<T, F>> {
478 using type = F;
479};
480
481#endif // PROBFD_DISTRIBUTION_H
void reserve(size_t capacity)
Reserves space for capacity number of elements in the support of the distribution.
Definition distribution.h:207
value_t expectation(RandomVariable rv) const
Computes the expectation over a real random variable according to the distribution.
Definition distribution.h:287
size_t size() const
Returns the size of the support.
Definition distribution.h:222
iterator erase(iterator it)
Removes the element-probability pair pointed to by it.
Definition distribution.h:369
bool is_dirac() const
Checks if the distribution is a Dirac distribution.
Definition distribution.h:276
void normalize(value_t scale)
Scales all element probablities by a common factor.
Definition distribution.h:316
bool is_dirac(const T &t) const
Checks if the distribution is a Dirac distribution wrt an element.
Definition distribution.h:267
value_t expectation(RandomVariable rv) const
Computes the expectation over a real random variable according to the distribution.
Definition distribution.h:304
void normalize()
Normalizes the probabilities of the elements to sum up to one.
Definition distribution.h:326
iterator erase(iterator it, iterator last)
Removes a range of element-probability pairs.
Definition distribution.h:377
bool empty() const
Checks if the distribution is in an empty state.
Definition distribution.h:213
An item-probability pair.
Definition distribution.h:20
PrType probability
The probability of the item.
Definition distribution.h:38
ItemProbabilityPair()=default
Pairs a default-constructed item with an indeterminate probability.
T item
The item.
Definition distribution.h:37
ItemProbabilityPair(std::piecewise_construct_t, std::tuple< Args... > t1, std::tuple< Args2... > t2)
Pairs an item constructed from a tuple of constructor arguments with a given probability.
Definition distribution.h:90
friend auto operator<=>(const ItemProbabilityPair< T, PrType > &left, const ItemProbabilityPair< T, PrType > &right)=default
Lexicographical comparison.
The top-level namespace of probabilistic Fast Downward.
Definition command_line.h:8
constexpr no_normalize_t no_normalize
Disambiguator tag for Distribution constructor to indicate that the probabilities are already normali...
Definition distribution.h:146
double value_t
Typedef for the state value type.
Definition aliases.h:7
STL namespace.
Disambiguator tag type.
Definition distribution.h:142