1 #ifndef GRAIL_IDECISION_TREE_SERIALIZER_H
2 #define GRAIL_IDECISION_TREE_SERIALIZER_H
4 #include "../DT/ContinuousInnerDecisionNode.h"
5 #include "../DT/NominalInnerDecisionNode.h"
6 #include "../DT/LeafDecisionNode.h"
7 #include "../DT/DecisionNodeType.h"
8 #include "IDecisionSerializers.hh"
19 template <
class TDecisionType>
22 virtual void SerializeNodeType(DecisionNodeType nodeType) = 0;
23 virtual DecisionNodeType DeserializeNodeType() = 0;
25 virtual void SerializeColumnIndex(
int columnIndex) = 0;
26 virtual int DeserializeColumnIndex() = 0;
28 virtual void SerializeValue(
float value) = 0;
29 virtual float DeserializeValue() = 0;
31 virtual void SerializeChildrenCount(
int count) = 0;
32 virtual int DeserializeChildrenCount() = 0;
34 virtual void SerializeDecision(
const TDecisionType& decision) = 0;
35 virtual std::unique_ptr<TDecisionType> DeserializeDecision() = 0;
37 virtual void Initialize()
44 switch (node.GetNodeType())
46 case DecisionNodeType::NOMINAL_INNER:
49 case DecisionNodeType::CONTINUOUS_INNER:
52 case DecisionNodeType::LEAF:
62 SerializeNodeType(DecisionNodeType::CONTINUOUS_INNER);
63 SerializeColumnIndex(node.columnIndex);
64 SerializeValue(node.splitValue);
65 Serialize(*node.children[0]);
66 Serialize(*node.children[1]);
71 SerializeNodeType(DecisionNodeType::NOMINAL_INNER);
72 SerializeColumnIndex(node.columnIndex);
73 SerializeChildrenCount(
static_cast<int>(node.children.size()));
74 for (
auto& entry : node.children)
76 SerializeValue(entry.first);
77 Serialize(*entry.second);
83 SerializeNodeType(DecisionNodeType::LEAF);
84 SerializeDecision(*node.decision);
89 node.children.emplace_back(std::move(DeserializeNode()));
90 node.children.emplace_back(std::move(DeserializeNode()));
95 int childrenCount = DeserializeColumnIndex();
96 for (
int i = 0; i < childrenCount; ++i)
98 float value = DeserializeValue();
99 node.children[value] = DeserializeNode();
103 std::unique_ptr<DecisionNode<TDecisionType>> DeserializeNode()
105 DecisionNodeType type = DeserializeNodeType();
109 case DecisionNodeType::NOMINAL_INNER:
111 columIndex = DeserializeColumnIndex();
112 std::unique_ptr<NominalInnerDecisionNode<TDecisionType>> node = std::make_unique<NominalInnerDecisionNode<TDecisionType>>(columIndex);
113 DeserializeNominalInnerDecisionNode(*node);
116 case DecisionNodeType::CONTINUOUS_INNER:
118 columIndex = DeserializeColumnIndex();
119 float value = DeserializeValue();
120 auto node = std::make_unique<ContinuousInnerDecisionNode<TDecisionType>>(columIndex, value);
121 DeserializeContinuousInnerDecisionNode(*node);
124 case DecisionNodeType::LEAF:
126 return std::make_unique<LeafDecisionNode<TDecisionType>>(DeserializeDecision());
Class for internal usage. Decision tree node that correspond to numeric conditionals.
Definition: ContinuousInnerDecisionNode.h:19
Class for internal usage. Decision tree node that correspond to actual decisions. They are always lea...
Definition: LeafDecisionNode.h:17
Class for internal usage.
Definition: NominalInnerDecisionNode.h:16
Class for internal usage. Decision tree node base type.
Definition: DecisionNode.h:21
A base class for an object that is passed to Serialize() and Deserialize() methods of DecisionTree....
Definition: IDecisionTreeSerializer.hh:21