(C++)  1.0.0
A multi-platform, modular, universal engine for embedding advanced AI in games.
IDecisionTreeSerializer.hh
1 #ifndef GRAIL_IDECISION_TREE_SERIALIZER_H
2 #define GRAIL_IDECISION_TREE_SERIALIZER_H
3 
4 #include "../DT/ContinuousInnerDecisionNode.h"
5 #include "../DT/NominalInnerDecisionNode.h"
6 #include "../DT/LeafDecisionNode.h"
7 #include "../DT/DecisionNodeType.h"
8 #include "IDecisionSerializers.hh"
9 #include <memory>
10 
11 namespace grail
12 {
13  namespace simulation
14  {
19  template <class TDecisionType>
21  {
22  virtual void SerializeNodeType(DecisionNodeType nodeType) = 0;
23  virtual DecisionNodeType DeserializeNodeType() = 0;
24 
25  virtual void SerializeColumnIndex(int columnIndex) = 0;
26  virtual int DeserializeColumnIndex() = 0;
27 
28  virtual void SerializeValue(float value) = 0;
29  virtual float DeserializeValue() = 0;
30 
31  virtual void SerializeChildrenCount(int count) = 0;
32  virtual int DeserializeChildrenCount() = 0;
33 
34  virtual void SerializeDecision(const TDecisionType& decision) = 0;
35  virtual std::unique_ptr<TDecisionType> DeserializeDecision() = 0;
36 
37  virtual void Initialize()
38  {
39 
40  }
41 
42  void Serialize(DecisionNode<TDecisionType>& node)
43  {
44  switch (node.GetNodeType())
45  {
46  case DecisionNodeType::NOMINAL_INNER:
47  SerializeNominalInnerDecisionNode(static_cast<NominalInnerDecisionNode<TDecisionType>&>(node));
48  break;
49  case DecisionNodeType::CONTINUOUS_INNER:
50  SerializeContinuousInnerDecisionNode(static_cast<ContinuousInnerDecisionNode<TDecisionType>&>(node));
51  break;
52  case DecisionNodeType::LEAF:
53  SerializeLeafNode(static_cast<LeafDecisionNode<TDecisionType>&>(node));
54  break;
55  default:
56  break;
57  }
58  }
59 
60  void SerializeContinuousInnerDecisionNode(ContinuousInnerDecisionNode<TDecisionType>& node)
61  {
62  SerializeNodeType(DecisionNodeType::CONTINUOUS_INNER);
63  SerializeColumnIndex(node.columnIndex);
64  SerializeValue(node.splitValue);
65  Serialize(*node.children[0]);
66  Serialize(*node.children[1]);
67  }
68 
69  void SerializeNominalInnerDecisionNode(NominalInnerDecisionNode<TDecisionType>& node)
70  {
71  SerializeNodeType(DecisionNodeType::NOMINAL_INNER);
72  SerializeColumnIndex(node.columnIndex);
73  SerializeChildrenCount(static_cast<int>(node.children.size()));
74  for (auto& entry : node.children)
75  {
76  SerializeValue(entry.first);
77  Serialize(*entry.second);
78  }
79  }
80 
81  void SerializeLeafNode(LeafDecisionNode<TDecisionType>& node)
82  {
83  SerializeNodeType(DecisionNodeType::LEAF);
84  SerializeDecision(*node.decision);
85  }
86 
87  void DeserializeContinuousInnerDecisionNode(ContinuousInnerDecisionNode<TDecisionType>& node)
88  {
89  node.children.emplace_back(std::move(DeserializeNode()));
90  node.children.emplace_back(std::move(DeserializeNode()));
91  }
92 
93  void DeserializeNominalInnerDecisionNode(NominalInnerDecisionNode<TDecisionType>& node)
94  {
95  int childrenCount = DeserializeColumnIndex();
96  for (int i = 0; i < childrenCount; ++i)
97  {
98  float value = DeserializeValue();
99  node.children[value] = DeserializeNode();
100  }
101  }
102 
103  std::unique_ptr<DecisionNode<TDecisionType>> DeserializeNode()
104  {
105  DecisionNodeType type = DeserializeNodeType();
106  int columIndex = -1;
107  switch (type)
108  {
109  case DecisionNodeType::NOMINAL_INNER:
110  {
111  columIndex = DeserializeColumnIndex();
112  std::unique_ptr<NominalInnerDecisionNode<TDecisionType>> node = std::make_unique<NominalInnerDecisionNode<TDecisionType>>(columIndex);
113  DeserializeNominalInnerDecisionNode(*node);
114  return node;
115  }
116  case DecisionNodeType::CONTINUOUS_INNER:
117  {
118  columIndex = DeserializeColumnIndex();
119  float value = DeserializeValue();
120  auto node = std::make_unique<ContinuousInnerDecisionNode<TDecisionType>>(columIndex, value);
121  DeserializeContinuousInnerDecisionNode(*node);
122  return node;
123  }
124  case DecisionNodeType::LEAF:
125  {
126  return std::make_unique<LeafDecisionNode<TDecisionType>>(DeserializeDecision());
127  }
128  default:
129  return nullptr;
130  }
131  }
132  };
133  }
134 }
135 #endif //GRAIL_IDECISION_TREE_SERIALIZER_H
Class for internal usage. Decision tree node that correspond to numeric conditionals.
Definition: ContinuousInnerDecisionNode.h:19
Class for internal usage. Decision tree node that correspond to actual decisions. They are always lea...
Definition: LeafDecisionNode.h:17
Class for internal usage.
Definition: NominalInnerDecisionNode.h:16
Class for internal usage. Decision tree node base type.
Definition: DecisionNode.h:21
A base class for an object that is passed to Serialize() and Deserialize() methods of DecisionTree....
Definition: IDecisionTreeSerializer.hh:21