1 #ifndef GRAIL_EVALUATOR_HELPERS_H
2 #define GRAIL_EVALUATOR_HELPERS_H
4 #include "../Curves/Curve.hh"
5 #include "../Aggregators/Aggregator.hh"
6 #include "../../GrailCore/Logger/LoggerManager.hh"
7 #include "../Aggregators/AverageAggregator.hh"
8 #include "../Aggregators/MaxAggregator.hh"
9 #include "../Aggregators/MinAggregator.hh"
10 #include "../Aggregators/ProductAggregator.hh"
11 #include "../Aggregators/SumAggregator.hh"
12 #include "../Curves/AllCurves.h"
13 #include "../../GrailData/UtilityModel/EvaluatorNodeModel.h"
14 #include "../../GrailData/UtilityModel/EvaluatorTreeModel.h"
15 #include "../../GrailData/UtilityModel/EvaluatorTreeReference.h"
16 #include "IConsiderationProvider.h"
32 template<
typename ContextType>
33 std::shared_ptr<utility::Aggregator<ContextType>> ConstructAggregator(
const std::vector<std::shared_ptr<utility::Evaluator<ContextType>>>& inputEvaluators,
34 EvaluatorType evaluatorType)
37 switch (evaluatorType)
39 case EvaluatorType::AGGREGATOR_MIN:
41 return std::make_shared<utility::MinAggregator<utility::EntityBlackboardPair>>(inputEvaluators);
43 case EvaluatorType::AGGREGATOR_MAX:
45 return std::make_shared<utility::MaxAggregator<utility::EntityBlackboardPair>>(inputEvaluators);
47 case EvaluatorType::AGGREGATOR_AVERAGE:
49 return std::make_shared<utility::AverageAggregator<utility::EntityBlackboardPair>>(inputEvaluators);
51 case EvaluatorType::AGGREGATOR_PRODUCT:
53 return std::make_shared<utility::ProductAggregator<utility::EntityBlackboardPair>>(inputEvaluators);
55 case EvaluatorType::AGGREGATOR_SUM:
57 return std::make_shared<utility::SumAggregator<utility::EntityBlackboardPair>>(inputEvaluators);
61 GRAIL_LOG(consts::DEFAULT_GRAIL_LOG_GROUP, grail::Severity::ERRORS,
"Invalid aggregator type");
74 template<
typename ContextType>
75 std::shared_ptr<curves::Curve<ContextType>> ConstructCurve(
const std::shared_ptr<utility::Evaluator<ContextType>>& inputEvaluator,
76 EvaluatorType evaluatorType,
const std::vector<float>& curveParameters)
78 switch (evaluatorType)
80 case EvaluatorType::CURVE_LINEAR:
82 return std::make_shared<curves::LinearFunction<utility::EntityBlackboardPair>>(inputEvaluator, curveParameters[0],
85 case EvaluatorType::CURVE_POWER:
87 return std::make_shared<curves::PowerFunction<utility::EntityBlackboardPair>>(inputEvaluator, curveParameters[0],
91 case EvaluatorType::CURVE_SIGMOID:
93 return std::make_shared<curves::SigmoidFunction<utility::EntityBlackboardPair>>(inputEvaluator, curveParameters[0],
98 case EvaluatorType::CURVE_UNIT_STEP:
100 return std::make_shared<curves::UnitStepFunction<utility::EntityBlackboardPair>>(inputEvaluator, curveParameters[0],
104 case EvaluatorType::CURVE_EXPONENTIAL:
106 return std::make_shared<curves::ExponentialFunction<utility::EntityBlackboardPair>>(inputEvaluator, curveParameters[0]);
108 case EvaluatorType::CURVE_CONSTANT:
110 return std::make_shared<curves::ConstantFunction<utility::EntityBlackboardPair>>(inputEvaluator, curveParameters[0]);
112 case EvaluatorType::CURVE_LINEAR_INTERPOLATED:
114 std::vector<Vector2> points{};
115 for (
size_t i = 0; i < curveParameters.size(); i += 2)
117 points.emplace_back(curveParameters[i], curveParameters[i + 1]);
120 return std::make_shared<curves::LinearlyInterpolatedCurve<utility::EntityBlackboardPair>>(inputEvaluator, points);
122 case EvaluatorType::CURVE_BEZIER:
124 std::vector<Vector2> points{};
125 std::vector<Vector2> tangents{};
127 for (
size_t i = 0; i < curveParameters.size(); i += 4)
129 points.emplace_back(curveParameters[i], curveParameters[i + 1]);
130 tangents.emplace_back(curveParameters[i + 2], curveParameters[i + 3]);
133 return std::make_shared<curves::BezierSpline<utility::EntityBlackboardPair>>(inputEvaluator, points, tangents);
135 case EvaluatorType::CURVE_STAIRCASE:
137 std::vector<grail::curves::StepData> stepData{};
138 for (
size_t i = 0; i < curveParameters.size(); i += 2)
142 data.
value = curveParameters[i + 1];
144 stepData.emplace_back(data);
147 return std::make_shared<curves::StaircaseFunction<utility::EntityBlackboardPair>>(inputEvaluator, stepData);
149 case EvaluatorType::CURVE_LOWER_BOUND:
151 return std::make_shared<curves::LowerBound<utility::EntityBlackboardPair>>(inputEvaluator, curveParameters[0]);
153 case EvaluatorType::CURVE_UPPER_BOUND:
155 return std::make_shared<curves::UpperBound<utility::EntityBlackboardPair>>(inputEvaluator, curveParameters[0]);
157 case EvaluatorType::CURVE_DOUBLE_SIDED_BOUND:
159 return std::make_shared<curves::DoubleSidedBound<utility::EntityBlackboardPair>>(inputEvaluator, curveParameters[0], curveParameters[1]);
163 GRAIL_LOG(consts::DEFAULT_GRAIL_LOG_GROUP, grail::Severity::ERRORS,
"Invalid curve data type");
172 enum class EvaluatorTypeGroup
180 EvaluatorTypeGroup GetEvaluatorTypeGroup(EvaluatorType evaluatorType);
181 EvaluatorNodeModel ResolveSubtreeDependencies(
const EvaluatorNodeModel& inputModel,
182 const std::vector<EvaluatorTreeModel>& savedSubtrees);
184 template<
typename ContextType>
185 std::shared_ptr<utility::Evaluator<ContextType>> ConstructEvaluator(
const EvaluatorNodeModel& resolvedModel,
186 const std::shared_ptr<
const IConsiderationProvider<ContextType>>& considerationProvider)
188 EvaluatorTypeGroup evaluatorTypeGroup = GetEvaluatorTypeGroup(resolvedModel.evaluatorTypeId);
189 switch (evaluatorTypeGroup)
191 case EvaluatorTypeGroup::CONSIDERATION:
193 if (resolvedModel.inputEvaluatorModels.size() != 0)
195 std::stringstream ss;
196 ss <<
"Invalid config for consideration " << resolvedModel.considerationModel.name <<
197 ": considerations don't accept inputs";
198 GRAIL_LOG(consts::DEFAULT_GRAIL_LOG_GROUP, grail::Severity::ERRORS, ss.str());
201 return considerationProvider->GetConsiderationByName(resolvedModel.considerationModel.name);
203 case EvaluatorTypeGroup::AGGREGATOR:
205 std::vector<std::shared_ptr<utility::Evaluator<utility::EntityBlackboardPair>>> inputEvaluators;
206 std::transform(resolvedModel.inputEvaluatorModels.begin(), resolvedModel.inputEvaluatorModels.end(),
207 std::back_inserter(inputEvaluators),
208 [&considerationProvider](
const EvaluatorNodeModel& model)
210 return ConstructEvaluator(model, considerationProvider);
212 return ConstructAggregator(inputEvaluators, resolvedModel.evaluatorTypeId);
214 case EvaluatorTypeGroup::CURVE:
216 if (resolvedModel.inputEvaluatorModels.size() != 1)
218 std::stringstream ss;
219 ss <<
"Invalid config for utility curve: " << resolvedModel.inputEvaluatorModels.size() <<
220 " curve inputs provided, expected 1";
221 GRAIL_LOG(consts::DEFAULT_GRAIL_LOG_GROUP, grail::Severity::ERRORS, ss.str());
225 return ConstructCurve(ConstructEvaluator(resolvedModel.inputEvaluatorModels[0], considerationProvider),
226 resolvedModel.evaluatorTypeId, resolvedModel.curveDataModel.curveParameters);
244 template<
typename ContextType>
245 std::shared_ptr<utility::Evaluator<ContextType>> ConstructEvaluator(
const EvaluatorNodeModel& inputModel,
246 const std::vector<EvaluatorTreeModel>& evaluatorTreeModels,
247 const std::shared_ptr<
const IConsiderationProvider<ContextType>>& considerationProvider)
249 return internal::ConstructEvaluator(internal::ResolveSubtreeDependencies(inputModel, evaluatorTreeModels), considerationProvider);
The StepData struct - Helper structure describing discrete values used in Staircase function.
Definition: StaircaseFunction.hh:14
float startPoint
startPoint - Start point of half-line extending to the right and having constant given value.
Definition: StaircaseFunction.hh:22
float value
value - Y-axis value.
Definition: StaircaseFunction.hh:18
bool inclusiveStartPoint
inclusiveStartPoint - Determines whether exact start point also has given value. If not,...
Definition: StaircaseFunction.hh:26