11 #ifndef MLPACK_METHODS_ANN_GAN_GAN_HPP 12 #define MLPACK_METHODS_ANN_GAN_GAN_HPP 59 typename InitializationRuleType,
61 typename PolicyType = StandardGAN
85 InitializationRuleType& initializeRule,
87 const size_t noiseDim,
88 const size_t batchSize,
89 const size_t generatorUpdateStep,
90 const size_t preTrainSize,
91 const double multiplier,
92 const double clippingParameter = 0.01,
93 const double lambda = 10.0);
123 template<
typename OptimizerType,
typename... CallbackTypes>
124 double Train(arma::mat trainData,
125 OptimizerType& Optimizer,
126 CallbackTypes&&... callbacks);
137 template<
typename Policy = PolicyType>
138 typename std::enable_if
::value || 139 std::is_same
::value, double>::type 140 Evaluate(
const arma::mat& parameters,
142 const size_t batchSize);
152 template<
typename Policy = PolicyType>
153 typename std::enable_if
::value, 155 Evaluate(
const arma::mat& parameters,
157 const size_t batchSize);
167 template<
typename Policy = PolicyType>
168 typename std::enable_if
::value, 170 Evaluate(
const arma::mat& parameters,
172 const size_t batchSize);
184 template<
typename GradType,
typename Policy = PolicyType>
185 typename std::enable_if
::value || 186 std::is_same
::value, double>::type 190 const size_t batchSize);
202 template<
typename GradType,
typename Policy = PolicyType>
203 typename std::enable_if
::value, 208 const size_t batchSize);
220 template<
typename GradType,
typename Policy = PolicyType>
221 typename std::enable_if
::value, 226 const size_t batchSize);
238 template<
typename Policy = PolicyType>
239 typename std::enable_if
::value || 240 std::is_same
::value, void>::type 241 Gradient(
const arma::mat& parameters,
244 const size_t batchSize);
256 template<
typename Policy = PolicyType>
257 typename std::enable_if
::value, void>::type 258 Gradient(
const arma::mat& parameters,
261 const size_t batchSize);
273 template<
typename Policy = PolicyType>
274 typename std::enable_if
::value, 276 Gradient(
const arma::mat& parameters,
279 const size_t batchSize);
292 void Forward(
const arma::mat& input);
300 void Predict(arma::mat input, arma::mat& output);
320 const arma::mat&
Responses()
const {
return responses; }
330 template<
typename Archive>
331 void serialize(Archive& ar,
const unsigned int );
338 void ResetDeterministic();
341 arma::mat predictors;
349 InitializationRuleType initializeRule;
361 size_t generatorUpdateStep;
367 double clippingParameter;
377 arma::mat currentInput;
379 arma::mat currentTarget;
389 arma::mat gradientDiscriminator;
391 arma::mat noiseGradientDiscriminator;
393 arma::mat normGradientDiscriminator;
397 arma::mat gradientGenerator;
410 #include "gan_impl.hpp" 411 #include "wgan_impl.hpp" 412 #include "wgangp_impl.hpp" std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type EvaluateWithGradient(const arma::mat ¶meters, const size_t i, GradType &gradient, const size_t batchSize)
EvaluateWithGradient function for the Standard GAN and DCGAN.
const Model & Discriminator() const
Return the discriminator of the GAN.
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points).
Model & Generator()
Modify the generator of the GAN.
void Forward(const arma::mat &input)
This function does a forward pass through the GAN network.
Linear algebra utility functions, generally performed on matrices or vectors.
void ResetData(arma::mat trainData)
Initialize the generator, discriminator and weights of the model for training.
WeightSizeVisitor returns the number of weights of the given module.
void serialize(Archive &ar, const unsigned int)
Serialize the model.
void Shuffle()
Shuffle the order of function visitation.
GAN(Model generator, Model discriminator, InitializationRuleType &initializeRule, Noise &noiseFunction, const size_t noiseDim, const size_t batchSize, const size_t generatorUpdateStep, const size_t preTrainSize, const double multiplier, const double clippingParameter=0.01, const double lambda=10.0)
Constructor for GAN class.
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, void >::type Gradient(const arma::mat ¶meters, const size_t i, arma::mat &gradient, const size_t batchSize)
Gradient function for Standard GAN and DCGAN.
arma::mat & Parameters()
Modify the parameters of the network.
ResetVisitor executes the Reset() function.
OutputParameterVisitor exposes the output parameter of the given module.
const arma::mat & Responses() const
Get the matrix of responses to the input data points.
double Train(arma::mat trainData, OptimizerType &Optimizer, CallbackTypes &&... callbacks)
Train function.
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type Evaluate(const arma::mat ¶meters, const size_t i, const size_t batchSize)
Evaluate function for the Standard GAN and DCGAN.
arma::mat & Predictors()
Modify the matrix of data points (predictors).
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
const Model & Generator() const
Return the generator of the GAN.
const arma::mat & Parameters() const
Return the parameters of the network.
DeltaVisitor exposes the delta parameter of the given module.
The implementation of the standard GAN module.
void Predict(arma::mat input, arma::mat &output)
This function predicts the output of the network on the given input.
arma::mat & Responses()
Modify the matrix of responses to the input data points.
Model & Discriminator()
Modify the discriminator of the GAN.
const arma::mat & Predictors() const
Get the matrix of data points (predictors).