3 #ifndef GRAIL_EVALUATOR_HELPERS_H
4 #define GRAIL_EVALUATOR_HELPERS_H
6 #include "IConsiderationProvider.h"
7 #include "../../GrailData/UtilityModel/EvaluatorNodeModel.h"
8 #include "../../GrailData/UtilityModel/EvaluatorTreeModel.h"
9 #include "../../GrailData/UtilityModel/EvaluatorTreeReference.h"
10 #include "../../GrailLogger/LoggerManager.hh"
11 #include "../Aggregators/Aggregator.hh"
12 #include "../Aggregators/AverageAggregator.hh"
13 #include "../Aggregators/MaxAggregator.hh"
14 #include "../Aggregators/MinAggregator.hh"
15 #include "../Aggregators/ProductAggregator.hh"
16 #include "../Aggregators/SumAggregator.hh"
17 #include "../Curves/AllCurves.h"
18 #include "../Curves/Curve.hh"
35 template <
typename ContextType>
36 std::shared_ptr<Aggregator<ContextType>> ConstructAggregator(
37 const std::vector<std::shared_ptr<Evaluator<ContextType>>>& inputEvaluators,
38 data::EvaluatorType evaluatorType)
41 if (std::find(inputEvaluators.begin(), inputEvaluators.end(),
nullptr) != inputEvaluators.end())
48 case data::EvaluatorType::AGGREGATOR_MIN:
50 return std::make_shared<MinAggregator<EntityBlackboardPair>>(inputEvaluators);
52 case data::EvaluatorType::AGGREGATOR_MAX:
54 return std::make_shared<MaxAggregator<EntityBlackboardPair>>(inputEvaluators);
56 case data::EvaluatorType::AGGREGATOR_AVERAGE:
58 return std::make_shared<AverageAggregator<EntityBlackboardPair>>(inputEvaluators);
60 case data::EvaluatorType::AGGREGATOR_PRODUCT:
62 return std::make_shared<ProductAggregator<EntityBlackboardPair>>(inputEvaluators);
64 case data::EvaluatorType::AGGREGATOR_SUM:
66 return std::make_shared<SumAggregator<EntityBlackboardPair>>(inputEvaluators);
70 GRAIL_LOG(consts::DEFAULT_GRAIL_LOG_GROUP, grail::logger::Severity::ERRORS,
"Invalid aggregator type");
83 template <
typename ContextType>
84 std::shared_ptr<Curve<ContextType>> ConstructCurve(
85 const std::shared_ptr<Evaluator<ContextType>>& inputEvaluator,
86 data::EvaluatorType evaluatorType,
87 const std::vector<float>& curveParameters)
89 if (inputEvaluator ==
nullptr)
96 case data::EvaluatorType::CURVE_LINEAR:
98 return std::make_shared<LinearFunction<EntityBlackboardPair>>(inputEvaluator,
102 case data::EvaluatorType::CURVE_POWER:
104 return std::make_shared<PowerFunction<EntityBlackboardPair>>(inputEvaluator,
109 case data::EvaluatorType::CURVE_SIGMOID:
111 return std::make_shared<SigmoidFunction<EntityBlackboardPair>>(inputEvaluator,
117 case data::EvaluatorType::CURVE_UNIT_STEP:
119 return std::make_shared<UnitStepFunction<EntityBlackboardPair>>(inputEvaluator,
124 case data::EvaluatorType::CURVE_EXPONENTIAL:
126 return std::make_shared<ExponentialFunction<EntityBlackboardPair>>(inputEvaluator,
129 case data::EvaluatorType::CURVE_CONSTANT:
131 return std::make_shared<ConstantFunction<EntityBlackboardPair>>(inputEvaluator,
134 case data::EvaluatorType::CURVE_LINEAR_INTERPOLATED:
136 std::vector<Vector2> points{};
137 for(
size_t i = 0; i < curveParameters.size(); i += 2)
139 points.emplace_back(curveParameters[i], curveParameters[i + 1]);
142 return std::make_shared<LinearlyInterpolatedCurve<
143 EntityBlackboardPair>>(inputEvaluator, points);
145 case data::EvaluatorType::CURVE_BEZIER:
147 std::vector<Vector2> points{};
148 std::vector<Vector2> tangents{};
150 for(
size_t i = 0; i < curveParameters.size(); i += 4)
152 points.emplace_back(curveParameters[i], curveParameters[i + 1]);
153 tangents.emplace_back(curveParameters[i + 2], curveParameters[i + 3]);
156 return std::make_shared<BezierSpline<EntityBlackboardPair>>(inputEvaluator,
160 case data::EvaluatorType::CURVE_STAIRCASE:
162 std::vector<StepData> stepData{};
163 for(
size_t i = 0; i < curveParameters.size(); i += 2)
166 data.startPoint = curveParameters[i];
167 data.value = curveParameters[i + 1];
168 data.inclusiveStartPoint =
true;
169 stepData.emplace_back(data);
172 return std::make_shared<StaircaseFunction<
173 EntityBlackboardPair>>(inputEvaluator, stepData);
175 case data::EvaluatorType::CURVE_LOWER_BOUND:
177 return std::make_shared<LowerBound<EntityBlackboardPair>>(inputEvaluator,
180 case data::EvaluatorType::CURVE_UPPER_BOUND:
182 return std::make_shared<UpperBound<EntityBlackboardPair>>(inputEvaluator,
185 case data::EvaluatorType::CURVE_DOUBLE_SIDED_BOUND:
187 return std::make_shared<DoubleSidedBound<EntityBlackboardPair>>(inputEvaluator,
193 GRAIL_LOG(consts::DEFAULT_GRAIL_LOG_GROUP, grail::logger::Severity::ERRORS,
"Invalid curve data type");
201 enum class EvaluatorTypeGroup
209 EvaluatorTypeGroup GetEvaluatorTypeGroup(data::EvaluatorType evaluatorType);
210 data::EvaluatorNodeModel ResolveSubtreeDependencies(
const data::EvaluatorNodeModel& inputModel,
211 const std::vector<data::EvaluatorTreeModel>& savedSubtrees);
213 template <
typename ContextType>
214 std::shared_ptr<Evaluator<ContextType>> ConstructEvaluator(
const data::EvaluatorNodeModel& resolvedModel,
215 const std::shared_ptr<
const IConsiderationProvider<ContextType>>& considerationProvider,
216 std::map<std::string, std::shared_ptr<Consideration<ContextType>>>& considerationMapping)
218 EvaluatorTypeGroup evaluatorTypeGroup = GetEvaluatorTypeGroup(resolvedModel.evaluatorTypeId);
219 switch(evaluatorTypeGroup)
221 case EvaluatorTypeGroup::CONSIDERATION:
223 if(resolvedModel.inputEvaluatorModels.size() != 0)
225 std::stringstream ss;
226 ss <<
"Invalid config for consideration " << resolvedModel.considerationModel.name <<
227 ": considerations don't accept inputs";
228 GRAIL_LOG(consts::DEFAULT_GRAIL_LOG_GROUP, grail::logger::Severity::ERRORS, ss.str());
232 std::shared_ptr<grail::evaluator::Evaluator<ContextType>> evaluator =
nullptr;
233 if(
auto cachedConsideration = considerationMapping[resolvedModel.considerationModel.name]; cachedConsideration !=
nullptr)
235 evaluator = cachedConsideration;
239 auto producedConsideration = considerationProvider->GetConsiderationByName(resolvedModel.considerationModel.name);
240 evaluator = producedConsideration;
241 considerationMapping[resolvedModel.considerationModel.name] = producedConsideration;
244 if(evaluator ==
nullptr)
246 std::stringstream ss;
247 ss <<
"Invalid config for consideration " << resolvedModel.considerationModel.name <<
248 ": the current consideration provider does not provide consideration of this name; check its GetConsiderationByName() method.";
249 GRAIL_LOG(consts::DEFAULT_GRAIL_LOG_GROUP, grail::logger::Severity::ERRORS, ss.str());
255 case EvaluatorTypeGroup::AGGREGATOR:
257 std::vector<std::shared_ptr<Evaluator<EntityBlackboardPair>>> inputEvaluators;
258 std::transform(resolvedModel.inputEvaluatorModels.begin(),
259 resolvedModel.inputEvaluatorModels.end(),
260 std::back_inserter(inputEvaluators),
261 [&considerationProvider, &considerationMapping](
const data::EvaluatorNodeModel& model)
263 return ConstructEvaluator(model, considerationProvider, considerationMapping);
265 return ConstructAggregator(inputEvaluators, resolvedModel.evaluatorTypeId);
267 case EvaluatorTypeGroup::CURVE:
269 if(resolvedModel.inputEvaluatorModels.size() != 1)
271 std::stringstream ss;
272 ss <<
"Invalid config for utility curve: " << resolvedModel.inputEvaluatorModels.size() <<
273 " curve inputs provided, expected 1";
274 GRAIL_LOG(consts::DEFAULT_GRAIL_LOG_GROUP, grail::logger::Severity::ERRORS, ss.str());
278 return ConstructCurve(ConstructEvaluator(resolvedModel.inputEvaluatorModels[0],
279 considerationProvider, considerationMapping),
280 resolvedModel.evaluatorTypeId,
281 resolvedModel.curveDataModel.curveParameters);
300 template <
typename ContextType>
301 std::shared_ptr<Evaluator<ContextType>> ConstructEvaluator(
const data::EvaluatorNodeModel& inputModel,
302 const std::vector<data::EvaluatorTreeModel>& evaluatorTreeModels,
303 const std::shared_ptr<
const IConsiderationProvider<ContextType>>& considerationProvider,
304 std::map<std::string, std::shared_ptr<Consideration<ContextType>>>& considerationMapping)
306 return internal::ConstructEvaluator(internal::ResolveSubtreeDependencies(inputModel, evaluatorTreeModels),
307 considerationProvider, considerationMapping);
313 #endif //GRAIL_EVALUATOR_HELPERS_H