(C++)  1.0.0
A multi-platform, modular, universal engine for embedding advanced AI in games.
NominalInnerDecisionNode.h
1 #ifndef GRAIL_NOMINAL_INNER_DECISION_NODE_H
2 #define GRAIL_NOMINAL_INNER_DECISION_NODE_H
3 
4 #include "DecisionNode.h"
5 #include <unordered_map>
6 #include <memory>
7 #include <iostream>
8 
9 namespace grail
10 {
11  namespace simulation
12  {
14  template <class TDecisionType>
15  class NominalInnerDecisionNode : public DecisionNode<TDecisionType>
16  {
17  public:
18 
19  NominalInnerDecisionNode(int column) :
20  children{},
21  columnIndex{ column }
22  {
23 
24  }
25 
26  void SetDefaultDecision(std::unique_ptr< TDecisionType> decision)
27  {
28  defaultDecision = std::move(decision);
29  }
30 
31  void AddChild(std::unique_ptr<DecisionNode<TDecisionType>> node, float key)
32  {
33  children[key] = std::move(node);
34  }
35 
36  void Print(std::unordered_map<int, std::string>& columnNames, const std::string& indent) const override
37  {
38  if (columnNames.find(columnIndex) == columnNames.end())
39  columnNames[columnIndex] = "Column[" + std::to_string(columnIndex) + "]";
40 
41  for(auto& child : children)
42  {
43  std::cout << indent << columnNames.at(columnIndex) << " = " << child.first << std::endl;
44  child.second->Print(columnNames, indent + " ");
45  }
46  }
47 
48  const TDecisionType* Predict(std::vector<float>& data) const final
49  {
50  auto it = children.find(data[columnIndex]);
51  if (it != children.end())
52  {
53  return it->second->Predict(data);
54  }
55  return defaultDecision.get();
56  }
57 
58  DecisionNodeType GetNodeType() const final
59  {
60  return DecisionNodeType::NOMINAL_INNER;
61  }
62  private:
63  std::unordered_map<float, std::unique_ptr<DecisionNode<TDecisionType>>> children;
64  int columnIndex;
65  std::unique_ptr<TDecisionType> defaultDecision = nullptr;
66 
67  template <class T> friend struct IDecisionTreeSerializer;
68  };
69  }
70 }
71 
72 #endif //GRAIL_NOMINAL_INNER_DECISION_NODE_H
Class for internal usage.
Definition: NominalInnerDecisionNode.h:16
Class for internal usage. Decision tree node base type.
Definition: DecisionNode.h:21
A base class for an object that is passed to Serialize() and Deserialize() methods of DecisionTree....
Definition: IDecisionTreeSerializer.hh:21