AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
cegar.h
1#ifndef PROBFD_PDBS_CEGAR_CEGAR_H
2#define PROBFD_PDBS_CEGAR_CEGAR_H
3
4#include "probfd/pdbs/types.h"
5
6#include "probfd/fdr_types.h"
7#include "probfd/value_type.h"
8
9#include <memory>
10#include <unordered_map>
11#include <unordered_set>
12#include <vector>
13
14// Forward Declarations
15class VariablesProxy;
16
17namespace utils {
18class CountdownTimer;
19class LogProxy;
20class RandomNumberGenerator;
21} // namespace utils
22
23namespace probfd {
24class ProbabilisticTaskProxy;
25}
26
27namespace probfd::pdbs::cegar {
28struct Flaw;
29class FlawFindingStrategy;
30} // namespace probfd::pdbs::cegar
31
32namespace probfd::pdbs::cegar {
33
34struct CEGARResult {
35 std::unique_ptr<ProjectionCollection> projections;
36 std::unique_ptr<PPDBCollection> pdbs;
37
38 ~CEGARResult();
39};
40
41class CEGAR {
42 class PDBInfo;
43
44 // Random number generator
45 const std::shared_ptr<utils::RandomNumberGenerator> rng_;
46
47 // Flaw finding strategy
48 const std::shared_ptr<FlawFindingStrategy> flaw_strategy_;
49
50 // behavior defining parameters
51 const bool wildcard_;
52 const int max_pdb_size_;
53 const int max_collection_size_;
54
55 const std::vector<int> goals_;
56 std::unordered_set<int> blacklisted_variables_;
57
58 // the pattern collection in form of their pdbs plus stored plans.
59 std::vector<PDBInfo> pdb_infos_;
60 std::vector<PDBInfo>::iterator unsolved_end;
61 std::vector<PDBInfo>::iterator solved_end;
62
63 // Takes a variable as key and returns an iterator to the solutions-entry
64 // whose pattern contains said variable. Used for checking if a variable
65 // is already included in some pattern as well as for quickly finding
66 // the other partner for merging.
67 std::unordered_map<int, std::vector<PDBInfo>::iterator> variable_to_info_;
68
69 int remaining_size_ = max_collection_size_;
70
71public:
72 CEGAR(
73 const std::shared_ptr<utils::RandomNumberGenerator>& rng,
74 std::shared_ptr<cegar::FlawFindingStrategy> flaw_strategy,
75 bool wildcard,
76 int max_pdb_size,
77 int max_collection_size,
78 std::vector<int> goals,
79 std::unordered_set<int> blacklisted_variables = {});
80
81 ~CEGAR();
82
83 CEGARResult generate_pdbs(
84 ProbabilisticTaskProxy task_proxy,
85 const std::shared_ptr<FDRSimpleCostFunction>& task_cost_function,
86 double max_time,
87 utils::LogProxy log);
88
89private:
90 void generate_trivial_solution_collection(
91 ProbabilisticTaskProxy task_proxy,
92 std::shared_ptr<FDRSimpleCostFunction> task_cost_function,
93 utils::CountdownTimer& timer,
94 utils::LogProxy log);
95
96 std::vector<PDBInfo>::iterator get_flaws(
97 ProbabilisticTaskProxy task_proxy,
98 std::vector<Flaw>& flaws,
99 std::vector<int>& flaw_offsets,
100 utils::CountdownTimer& timer,
101 utils::LogProxy log);
102
103 bool can_add_variable_to_pattern(
104 const VariablesProxy& variables,
105 std::vector<PDBInfo>::iterator info_it,
106 int var) const;
107
108 bool can_merge_patterns(
109 std::vector<PDBInfo>::iterator info_it1,
110 std::vector<PDBInfo>::iterator info_it2) const;
111
112 void add_pattern_for_var(
113 ProbabilisticTaskProxy task_proxy,
114 std::shared_ptr<FDRSimpleCostFunction> task_cost_function,
115 int var,
116 utils::CountdownTimer& timer);
117
118 void add_variable_to_pattern(
119 ProbabilisticTaskProxy task_proxy,
120 std::shared_ptr<FDRSimpleCostFunction> task_cost_function,
121 std::vector<PDBInfo>::iterator info_it,
122 int var,
123 utils::CountdownTimer& timer);
124
125 void merge_patterns(
126 ProbabilisticTaskProxy task_proxy,
127 std::shared_ptr<FDRSimpleCostFunction> task_cost_function,
128 std::vector<PDBInfo>::iterator info_it1,
129 std::vector<PDBInfo>::iterator info_it2,
130 utils::CountdownTimer& timer);
131
132 void refine(
133 ProbabilisticTaskProxy task_proxy,
134 const std::shared_ptr<FDRSimpleCostFunction>& task_cost_function,
135 const std::vector<Flaw>& flaws,
136 const std::vector<int>& flaw_offsets,
137 utils::CountdownTimer& timer,
138 utils::LogProxy log);
139
140 void print_collection(utils::LogProxy log) const;
141};
142
143} // namespace probfd::pdbs::cegar
144
145#endif // PROBFD_PDBS_CEGAR_CEGAR_H
The top-level namespace of probabilistic Fast Downward.
Definition command_line.h:8