AI 24/25 Project Software
Documentation for the AI 24/25 course programming project software
Loading...
Searching...
No Matches
graph_visualization.h
1#ifndef PROBFD_UTILS_GRAPH_VISUALIZATION_H
2#define PROBFD_UTILS_GRAPH_VISUALIZATION_H
3
4#include "probfd/storage/per_state_storage.h"
5
6#include "probfd/distribution.h"
7#include "probfd/evaluator.h"
8#include "probfd/mdp.h"
9
10#include <cassert>
11#include <deque>
12#include <iomanip>
13#include <iostream>
14#include <map>
15#include <memory>
16#include <ranges>
17#include <sstream>
18#include <vector>
19
22
23namespace internal {
24class GraphBuilder {
25public:
26 class Configurable {
27 friend class GraphBuilder;
28 std::string name_;
29 std::map<std::string, std::string> attributes_;
30
31 public:
32 explicit Configurable(std::string name)
33 : name_(std::move(name))
34 {
35 }
36
37 void set_attribute(std::string attribute, std::string value)
38 {
39 attributes_.emplace(std::move(attribute), std::move(value));
40 }
41 };
42
43 class Node : public Configurable {
44 friend class GraphBuilder;
45 using Configurable::Configurable;
46 int rank_;
47
48 Node(std::string name, int rank)
49 : Configurable(std::move(name))
50 , rank_(rank)
51 {
52 }
53
54 public:
55 [[nodiscard]]
56 int get_rank() const
57 {
58 return rank_;
59 }
60 };
61
62private:
63 class StateNode : public Node {
64 friend class GraphBuilder;
65 using Node::Node;
66 };
67
68 class DummyNode : public Node {
69 friend class GraphBuilder;
70 using Node::Node;
71 };
72
73 class Edge : public Configurable {
74 friend class GraphBuilder;
75 Node& source_;
76 Node& target_;
77
78 Edge(std::string name, Node& source, Node& target)
79 : Configurable(std::move(name))
80 , source_(source)
81 , target_(target)
82 {
83 }
84 };
85
86 std::vector<std::unique_ptr<DummyNode>> dummynodes_;
87 std::vector<std::unique_ptr<StateNode>> nodes_;
88 std::vector<std::unique_ptr<Edge>> edges_;
89
90 std::map<StateID, StateNode*> id_to_nodes_;
91 StateID initial_;
92
93 std::map<int, std::vector<Node*>> ranked_nodes_;
94
95public:
96 explicit GraphBuilder(StateID initial)
97 : GraphBuilder(initial, get_default_node_name())
98 {
99 }
100
101 GraphBuilder(StateID initial, std::string name)
102 : initial_(initial)
103 {
104 insert_state_node(initial, 0, std::move(name));
105 }
106
107 Node& get_node(StateID id) { return *id_to_nodes_[id]; }
108
109 std::pair<Node*, bool> insert_node(StateID id)
110 {
111 return insert_state_node(id, 0, get_default_node_name());
112 }
113
114 Node* create_dummy_node(int rank)
115 {
116 return create_dummy_node(rank, get_default_dummy_node_name());
117 }
118
119 Node* create_dummy_node(int rank, std::string name)
120 {
121 auto& r =
122 dummynodes_.emplace_back(new DummyNode(std::move(name), rank));
123 ranked_nodes_[rank].push_back(r.get());
124 return r.get();
125 }
126
127 Node* create_state_node(StateID id, int rank, std::string name)
128 {
129 auto& r = nodes_.emplace_back(new StateNode(std::move(name), rank));
130 id_to_nodes_[id] = r.get();
131 r->rank_ = rank;
132
133 ranked_nodes_[rank].push_back(r.get());
134
135 return r.get();
136 }
137
138 std::pair<Node*, bool> insert_state_node(StateID id, int rank)
139 {
140 return insert_state_node(id, rank, get_default_node_name());
141 }
142
143 std::pair<Node*, bool>
144 insert_state_node(StateID id, int rank, std::string name)
145 {
146 auto it = id_to_nodes_.find(id);
147
148 if (it != id_to_nodes_.end()) {
149 return std::make_pair(it->second, false);
150 }
151
152 return std::make_pair(
153 create_state_node(id, rank, std::move(name)),
154 true);
155 }
156
157 Configurable& create_edge(Node& source, Node& target)
158 {
159 return create_edge(source, target, get_default_edge_name());
160 }
161
162 Configurable&
163 create_edge(Node& source_node, Node& target_node, std::string name)
164 {
165 return *edges_.emplace_back(
166 new Edge(std::move(name), source_node, target_node));
167 }
168
169 void emit(std::ostream& out)
170 {
171 out << "digraph {\n"
172 << " # Node and edge settings\n"
173 << " node [fontsize=8];\n"
174 << " edge [fontsize=8, arrowsize=0.4];\n\n"
175 << " # Initial Arrow\n"
176 << " initial_arrow_origin [label=\"\", shape=none];\n"
177 << " initial_arrow_origin -> " << id_to_nodes_[initial_]->name_
178 << " [arrowhead=vee];\n";
179
180 out << "\n # Node Ranking\n";
181
182 for (const auto& group : std::views::values(ranked_nodes_)) {
183 out << " { ";
184 emit_attribute(out, "rank", "same");
185 out << "; ";
186
187 for (const auto* node : group) {
188 out << node->name_ << "; ";
189 }
190
191 out << "}\n";
192 }
193
194 out << "\n # State Nodes\n";
195
196 for (const auto& node : nodes_) {
197 out << " ";
198 out << node->name_;
199 if (!node->attributes_.empty()) {
200 out << " ";
201 emit_attribute_list(out, node->attributes_);
202 }
203 out << ";\n";
204 }
205
206 out << "\n # Intermediate Nodes\n";
207
208 for (const auto& node : dummynodes_) {
209 out << " ";
210 out << node->name_;
211 if (!node->attributes_.empty()) {
212 out << " ";
213 emit_attribute_list(out, node->attributes_);
214 }
215 out << ";\n";
216 }
217
218 out << "\n # Edges\n";
219
220 for (const auto& edge : edges_) {
221 out << " ";
222 emit_edge(out, edge->source_, edge->target_);
223 if (!edge->attributes_.empty()) {
224 out << " ";
225 emit_attribute_list(out, edge->attributes_);
226 }
227 out << ";\n";
228 }
229
230 out << "}\n";
231 }
232
233private:
234 static void emit_attribute_list(
235 std::ostream& out,
236 const std::map<std::string, std::string>& attributes)
237 {
238 out << "[";
239
240 auto it = attributes.begin();
241 auto end = attributes.end();
242
243 assert(it != end);
244
245 emit_attribute(out, it->first, it->second);
246 while (++it != end) {
247 out << ", ";
248 emit_attribute(out, it->first, it->second);
249 }
250
251 out << "]";
252 }
253
254 static void emit_attribute(
255 std::ostream& out,
256 const std::string& attribute,
257 const std::string& value)
258 {
259 out << attribute << "=\"" << value << "\"";
260 }
261
262 static void emit_edge(std::ostream& out, Node& source, Node& target)
263 {
264 out << source.name_ << " -> " << target.name_;
265 }
266
267 [[nodiscard]]
268 std::string get_default_node_name() const
269 {
270 return "node_" + std::to_string(nodes_.size());
271 }
272
273 [[nodiscard]]
274 std::string get_default_dummy_node_name() const
275 {
276 return "intermediate_node_" + std::to_string(dummynodes_.size());
277 }
278
279 [[nodiscard]]
280 std::string get_default_edge_name() const
281 {
282 return "edge_" + std::to_string(edges_.size());
283 }
284};
285} // namespace internal
286
287template <typename State, typename Action>
288void dump_state_space_dot_graph(
289 std::ostream& out,
290 const State& initial_state,
292 Evaluator<State>* prune = nullptr,
293 std::function<std::string(const State&)> sstr =
294 [](const State&) { return ""; },
295 std::function<std::string(const Action&)> astr =
296 [](const Action&) { return ""; },
297 bool expand_terminal = false)
298{
299 struct SearchInfo {
300 StateID state_id;
301 State state;
302 internal::GraphBuilder::Node* node;
303
304 SearchInfo(
305 StateID state_id,
306 State state,
307 internal::GraphBuilder::Node* node)
308 : state_id(state_id)
309 , state(state)
310 , node(node)
311 {
312 }
313 };
314
315 StateID istateid = mdp->get_state_id(initial_state);
316 internal::GraphBuilder builder(istateid);
317 std::stringstream ss;
318 ss << std::setprecision(3);
319
320 std::deque<SearchInfo> open;
321 open.emplace_back(istateid, initial_state, &builder.get_node(istateid));
322
323 do {
324 auto& s = open.front();
325
326 const State& state = s.state;
327 auto* node = s.node;
328
329 node->set_attribute("label", sstr(state));
330 node->set_attribute("shape", "circle");
331
332 const auto term = mdp->get_termination_info(state);
333 bool expand = expand_terminal || !term.is_goal_state();
334
335 if (term.is_goal_state()) {
336 node->set_attribute("peripheries", std::to_string(2));
337 } else if (
338 expand && prune != nullptr &&
339 prune->evaluate(state) == term.get_cost()) {
340 expand = false;
341 node->set_attribute("peripheries", std::to_string(3));
342 }
343
344 open.pop_front();
345
346 if (!expand) {
347 continue;
348 }
349
350 std::vector<Action> aops;
351 std::vector<Distribution<StateID>> all_successors;
352 mdp->generate_all_transitions(state, aops, all_successors);
353
354 std::vector<std::pair<Action, Distribution<StateID>>> transitions;
355
356 for (std::size_t i = 0; i != aops.size(); ++i) {
357 transitions.emplace_back(aops[i], all_successors[i]);
358 }
359
360 auto less = [](const auto& left, const auto& right) {
361 return left.second < right.second;
362 };
363
364 auto equals = [](const auto& left, const auto& right) {
365 return left.second == right.second;
366 };
367
368 std::sort(transitions.begin(), transitions.end(), less);
369 transitions.erase(
370 std::unique(transitions.begin(), transitions.end(), equals),
371 transitions.end());
372
373 for (const auto& [act, successors] : transitions) {
374 const auto a_cost = mdp->get_action_cost(act);
375 if (a_cost != 0_vt) {
376 ss << a_cost << "\\n";
377 }
378 ss << astr(act);
379 std::string label_text = ss.str();
380 ss.str("");
381
382 if (successors.is_dirac()) {
383 const auto succ_id = successors.begin()->item;
384 auto [succ_node, inserted] = builder.insert_node(succ_id);
385
386 auto& direct_edge = builder.create_edge(*node, *succ_node);
387 direct_edge.set_attribute("arrowhead", "vee");
388 direct_edge.set_attribute("label", label_text);
389
390 if (inserted) {
391 open.emplace_back(
392 succ_id,
393 mdp->get_state(succ_id),
394 succ_node);
395 }
396
397 continue;
398 }
399
401 int my_rank = node->get_rank();
402 int max_rank = 0;
403
404 for (const auto& [succ_id, prob] : successors) {
405 auto [succ_node, inserted] =
406 builder.insert_state_node(succ_id, my_rank + 2);
407 max_rank = std::max(max_rank, succ_node->get_rank());
408 successor_nodes.add_probability(succ_node, prob);
409
410 if (inserted) {
411 open.emplace_back(
412 succ_id,
413 mdp->get_state(succ_id),
414 succ_node);
415 }
416 }
417
418 auto* intermediate = builder.create_dummy_node(
419 my_rank < max_rank ? my_rank + 1 : my_rank - 1);
420 intermediate->set_attribute("xlabel", astr(act));
421 // intermediate->set_attribute("tooltip", astr(act));
422 intermediate->set_attribute("shape", "point");
423 intermediate->set_attribute("style", "filled");
424 intermediate->set_attribute("fillcolor", "black");
425
426 auto& interm_edge = builder.create_edge(*node, *intermediate);
427 interm_edge.set_attribute("arrowhead", "none");
428 interm_edge.set_attribute("label", label_text);
429
430 for (const auto& [succ_node, prob] : successor_nodes) {
431 ss << prob;
432 std::string prob_text = ss.str();
433 ss.str("");
434
435 auto& edge = builder.create_edge(*intermediate, *succ_node);
436 edge.set_attribute("arrowhead", "vee");
437 edge.set_attribute("label", prob_text);
438 }
439 }
440
441 aops.clear();
442 } while (!open.empty());
443
444 builder.emit(out);
445}
446
447} // namespace probfd::graphviz
448
449#endif // PROBFD_UTILS_GRAPH_VISUALIZATION_H
virtual value_t get_action_cost(param_type< Action > action)=0
Gets the cost of an action.
virtual TerminationInfo get_termination_info(param_type< State > state)=0
Returns the cost to terminate in a given state and checks whether a state is a goal.
A convenience class that represents a finite probability distribution.
Definition task_state_space.h:27
The interface representing heuristic functions.
Definition mdp_algorithm.h:16
Basic interface for MDPs.
Definition mdp_algorithm.h:14
virtual void generate_all_transitions(param_type< State > state, std::vector< Action > &aops, std::vector< Distribution< StateID > > &successors)=0
Generates all applicable actions and their corresponding successor distributions for a given state.
virtual State get_state(StateID state_id)=0
Get the state mapped to a given state ID.
virtual StateID get_state_id(param_type< State > state)=0
Get the state ID for a given state.
bool is_goal_state() const
Check if this state is a goal.
Definition cost_function.h:34
This namespace contains code used for dumping search spaces as dot graphs.
Definition graph_visualization.h:21
A StateID represents a state within a StateIDMap. Just like Fast Downward's StateID type,...
Definition types.h:22