base_layer.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_BASE_LAYER_HPP
14 #define MLPACK_METHODS_ANN_LAYER_BASE_LAYER_HPP
15 
16 #include <mlpack/prereqs.hpp>
28 
29 namespace mlpack {
30 namespace ann {
31 
56 template <
57  class ActivationFunction = LogisticFunction,
58  typename InputDataType = arma::mat,
59  typename OutputDataType = arma::mat
60 >
61 class BaseLayer
62 {
63  public:
68  {
69  // Nothing to do here.
70  }
71 
79  template<typename InputType, typename OutputType>
80  void Forward(const InputType& input, OutputType& output)
81  {
82  ActivationFunction::Fn(input, output);
83  }
84 
94  template<typename eT>
95  void Backward(const arma::Mat<eT>& input,
96  const arma::Mat<eT>& gy,
97  arma::Mat<eT>& g)
98  {
99  arma::Mat<eT> derivative;
100  ActivationFunction::Deriv(input, derivative);
101  g = gy % derivative;
102  }
103 
105  OutputDataType const& OutputParameter() const { return outputParameter; }
107  OutputDataType& OutputParameter() { return outputParameter; }
108 
110  OutputDataType const& Delta() const { return delta; }
112  OutputDataType& Delta() { return delta; }
113 
117  template<typename Archive>
118  void serialize(Archive& /* ar */, const unsigned int /* version */)
119  {
120  /* Nothing to do here */
121  }
122 
123  private:
125  OutputDataType delta;
126 
128  OutputDataType outputParameter;
129 }; // class BaseLayer
130 
131 // Convenience typedefs.
132 
136 template <
137  class ActivationFunction = LogisticFunction,
138  typename InputDataType = arma::mat,
139  typename OutputDataType = arma::mat
140 >
141 using SigmoidLayer = BaseLayer<
142  ActivationFunction, InputDataType, OutputDataType>;
143 
147 template <
148  class ActivationFunction = IdentityFunction,
149  typename InputDataType = arma::mat,
150  typename OutputDataType = arma::mat
151 >
152 using IdentityLayer = BaseLayer<
153  ActivationFunction, InputDataType, OutputDataType>;
154 
158 template <
159  class ActivationFunction = RectifierFunction,
160  typename InputDataType = arma::mat,
161  typename OutputDataType = arma::mat
162 >
163 using ReLULayer = BaseLayer<
164  ActivationFunction, InputDataType, OutputDataType>;
165 
169 template <
170  class ActivationFunction = TanhFunction,
171  typename InputDataType = arma::mat,
172  typename OutputDataType = arma::mat
173 >
174 using TanHLayer = BaseLayer<
175  ActivationFunction, InputDataType, OutputDataType>;
176 
180 template <
181  class ActivationFunction = SoftplusFunction,
182  typename InputDataType = arma::mat,
183  typename OutputDataType = arma::mat
184 >
185 using SoftPlusLayer = BaseLayer<
186  ActivationFunction, InputDataType, OutputDataType>;
187 
191 template <
192  class ActivationFunction = HardSigmoidFunction,
193  typename InputDataType = arma::mat,
194  typename OutputDataType = arma::mat
195 >
197  ActivationFunction, InputDataType, OutputDataType>;
198 
202 template <
203  class ActivationFunction = SwishFunction,
204  typename InputDataType = arma::mat,
205  typename OutputDataType = arma::mat
206 >
208  ActivationFunction, InputDataType, OutputDataType>;
209 
213 template <
214  class ActivationFunction = MishFunction,
215  typename InputDataType = arma::mat,
216  typename OutputDataType = arma::mat
217 >
219  ActivationFunction, InputDataType, OutputDataType>;
220 
224 template <
225  class ActivationFunction = LiSHTFunction,
226  typename InputDataType = arma::mat,
227  typename OutputDataType = arma::mat
228 >
230  ActivationFunction, InputDataType, OutputDataType>;
231 
235 template <
236  class ActivationFunction = GELUFunction,
237  typename InputDataType = arma::mat,
238  typename OutputDataType = arma::mat
239 >
241  ActivationFunction, InputDataType, OutputDataType>;
242 
246 template <
247  class ActivationFunction = ElishFunction,
248  typename InputDataType = arma::mat,
249  typename OutputDataType = arma::mat
250 >
252  ActivationFunction, InputDataType, OutputDataType>;
253 
254 } // namespace ann
255 } // namespace mlpack
256 
257 #endif
The identity function, defined by.
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: base_layer.hpp:80
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: base_layer.hpp:107
BaseLayer()
Create the BaseLayer object.
Definition: base_layer.hpp:67
The LiSHT function, defined by.
OutputDataType & Delta()
Modify the delta.
Definition: base_layer.hpp:112
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: base_layer.hpp:95
void serialize(Archive &, const unsigned int)
Serialize the layer.
Definition: base_layer.hpp:118
The tanh function, defined by.
strip_type.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
The ELiSH function, defined by.
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: base_layer.hpp:105
OutputDataType const & Delta() const
Get the delta.
Definition: base_layer.hpp:110
Implementation of the base layer.
Definition: base_layer.hpp:61
The logistic function, defined by.
The swish function, defined by.
The softplus function, defined by.
The hard sigmoid function, defined by.
The GELU function, defined by.
The rectifier function, defined by.