(C++)  1.0.0
A multi-platform, modular, universal engine for embedding advanced AI in games.
ContinuousInnerDecisionNode.h
1 #ifndef GRAIL_CONTINUOUS_INNER_DECISION_NODE_H
2 #define GRAIL_CONTINUOUS_INNER_DECISION_NODE_H
3 
4 #include "DecisionNode.h"
5 #include <vector>
6 #include <memory>
7 #include <iostream>
8 
9 
10 namespace grail
11 {
12  namespace simulation
13  {
17  template <class TDecisionType>
18  class ContinuousInnerDecisionNode : public DecisionNode<TDecisionType>
19  {
20  public:
21  ContinuousInnerDecisionNode(int column, float splitValue) :
22  columnIndex { column},
23  splitValue{ splitValue }
24  {
25 
26  }
27 
28  void AddChild(std::unique_ptr<DecisionNode<TDecisionType>> childNode)
29  {
30  children.emplace_back(std::move(childNode));
31  }
32 
33  void Print(std::unordered_map<int, std::string>& columnNames, const std::string& indent) const override
34  {
35  if (columnNames.find(columnIndex) == columnNames.end())
36  columnNames[columnIndex] = "Column[" + std::to_string(columnIndex) + "]";
37 
38  std::cout << indent << columnNames.at(columnIndex) << " <= " << splitValue << std::endl;
39  children[0]->Print(columnNames, indent + " ");
40  std::cout << indent << columnNames.at(columnIndex) << " > " << splitValue << std::endl;
41  children[1]->Print(columnNames, indent + " ");
42  }
43 
44  const TDecisionType* Predict(std::vector<float>& data) const final
45  {
46  return (data[columnIndex] <= splitValue) ? children[0]->Predict(data) : children[1]->Predict(data);
47  }
48 
49  DecisionNodeType GetNodeType() const final
50  {
51  return DecisionNodeType::CONTINUOUS_INNER;
52  }
53  private:
54  std::vector<std::unique_ptr<DecisionNode<TDecisionType>>> children;
55  int columnIndex;
56  float splitValue;
57 
58  template <class T> friend struct IDecisionTreeSerializer;
59  };
60  }
61 }
62 
63 #endif //GRAIL_CONTINUOUS_INNER_DECISION_NODE_H
Class for internal usage. Decision tree node that correspond to numeric conditionals.
Definition: ContinuousInnerDecisionNode.h:19
Class for internal usage. Decision tree node base type.
Definition: DecisionNode.h:21
A base class for an object that is passed to Serialize() and Deserialize() methods of DecisionTree....
Definition: IDecisionTreeSerializer.hh:21