base_layer.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_BASE_LAYER_HPP
14 #define MLPACK_METHODS_ANN_LAYER_BASE_LAYER_HPP
15 
16 #include <mlpack/prereqs.hpp>
30 
31 namespace mlpack {
32 namespace ann {
33 
60 template <
61  class ActivationFunction = LogisticFunction,
62  typename InputDataType = arma::mat,
63  typename OutputDataType = arma::mat
64 >
65 class BaseLayer
66 {
67  public:
72  {
73  // Nothing to do here.
74  }
75 
83  template<typename InputType, typename OutputType>
84  void Forward(const InputType& input, OutputType& output)
85  {
86  ActivationFunction::Fn(input, output);
87  }
88 
98  template<typename eT>
99  void Backward(const arma::Mat& input,
100  const arma::Mat& gy,
101  arma::Mat& g)
102  {
103  arma::Mat derivative;
104  ActivationFunction::Deriv(input, derivative);
105  g = gy % derivative;
106  }
107 
109  OutputDataType const& OutputParameter() const { return outputParameter; }
111  OutputDataType& OutputParameter() { return outputParameter; }
112 
114  OutputDataType const& Delta() const { return delta; }
116  OutputDataType& Delta() { return delta; }
117 
121  template<typename Archive>
122  void serialize(Archive& /* ar */, const unsigned int /* version */)
123  {
124  /* Nothing to do here */
125  }
126 
127  private:
129  OutputDataType delta;
130 
132  OutputDataType outputParameter;
133 }; // class BaseLayer
134 
135 // Convenience typedefs.
136 
140 template <
141  class ActivationFunction = LogisticFunction,
142  typename InputDataType = arma::mat,
143  typename OutputDataType = arma::mat
144 >
145 using SigmoidLayer = BaseLayer<
146  ActivationFunction, InputDataType, OutputDataType>;
147 
151 template <
152  class ActivationFunction = IdentityFunction,
153  typename InputDataType = arma::mat,
154  typename OutputDataType = arma::mat
155 >
156 using IdentityLayer = BaseLayer<
157  ActivationFunction, InputDataType, OutputDataType>;
158 
162 template <
163  class ActivationFunction = RectifierFunction,
164  typename InputDataType = arma::mat,
165  typename OutputDataType = arma::mat
166 >
167 using ReLULayer = BaseLayer<
168  ActivationFunction, InputDataType, OutputDataType>;
169 
173 template <
174  class ActivationFunction = TanhFunction,
175  typename InputDataType = arma::mat,
176  typename OutputDataType = arma::mat
177 >
178 using TanHLayer = BaseLayer<
179  ActivationFunction, InputDataType, OutputDataType>;
180 
184 template <
185  class ActivationFunction = SoftplusFunction,
186  typename InputDataType = arma::mat,
187  typename OutputDataType = arma::mat
188 >
189 using SoftPlusLayer = BaseLayer<
190  ActivationFunction, InputDataType, OutputDataType>;
191 
195 template <
196  class ActivationFunction = HardSigmoidFunction,
197  typename InputDataType = arma::mat,
198  typename OutputDataType = arma::mat
199 >
201  ActivationFunction, InputDataType, OutputDataType>;
202 
206 template <
207  class ActivationFunction = SwishFunction,
208  typename InputDataType = arma::mat,
209  typename OutputDataType = arma::mat
210 >
212  ActivationFunction, InputDataType, OutputDataType>;
213 
217 template <
218  class ActivationFunction = MishFunction,
219  typename InputDataType = arma::mat,
220  typename OutputDataType = arma::mat
221 >
223  ActivationFunction, InputDataType, OutputDataType>;
224 
228 template <
229  class ActivationFunction = LiSHTFunction,
230  typename InputDataType = arma::mat,
231  typename OutputDataType = arma::mat
232 >
234  ActivationFunction, InputDataType, OutputDataType>;
235 
239 template <
240  class ActivationFunction = GELUFunction,
241  typename InputDataType = arma::mat,
242  typename OutputDataType = arma::mat
243 >
245  ActivationFunction, InputDataType, OutputDataType>;
246 
250 template <
251  class ActivationFunction = ElliotFunction,
252  typename InputDataType = arma::mat,
253  typename OutputDataType = arma::mat
254 >
256  ActivationFunction, InputDataType, OutputDataType>;
257 
261 template <
262  class ActivationFunction = ElishFunction,
263  typename InputDataType = arma::mat,
264  typename OutputDataType = arma::mat
265 >
267  ActivationFunction, InputDataType, OutputDataType>;
268 
272 template <
273  class ActivationFunction = GaussianFunction,
274  typename InputDataType = arma::mat,
275  typename OutputDataType = arma::mat
276 >
278  ActivationFunction, InputDataType, OutputDataType>;
279 
280 } // namespace ann
281 } // namespace mlpack
282 
283 #endif
The identity function, defined by.
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: base_layer.hpp:84
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: base_layer.hpp:111
BaseLayer()
Create the BaseLayer object.
Definition: base_layer.hpp:71
The LiSHT function, defined by.
OutputDataType & Delta()
Modify the delta.
Definition: base_layer.hpp:116
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: base_layer.hpp:99
void serialize(Archive &, const unsigned int)
Serialize the layer.
Definition: base_layer.hpp:122
The tanh function, defined by.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
The ELiSH function, defined by.
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: base_layer.hpp:109
OutputDataType const & Delta() const
Get the delta.
Definition: base_layer.hpp:114
Implementation of the base layer.
Definition: base_layer.hpp:65
The Mish function, defined by.
The logistic function, defined by.
The gaussian function, defined by.
The Elliot function, defined by.
The swish function, defined by.
The softplus function, defined by.
The hard sigmoid function, defined by.
The GELU function, defined by.
The rectifier function, defined by.