Grail (C++)  1.4.0
A multi-platform, modular, universal engine for embedding advanced AI in games.
EvaluatorHelpers.h
1 // Copyright QED Software 2023.
2 
3 #ifndef GRAIL_EVALUATOR_HELPERS_H
4 #define GRAIL_EVALUATOR_HELPERS_H
5 
6 #include "IConsiderationProvider.h"
7 #include "../../GrailData/UtilityModel/EvaluatorNodeModel.h"
8 #include "../../GrailData/UtilityModel/EvaluatorTreeModel.h"
9 #include "../../GrailData/UtilityModel/EvaluatorTreeReference.h"
10 #include "../../GrailLogger/LoggerManager.hh"
11 #include "../Aggregators/Aggregator.hh"
12 #include "../Aggregators/AverageAggregator.hh"
13 #include "../Aggregators/MaxAggregator.hh"
14 #include "../Aggregators/MinAggregator.hh"
15 #include "../Aggregators/ProductAggregator.hh"
16 #include "../Aggregators/SumAggregator.hh"
17 #include "../Curves/AllCurves.h"
18 #include "../Curves/Curve.hh"
19 
20 #include <algorithm>
21 #include <sstream>
22 
23 namespace grail
24 {
25 namespace evaluator
26 {
27 namespace helpers
28 {
35  template <typename ContextType>
36  std::shared_ptr<Aggregator<ContextType>> ConstructAggregator(
37  const std::vector<std::shared_ptr<Evaluator<ContextType>>>& inputEvaluators,
38  data::EvaluatorType evaluatorType)
39 
40  {
41  if (std::find(inputEvaluators.begin(), inputEvaluators.end(), nullptr) != inputEvaluators.end())
42  {
43  return nullptr;
44  }
45 
46  switch(evaluatorType)
47  {
48  case data::EvaluatorType::AGGREGATOR_MIN:
49  {
50  return std::make_shared<MinAggregator<EntityBlackboardPair>>(inputEvaluators);
51  }
52  case data::EvaluatorType::AGGREGATOR_MAX:
53  {
54  return std::make_shared<MaxAggregator<EntityBlackboardPair>>(inputEvaluators);
55  }
56  case data::EvaluatorType::AGGREGATOR_AVERAGE:
57  {
58  return std::make_shared<AverageAggregator<EntityBlackboardPair>>(inputEvaluators);
59  }
60  case data::EvaluatorType::AGGREGATOR_PRODUCT:
61  {
62  return std::make_shared<ProductAggregator<EntityBlackboardPair>>(inputEvaluators);
63  }
64  case data::EvaluatorType::AGGREGATOR_SUM:
65  {
66  return std::make_shared<SumAggregator<EntityBlackboardPair>>(inputEvaluators);
67  }
68  default:
69  {
70  GRAIL_LOG(consts::DEFAULT_GRAIL_LOG_GROUP, grail::logger::Severity::ERRORS, "Invalid aggregator type");
71  return nullptr;
72  }
73  }
74  }
75 
83  template <typename ContextType>
84  std::shared_ptr<Curve<ContextType>> ConstructCurve(
85  const std::shared_ptr<Evaluator<ContextType>>& inputEvaluator,
86  data::EvaluatorType evaluatorType,
87  const std::vector<float>& curveParameters)
88  {
89  if (inputEvaluator == nullptr)
90  {
91  return nullptr;
92  }
93 
94  switch(evaluatorType)
95  {
96  case data::EvaluatorType::CURVE_LINEAR:
97  {
98  return std::make_shared<LinearFunction<EntityBlackboardPair>>(inputEvaluator,
99  curveParameters[0],
100  curveParameters[1]);
101  }
102  case data::EvaluatorType::CURVE_POWER:
103  {
104  return std::make_shared<PowerFunction<EntityBlackboardPair>>(inputEvaluator,
105  curveParameters[0],
106  curveParameters[1],
107  curveParameters[2]);
108  }
109  case data::EvaluatorType::CURVE_SIGMOID:
110  {
111  return std::make_shared<SigmoidFunction<EntityBlackboardPair>>(inputEvaluator,
112  curveParameters[0],
113  curveParameters[1],
114  curveParameters[2],
115  curveParameters[3]);
116  }
117  case data::EvaluatorType::CURVE_UNIT_STEP:
118  {
119  return std::make_shared<UnitStepFunction<EntityBlackboardPair>>(inputEvaluator,
120  curveParameters[0],
121  curveParameters[1],
122  curveParameters[2]);
123  }
124  case data::EvaluatorType::CURVE_EXPONENTIAL:
125  {
126  return std::make_shared<ExponentialFunction<EntityBlackboardPair>>(inputEvaluator,
127  curveParameters[0]);
128  }
129  case data::EvaluatorType::CURVE_CONSTANT:
130  {
131  return std::make_shared<ConstantFunction<EntityBlackboardPair>>(inputEvaluator,
132  curveParameters[0]);
133  }
134  case data::EvaluatorType::CURVE_LINEAR_INTERPOLATED:
135  {
136  std::vector<Vector2> points{};
137  for(size_t i = 0; i < curveParameters.size(); i += 2)
138  {
139  points.emplace_back(curveParameters[i], curveParameters[i + 1]);
140  }
141 
142  return std::make_shared<LinearlyInterpolatedCurve<
143  EntityBlackboardPair>>(inputEvaluator, points);
144  }
145  case data::EvaluatorType::CURVE_BEZIER:
146  {
147  std::vector<Vector2> points{};
148  std::vector<Vector2> tangents{};
149 
150  for(size_t i = 0; i < curveParameters.size(); i += 4)
151  {
152  points.emplace_back(curveParameters[i], curveParameters[i + 1]);
153  tangents.emplace_back(curveParameters[i + 2], curveParameters[i + 3]);
154  }
155 
156  return std::make_shared<BezierSpline<EntityBlackboardPair>>(inputEvaluator,
157  points,
158  tangents);
159  }
160  case data::EvaluatorType::CURVE_STAIRCASE:
161  {
162  std::vector<StepData> stepData{};
163  for(size_t i = 0; i < curveParameters.size(); i += 2)
164  {
165  StepData data;
166  data.startPoint = curveParameters[i];
167  data.value = curveParameters[i + 1];
168  data.inclusiveStartPoint = true;
169  stepData.emplace_back(data);
170  }
171 
172  return std::make_shared<StaircaseFunction<
173  EntityBlackboardPair>>(inputEvaluator, stepData);
174  }
175  case data::EvaluatorType::CURVE_LOWER_BOUND:
176  {
177  return std::make_shared<LowerBound<EntityBlackboardPair>>(inputEvaluator,
178  curveParameters[0]);
179  }
180  case data::EvaluatorType::CURVE_UPPER_BOUND:
181  {
182  return std::make_shared<UpperBound<EntityBlackboardPair>>(inputEvaluator,
183  curveParameters[0]);
184  }
185  case data::EvaluatorType::CURVE_DOUBLE_SIDED_BOUND:
186  {
187  return std::make_shared<DoubleSidedBound<EntityBlackboardPair>>(inputEvaluator,
188  curveParameters[0],
189  curveParameters[1]);
190  }
191  default:
192  {
193  GRAIL_LOG(consts::DEFAULT_GRAIL_LOG_GROUP, grail::logger::Severity::ERRORS, "Invalid curve data type");
194  return nullptr;
195  }
196  }
197  }
198 
199  namespace internal
200  {
201  enum class EvaluatorTypeGroup
202  {
203  CONSIDERATION,
204  CURVE,
205  AGGREGATOR,
206  SUBTREE_REFERENCE,
207  };
208 
209  EvaluatorTypeGroup GetEvaluatorTypeGroup(data::EvaluatorType evaluatorType);
210  data::EvaluatorNodeModel ResolveSubtreeDependencies(const data::EvaluatorNodeModel& inputModel,
211  const std::vector<data::EvaluatorTreeModel>& savedSubtrees);
212 
213  template <typename ContextType>
214  std::shared_ptr<Evaluator<ContextType>> ConstructEvaluator(const data::EvaluatorNodeModel& resolvedModel,
215  const std::shared_ptr<const IConsiderationProvider<ContextType>>& considerationProvider,
216  std::map<std::string, std::shared_ptr<Consideration<ContextType>>>& considerationMapping)
217  {
218  EvaluatorTypeGroup evaluatorTypeGroup = GetEvaluatorTypeGroup(resolvedModel.evaluatorTypeId);
219  switch(evaluatorTypeGroup)
220  {
221  case EvaluatorTypeGroup::CONSIDERATION:
222  {
223  if(resolvedModel.inputEvaluatorModels.size() != 0)
224  {
225  std::stringstream ss;
226  ss << "Invalid config for consideration " << resolvedModel.considerationModel.name <<
227  ": considerations don't accept inputs";
228  GRAIL_LOG(consts::DEFAULT_GRAIL_LOG_GROUP, grail::logger::Severity::ERRORS, ss.str());
229  return nullptr;
230  }
231 
232  std::shared_ptr<grail::evaluator::Evaluator<ContextType>> evaluator = nullptr;
233  if(auto cachedConsideration = considerationMapping[resolvedModel.considerationModel.name]; cachedConsideration != nullptr)
234  {
235  evaluator = cachedConsideration;
236  }
237  else
238  {
239  auto producedConsideration = considerationProvider->GetConsiderationByName(resolvedModel.considerationModel.name);
240  evaluator = producedConsideration;
241  considerationMapping[resolvedModel.considerationModel.name] = producedConsideration;
242  }
243 
244  if(evaluator == nullptr)
245  {
246  std::stringstream ss;
247  ss << "Invalid config for consideration " << resolvedModel.considerationModel.name <<
248  ": the current consideration provider does not provide consideration of this name; check its GetConsiderationByName() method.";
249  GRAIL_LOG(consts::DEFAULT_GRAIL_LOG_GROUP, grail::logger::Severity::ERRORS, ss.str());
250  return nullptr;
251  }
252 
253  return evaluator;
254  }
255  case EvaluatorTypeGroup::AGGREGATOR:
256  {
257  std::vector<std::shared_ptr<Evaluator<EntityBlackboardPair>>> inputEvaluators;
258  std::transform(resolvedModel.inputEvaluatorModels.begin(),
259  resolvedModel.inputEvaluatorModels.end(),
260  std::back_inserter(inputEvaluators),
261  [&considerationProvider, &considerationMapping](const data::EvaluatorNodeModel& model)
262  {
263  return ConstructEvaluator(model, considerationProvider, considerationMapping);
264  });
265  return ConstructAggregator(inputEvaluators, resolvedModel.evaluatorTypeId);
266  }
267  case EvaluatorTypeGroup::CURVE:
268  {
269  if(resolvedModel.inputEvaluatorModels.size() != 1)
270  {
271  std::stringstream ss;
272  ss << "Invalid config for utility curve: " << resolvedModel.inputEvaluatorModels.size() <<
273  " curve inputs provided, expected 1";
274  GRAIL_LOG(consts::DEFAULT_GRAIL_LOG_GROUP, grail::logger::Severity::ERRORS, ss.str());
275  return nullptr;
276  }
277 
278  return ConstructCurve(ConstructEvaluator(resolvedModel.inputEvaluatorModels[0],
279  considerationProvider, considerationMapping),
280  resolvedModel.evaluatorTypeId,
281  resolvedModel.curveDataModel.curveParameters);
282  }
283  default:
284  {
285  return nullptr;
286  }
287  }
288  }
289  }
290 
300  template <typename ContextType>
301  std::shared_ptr<Evaluator<ContextType>> ConstructEvaluator(const data::EvaluatorNodeModel& inputModel,
302  const std::vector<data::EvaluatorTreeModel>& evaluatorTreeModels,
303  const std::shared_ptr<const IConsiderationProvider<ContextType>>& considerationProvider,
304  std::map<std::string, std::shared_ptr<Consideration<ContextType>>>& considerationMapping)
305  {
306  return internal::ConstructEvaluator(internal::ResolveSubtreeDependencies(inputModel, evaluatorTreeModels),
307  considerationProvider, considerationMapping);
308  }
309 }
310 }
311 }
312 
313 #endif //GRAIL_EVALUATOR_HELPERS_H