1 #ifndef GRAIL_CONTINUOUS_INNER_DECISION_NODE_H
2 #define GRAIL_CONTINUOUS_INNER_DECISION_NODE_H
4 #include "DecisionNode.h"
17 template <
class TDecisionType>
22 columnIndex { column},
23 splitValue{ splitValue }
30 children.emplace_back(std::move(childNode));
33 void Print(std::unordered_map<int, std::string>& columnNames,
const std::string& indent)
const override
35 if (columnNames.find(columnIndex) == columnNames.end())
36 columnNames[columnIndex] =
"Column[" + std::to_string(columnIndex) +
"]";
38 std::cout << indent << columnNames.at(columnIndex) <<
" <= " << splitValue << std::endl;
39 children[0]->Print(columnNames, indent +
" ");
40 std::cout << indent << columnNames.at(columnIndex) <<
" > " << splitValue << std::endl;
41 children[1]->Print(columnNames, indent +
" ");
44 const TDecisionType* Predict(std::vector<float>& data)
const final
46 return (data[columnIndex] <= splitValue) ? children[0]->Predict(data) : children[1]->Predict(data);
49 DecisionNodeType GetNodeType()
const final
51 return DecisionNodeType::CONTINUOUS_INNER;
54 std::vector<std::unique_ptr<DecisionNode<TDecisionType>>> children;
Class for internal usage. Decision tree node that correspond to numeric conditionals.
Definition: ContinuousInnerDecisionNode.h:19
Class for internal usage. Decision tree node base type.
Definition: DecisionNode.h:21
A base class for an object that is passed to Serialize() and Deserialize() methods of DecisionTree....
Definition: IDecisionTreeSerializer.hh:21