Grail (C++)  1.2.0
A multi-platform, modular, universal engine for embedding advanced AI in games.
Dataset.hh
1 // Copyright QED Software 2023.
2 
3 #ifndef GRAIL_DATASET_H
4 #define GRAIL_DATASET_H
5 
6 #include <functional>
7 #include <initializer_list>
8 #include <iostream>
9 #include <memory>
10 #include <vector>
11 #include "DatasetSample.hh"
12 #include "../DecisionTree/DecisionNode.h"
13 #include "../DecisionTree/DTConsiderationType.hh"
14 
15 namespace grail
16 {
17 namespace simgames
18 {
19  class ISimulatedGameAction;
20 
21 namespace learn
22 {
23  class Dataset
24  {
25  public:
29  size_t DecisionVariablesCount() const;
30 
34  const std::vector<dt::DTConsiderationType>& GetConsiderationTypes() const;
35 
39  dt::DTConsiderationType GetConsiderationType(int columnIndex) const;
40 
45  Dataset(std::initializer_list<dt::DTConsiderationType> considerationTypes);
46 
51  Dataset(const std::vector<dt::DTConsiderationType>& considerationTypes);
52 
54  void AddSample(std::unique_ptr<ISimulatedGameAction> decision, std::initializer_list<float> data);
55 
57  void AddSample(std::unique_ptr<ISimulatedGameAction> decision, const std::vector<float>& data);
58 
60  void AddSample(std::unique_ptr<DatasetSample> sample);
61 
67  double ValidateBinary(dt::DecisionNode<ISimulatedGameAction>& decisionTreeNodeRoot);
68 
72  std::vector<std::unique_ptr<DatasetSample>> Samples;
73 
83  bool MoveFromOther(Dataset& sourceDataset);
84 
93  bool CopyFromOther(Dataset& sourceDataset);
94 
95  private:
96  bool CheckDataCompatibility(const Dataset& other) const;
97 
98  double CalculateDecisionEntropy() const;
99 
100  std::unique_ptr<Dataset> SplitByFilter(std::function<bool(std::vector<float>&)> filter) const;
101  std::vector<std::unique_ptr<Dataset>> SplitContinuous(size_t column, float splitValue) const;
102  std::vector<std::pair<std::unique_ptr<Dataset>, float>> SplitNominal(int column) const;
103 
104  std::vector<dt::DTConsiderationType> considerationTypes;
105 
106  friend class C45Algorithm;
107  };
108 }
109 }
110 }
111 
112 #endif //GRAIL_DATASET_H
grail::simgames::learn::Dataset::DecisionVariablesCount
size_t DecisionVariablesCount() const
Gets the number of measures (also know as considerations / decisions / columns in dataset).
Definition: Dataset.cpp:25
grail::simgames::learn::Dataset::Samples
std::vector< std::unique_ptr< DatasetSample > > Samples
Data stored in Dataset.
Definition: Dataset.hh:72
grail::simgames::learn::Dataset::AddSample
void AddSample(std::unique_ptr< ISimulatedGameAction > decision, std::initializer_list< float > data)
Constructs and insert new sample to the dataset. The sample is constructed using the @decision and va...
Definition: Dataset.cpp:40
grail::simgames::dt::DecisionNode
Class for internal usage. Decision tree node base type.
Definition: DecisionNode.h:24
grail::simgames::learn::Dataset::GetConsiderationTypes
const std::vector< dt::DTConsiderationType > & GetConsiderationTypes() const
Gets types of the respective consideration, in order of appearance. The types are either NUMERIC or N...
Definition: Dataset.cpp:30
grail::simgames::learn::Dataset::CopyFromOther
bool CopyFromOther(Dataset &sourceDataset)
Adds samples from another dataset to the dataset this function was called on. It performs a basic che...
Definition: Dataset.cpp:119
grail::simgames::learn::Dataset
Definition: Dataset.hh:23
grail::simgames::learn::Dataset::MoveFromOther
bool MoveFromOther(Dataset &sourceDataset)
Moves samples from another dataset to the dataset this function was called on. It performs a basic ch...
Definition: Dataset.cpp:105
grail::simgames::learn::C45Algorithm
This class encapsulates the C4.5 Algorithm used to generate a decision tree (see Grail....
Definition: C45Algorithm.h:25
grail::simgames::learn::Dataset::Dataset
Dataset(std::initializer_list< dt::DTConsiderationType > considerationTypes)
Creates a new dataset.
Definition: Dataset.cpp:15
grail::simgames::learn::Dataset::GetConsiderationType
dt::DTConsiderationType GetConsiderationType(int columnIndex) const
Returns the type of the i-th consideration; i = columnIndex.
Definition: Dataset.cpp:35
grail::simgames::learn::Dataset::ValidateBinary
double ValidateBinary(dt::DecisionNode< ISimulatedGameAction > &decisionTreeNodeRoot)
Tests a decision tree (represented by the root node) against a dataset. Returns the accuracy of decis...
Definition: Dataset.cpp:161