gan.hpp
Go to the documentation of this file.
1 
11 #ifndef MLPACK_METHODS_ANN_GAN_GAN_HPP
12 #define MLPACK_METHODS_ANN_GAN_GAN_HPP
13 
14 #include <mlpack/core.hpp>
15 
23 
24 
25 namespace mlpack {
26 namespace ann {
27 
57 template<
58  typename Model,
59  typename InitializationRuleType,
60  typename Noise,
61  typename PolicyType = StandardGAN
62 >
63 class GAN
64 {
65  public:
83  GAN(Model generator,
84  Model discriminator,
85  InitializationRuleType& initializeRule,
86  Noise& noiseFunction,
87  const size_t noiseDim,
88  const size_t batchSize,
89  const size_t generatorUpdateStep,
90  const size_t preTrainSize,
91  const double multiplier,
92  const double clippingParameter = 0.01,
93  const double lambda = 10.0);
94 
96  GAN(const GAN&);
97 
99  GAN(GAN&&);
100 
107  void ResetData(arma::mat trainData);
108 
109  // Reset function.
110  void Reset();
111 
123  template<typename OptimizerType, typename... CallbackTypes>
124  double Train(arma::mat trainData,
125  OptimizerType& Optimizer,
126  CallbackTypes&&... callbacks);
127 
137  template<typename Policy = PolicyType>
138  typename std::enable_if::value ||
139  std::is_same::value, double>::type
140  Evaluate(const arma::mat& parameters,
141  const size_t i,
142  const size_t batchSize);
143 
152  template<typename Policy = PolicyType>
153  typename std::enable_if::value,
154  double>::type
155  Evaluate(const arma::mat& parameters,
156  const size_t i,
157  const size_t batchSize);
158 
167  template<typename Policy = PolicyType>
168  typename std::enable_if::value,
169  double>::type
170  Evaluate(const arma::mat& parameters,
171  const size_t i,
172  const size_t batchSize);
173 
184  template<typename GradType, typename Policy = PolicyType>
185  typename std::enable_if::value ||
186  std::is_same::value, double>::type
187  EvaluateWithGradient(const arma::mat& parameters,
188  const size_t i,
189  GradType& gradient,
190  const size_t batchSize);
191 
202  template<typename GradType, typename Policy = PolicyType>
203  typename std::enable_if::value,
204  double>::type
205  EvaluateWithGradient(const arma::mat& parameters,
206  const size_t i,
207  GradType& gradient,
208  const size_t batchSize);
209 
220  template<typename GradType, typename Policy = PolicyType>
221  typename std::enable_if::value,
222  double>::type
223  EvaluateWithGradient(const arma::mat& parameters,
224  const size_t i,
225  GradType& gradient,
226  const size_t batchSize);
227 
238  template<typename Policy = PolicyType>
239  typename std::enable_if::value ||
240  std::is_same::value, void>::type
241  Gradient(const arma::mat& parameters,
242  const size_t i,
243  arma::mat& gradient,
244  const size_t batchSize);
245 
256  template<typename Policy = PolicyType>
257  typename std::enable_if::value, void>::type
258  Gradient(const arma::mat& parameters,
259  const size_t i,
260  arma::mat& gradient,
261  const size_t batchSize);
262 
273  template<typename Policy = PolicyType>
274  typename std::enable_if::value,
275  void>::type
276  Gradient(const arma::mat& parameters,
277  const size_t i,
278  arma::mat& gradient,
279  const size_t batchSize);
280 
285  void Shuffle();
286 
292  void Forward(const arma::mat& input);
293 
300  void Predict(arma::mat input, arma::mat& output);
301 
303  const arma::mat& Parameters() const { return parameter; }
305  arma::mat& Parameters() { return parameter; }
306 
308  const Model& Generator() const { return generator; }
310  Model& Generator() { return generator; }
312  const Model& Discriminator() const { return discriminator; }
314  Model& Discriminator() { return discriminator; }
315 
317  size_t NumFunctions() const { return numFunctions; }
318 
320  const arma::mat& Responses() const { return responses; }
322  arma::mat& Responses() { return responses; }
323 
325  const arma::mat& Predictors() const { return predictors; }
327  arma::mat& Predictors() { return predictors; }
328 
330  template<typename Archive>
331  void serialize(Archive& ar, const unsigned int /* version */);
332 
333  private:
338  void ResetDeterministic();
339 
341  arma::mat predictors;
343  arma::mat parameter;
345  Model generator;
347  Model discriminator;
349  InitializationRuleType initializeRule;
351  Noise noiseFunction;
353  size_t noiseDim;
355  size_t numFunctions;
357  size_t batchSize;
359  size_t currentBatch;
361  size_t generatorUpdateStep;
363  size_t preTrainSize;
365  double multiplier;
367  double clippingParameter;
369  double lambda;
371  bool reset;
373  DeltaVisitor deltaVisitor;
375  arma::mat responses;
377  arma::mat currentInput;
379  arma::mat currentTarget;
381  OutputParameterVisitor outputParameterVisitor;
383  WeightSizeVisitor weightSizeVisitor;
385  ResetVisitor resetVisitor;
387  arma::mat gradient;
389  arma::mat gradientDiscriminator;
391  arma::mat noiseGradientDiscriminator;
393  arma::mat normGradientDiscriminator;
395  arma::mat noise;
397  arma::mat gradientGenerator;
399  bool deterministic;
401  size_t genWeights;
403  size_t discWeights;
404 };
405 
406 } // namespace ann
407 } // namespace mlpack
408 
409 // Include implementation.
410 #include "gan_impl.hpp"
411 #include "wgan_impl.hpp"
412 #include "wgangp_impl.hpp"
413 
414 
415 #endif
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type EvaluateWithGradient(const arma::mat ¶meters, const size_t i, GradType &gradient, const size_t batchSize)
EvaluateWithGradient function for the Standard GAN and DCGAN.
const Model & Discriminator() const
Return the discriminator of the GAN.
Definition: gan.hpp:312
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points).
Definition: gan.hpp:317
Model & Generator()
Modify the generator of the GAN.
Definition: gan.hpp:310
void Forward(const arma::mat &input)
This function does a forward pass through the GAN network.
Linear algebra utility functions, generally performed on matrices or vectors.
void ResetData(arma::mat trainData)
Initialize the generator, discriminator and weights of the model for training.
WeightSizeVisitor returns the number of weights of the given module.
void serialize(Archive &ar, const unsigned int)
Serialize the model.
void Shuffle()
Shuffle the order of function visitation.
GAN(Model generator, Model discriminator, InitializationRuleType &initializeRule, Noise &noiseFunction, const size_t noiseDim, const size_t batchSize, const size_t generatorUpdateStep, const size_t preTrainSize, const double multiplier, const double clippingParameter=0.01, const double lambda=10.0)
Constructor for GAN class.
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, void >::type Gradient(const arma::mat ¶meters, const size_t i, arma::mat &gradient, const size_t batchSize)
Gradient function for Standard GAN and DCGAN.
arma::mat & Parameters()
Modify the parameters of the network.
Definition: gan.hpp:305
ResetVisitor executes the Reset() function.
OutputParameterVisitor exposes the output parameter of the given module.
const arma::mat & Responses() const
Get the matrix of responses to the input data points.
Definition: gan.hpp:320
double Train(arma::mat trainData, OptimizerType &Optimizer, CallbackTypes &&... callbacks)
Train function.
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type Evaluate(const arma::mat ¶meters, const size_t i, const size_t batchSize)
Evaluate function for the Standard GAN and DCGAN.
arma::mat & Predictors()
Modify the matrix of data points (predictors).
Definition: gan.hpp:327
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
const Model & Generator() const
Return the generator of the GAN.
Definition: gan.hpp:308
const arma::mat & Parameters() const
Return the parameters of the network.
Definition: gan.hpp:303
DeltaVisitor exposes the delta parameter of the given module.
The implementation of the standard GAN module.
Definition: gan.hpp:63
void Predict(arma::mat input, arma::mat &output)
This function predicts the output of the network on the given input.
arma::mat & Responses()
Modify the matrix of responses to the input data points.
Definition: gan.hpp:322
Model & Discriminator()
Modify the discriminator of the GAN.
Definition: gan.hpp:314
const arma::mat & Predictors() const
Get the matrix of data points (predictors).
Definition: gan.hpp:325