27 friend class GraphBuilder;
29 std::map<std::string, std::string> attributes_;
32 explicit Configurable(std::string name)
33 : name_(std::move(name))
37 void set_attribute(std::string attribute, std::string value)
39 attributes_.emplace(std::move(attribute), std::move(value));
43 class Node :
public Configurable {
44 friend class GraphBuilder;
45 using Configurable::Configurable;
48 Node(std::string name,
int rank)
49 : Configurable(std::move(name))
63 class StateNode :
public Node {
64 friend class GraphBuilder;
68 class DummyNode :
public Node {
69 friend class GraphBuilder;
73 class Edge :
public Configurable {
74 friend class GraphBuilder;
78 Edge(std::string name, Node& source, Node& target)
79 : Configurable(std::move(name))
86 std::vector<std::unique_ptr<DummyNode>> dummynodes_;
87 std::vector<std::unique_ptr<StateNode>> nodes_;
88 std::vector<std::unique_ptr<Edge>> edges_;
90 std::map<StateID, StateNode*> id_to_nodes_;
93 std::map<int, std::vector<Node*>> ranked_nodes_;
96 explicit GraphBuilder(
StateID initial)
97 : GraphBuilder(initial, get_default_node_name())
101 GraphBuilder(
StateID initial, std::string name)
104 insert_state_node(initial, 0, std::move(name));
107 Node& get_node(
StateID id) {
return *id_to_nodes_[id]; }
109 std::pair<Node*, bool> insert_node(
StateID id)
111 return insert_state_node(
id, 0, get_default_node_name());
114 Node* create_dummy_node(
int rank)
116 return create_dummy_node(rank, get_default_dummy_node_name());
119 Node* create_dummy_node(
int rank, std::string name)
122 dummynodes_.emplace_back(
new DummyNode(std::move(name), rank));
123 ranked_nodes_[rank].push_back(r.get());
127 Node* create_state_node(
StateID id,
int rank, std::string name)
129 auto& r = nodes_.emplace_back(
new StateNode(std::move(name), rank));
130 id_to_nodes_[id] = r.get();
133 ranked_nodes_[rank].push_back(r.get());
138 std::pair<Node*, bool> insert_state_node(
StateID id,
int rank)
140 return insert_state_node(
id, rank, get_default_node_name());
143 std::pair<Node*, bool>
144 insert_state_node(
StateID id,
int rank, std::string name)
146 auto it = id_to_nodes_.find(
id);
148 if (it != id_to_nodes_.end()) {
149 return std::make_pair(it->second,
false);
152 return std::make_pair(
153 create_state_node(
id, rank, std::move(name)),
157 Configurable& create_edge(Node& source, Node& target)
159 return create_edge(source, target, get_default_edge_name());
163 create_edge(Node& source_node, Node& target_node, std::string name)
165 return *edges_.emplace_back(
166 new Edge(std::move(name), source_node, target_node));
169 void emit(std::ostream& out)
172 <<
" # Node and edge settings\n"
173 <<
" node [fontsize=8];\n"
174 <<
" edge [fontsize=8, arrowsize=0.4];\n\n"
175 <<
" # Initial Arrow\n"
176 <<
" initial_arrow_origin [label=\"\", shape=none];\n"
177 <<
" initial_arrow_origin -> " << id_to_nodes_[initial_]->name_
178 <<
" [arrowhead=vee];\n";
180 out <<
"\n # Node Ranking\n";
182 for (
const auto& group : std::views::values(ranked_nodes_)) {
184 emit_attribute(out,
"rank",
"same");
187 for (
const auto* node : group) {
188 out << node->name_ <<
"; ";
194 out <<
"\n # State Nodes\n";
196 for (
const auto& node : nodes_) {
199 if (!node->attributes_.empty()) {
201 emit_attribute_list(out, node->attributes_);
206 out <<
"\n # Intermediate Nodes\n";
208 for (
const auto& node : dummynodes_) {
211 if (!node->attributes_.empty()) {
213 emit_attribute_list(out, node->attributes_);
218 out <<
"\n # Edges\n";
220 for (
const auto& edge : edges_) {
222 emit_edge(out, edge->source_, edge->target_);
223 if (!edge->attributes_.empty()) {
225 emit_attribute_list(out, edge->attributes_);
234 static void emit_attribute_list(
236 const std::map<std::string, std::string>& attributes)
240 auto it = attributes.begin();
241 auto end = attributes.end();
245 emit_attribute(out, it->first, it->second);
246 while (++it != end) {
248 emit_attribute(out, it->first, it->second);
254 static void emit_attribute(
256 const std::string& attribute,
257 const std::string& value)
259 out << attribute <<
"=\"" << value <<
"\"";
262 static void emit_edge(std::ostream& out, Node& source, Node& target)
264 out << source.name_ <<
" -> " << target.name_;
268 std::string get_default_node_name()
const
270 return "node_" + std::to_string(nodes_.size());
274 std::string get_default_dummy_node_name()
const
276 return "intermediate_node_" + std::to_string(dummynodes_.size());
280 std::string get_default_edge_name()
const
282 return "edge_" + std::to_string(edges_.size());
287template <
typename State,
typename Action>
288void dump_state_space_dot_graph(
290 const State& initial_state,
293 std::function<std::string(
const State&)> sstr =
294 [](
const State&) {
return ""; },
295 std::function<std::string(
const Action&)> astr =
296 [](
const Action&) {
return ""; },
297 bool expand_terminal =
false)
302 internal::GraphBuilder::Node* node;
307 internal::GraphBuilder::Node* node)
316 internal::GraphBuilder builder(istateid);
317 std::stringstream ss;
318 ss << std::setprecision(3);
320 std::deque<SearchInfo> open;
321 open.emplace_back(istateid, initial_state, &builder.get_node(istateid));
324 auto& s = open.front();
326 const State& state = s.state;
329 node->set_attribute(
"label", sstr(state));
330 node->set_attribute(
"shape",
"circle");
335 if (term.is_goal_state()) {
336 node->set_attribute(
"peripheries", std::to_string(2));
338 expand && prune !=
nullptr &&
339 prune->evaluate(state) == term.get_cost()) {
341 node->set_attribute(
"peripheries", std::to_string(3));
350 std::vector<Action> aops;
351 std::vector<Distribution<StateID>> all_successors;
354 std::vector<std::pair<Action, Distribution<StateID>>> transitions;
356 for (std::size_t i = 0; i != aops.size(); ++i) {
357 transitions.emplace_back(aops[i], all_successors[i]);
360 auto less = [](
const auto& left,
const auto& right) {
361 return left.second < right.second;
364 auto equals = [](
const auto& left,
const auto& right) {
365 return left.second == right.second;
368 std::sort(transitions.begin(), transitions.end(), less);
370 std::unique(transitions.begin(), transitions.end(), equals),
373 for (
const auto& [act, successors] : transitions) {
375 if (a_cost != 0_vt) {
376 ss << a_cost <<
"\\n";
379 std::string label_text = ss.str();
382 if (successors.is_dirac()) {
383 const auto succ_id = successors.begin()->item;
384 auto [succ_node, inserted] = builder.insert_node(succ_id);
386 auto& direct_edge = builder.create_edge(*node, *succ_node);
387 direct_edge.set_attribute(
"arrowhead",
"vee");
388 direct_edge.set_attribute(
"label", label_text);
401 int my_rank = node->get_rank();
404 for (
const auto& [succ_id, prob] : successors) {
405 auto [succ_node, inserted] =
406 builder.insert_state_node(succ_id, my_rank + 2);
407 max_rank = std::max(max_rank, succ_node->get_rank());
408 successor_nodes.add_probability(succ_node, prob);
418 auto* intermediate = builder.create_dummy_node(
419 my_rank < max_rank ? my_rank + 1 : my_rank - 1);
420 intermediate->set_attribute(
"xlabel", astr(act));
422 intermediate->set_attribute(
"shape",
"point");
423 intermediate->set_attribute(
"style",
"filled");
424 intermediate->set_attribute(
"fillcolor",
"black");
426 auto& interm_edge = builder.create_edge(*node, *intermediate);
427 interm_edge.set_attribute(
"arrowhead",
"none");
428 interm_edge.set_attribute(
"label", label_text);
430 for (
const auto& [succ_node, prob] : successor_nodes) {
432 std::string prob_text = ss.str();
435 auto& edge = builder.create_edge(*intermediate, *succ_node);
436 edge.set_attribute(
"arrowhead",
"vee");
437 edge.set_attribute(
"label", prob_text);
442 }
while (!open.empty());