Grail (C++)  1.3.0
A multi-platform, modular, universal engine for embedding advanced AI in games.
SigmoidFunction.hh
1 // Copyright QED Software 2023.
2 
3 #ifndef GRAIL_SIGMOID_FUNCTION_H
4 #define GRAIL_SIGMOID_FUNCTION_H
5 
6 #include <cmath>
7 #include "Curve.hh"
8 
9 namespace grail
10 {
11 namespace evaluator
12 {
13  template <typename ContextType>
18  class SigmoidFunction final : public Curve<ContextType>
19  {
20  public:
29  SigmoidFunction(std::shared_ptr<Evaluator<ContextType>> childEvaluator,
30  float range,
31  float slope,
32  float threshold,
33  float displacement)
34  : Curve<ContextType>{childEvaluator},
35  range(range),
36  slope(slope),
37  threshold(threshold),
38  displacement(displacement)
39  {
40  }
41 
42  virtual float Sample(float argument) const override final
43  {
44  float x = slope * (argument - threshold);
45  if(x < 0)
46  {
47  float exp_x = std::exp(x);
48  return ((range * exp_x) / (1 + exp_x)) + displacement;
49  }
50  else
51  {
52  return (range / (1 + std::exp(-x))) + displacement;
53  }
54  }
55 
60  float GetRange() const { return range; }
65  float GetSlope() const { return slope; }
70  float GetThreshold() const { return threshold; }
75  float GetDisplacement() const { return displacement; }
76 
77  virtual data::EvaluatorType GetEvaluatorType() const override final { return data::EvaluatorType::CURVE_SIGMOID; }
78 
79  private:
80  float range = 0.0f;
81  float slope = 0.0f;
82  float threshold = 0.0f;
83  float displacement = 0.0f;
84  };
85 }
86 }
87 
88 #endif // GRAIL_SIGMOID_FUNCTION_H
grail::evaluator::SigmoidFunction::Sample
virtual float Sample(float argument) const override final
Sample - Transforms argument into output value depending on the type of Curve.
Definition: SigmoidFunction.hh:42
grail::evaluator::SigmoidFunction::GetThreshold
float GetThreshold() const
GetThreshold.
Definition: SigmoidFunction.hh:70
grail::evaluator::Evaluator
The Evaluator class - base class being able to evaluate given context and output the result.
Definition: Evaluator.hh:22
grail::evaluator::SigmoidFunction::GetDisplacement
float GetDisplacement() const
GetDisplacement.
Definition: SigmoidFunction.hh:75
grail::evaluator::SigmoidFunction
The SigmoidFunction class - Sigmoid function.
Definition: SigmoidFunction.hh:18
grail::evaluator::SigmoidFunction::GetRange
float GetRange() const
GetRange.
Definition: SigmoidFunction.hh:60
grail::evaluator::SigmoidFunction::SigmoidFunction
SigmoidFunction(std::shared_ptr< Evaluator< ContextType >> childEvaluator, float range, float slope, float threshold, float displacement)
SigmoidFunction - Constructor.
Definition: SigmoidFunction.hh:29
grail::evaluator::Curve
The Curve class - Defines objects transforming one value into the other.
Definition: Curve.hh:21
grail::evaluator::SigmoidFunction::GetSlope
float GetSlope() const
GetSlope.
Definition: SigmoidFunction.hh:65