KFoldCV< MLAlgorithm, Metric, MatType, PredictionsType, WeightsType > Class Template Reference

The class KFoldCV implements k-fold cross-validation for regression and classification algorithms. More...

Public Member Functions

 KFoldCV (const size_t k, const MatType &xs, const PredictionsType &ys, const bool shuffle=true)
 This constructor can be used for regression algorithms and for binary classification algorithms. More...

 
 KFoldCV (const size_t k, const MatType &xs, const PredictionsType &ys, const size_t numClasses, const bool shuffle=true)
 This constructor can be used for multiclass classification algorithms. More...

 
 KFoldCV (const size_t k, const MatType &xs, const data::DatasetInfo &datasetInfo, const PredictionsType &ys, const size_t numClasses, const bool shuffle=true)
 This constructor can be used for multiclass classification algorithms that can take a data::DatasetInfo parameter. More...

 
 KFoldCV (const size_t k, const MatType &xs, const PredictionsType &ys, const WeightsType &weights, const bool shuffle=true)
 This constructor can be used for regression and binary classification algorithms that support weighted learning. More...

 
 KFoldCV (const size_t k, const MatType &xs, const PredictionsType &ys, const size_t numClasses, const WeightsType &weights, const bool shuffle=true)
 This constructor can be used for multiclass classification algorithms that support weighted learning. More...

 
 KFoldCV (const size_t k, const MatType &xs, const data::DatasetInfo &datasetInfo, const PredictionsType &ys, const size_t numClasses, const WeightsType &weights, const bool shuffle=true)
 This constructor can be used for multiclass classification algorithms that can take a data::DatasetInfo parameter and support weighted learning. More...

 
template<typename... MLAlgorithmArgs>
double Evaluate (const MLAlgorithmArgs &...args)
 Run k-fold cross-validation. More...

 
MLAlgorithm & Model ()
 Access and modify a model from the last run of k-fold cross-validation. More...

 
template<bool Enabled = !Base::MIE::SupportsWeights, typename = typename std::enable_if<Enabled>::type>
void Shuffle ()
 Shuffle the data. More...

 
template<bool Enabled = Base::MIE::SupportsWeights, typename = typename std::enable_if<Enabled>::type, typename = void>
void Shuffle ()
 Shuffle the data. More...

 

Detailed Description


template<typename MLAlgorithm, typename Metric, typename MatType = arma::mat, typename PredictionsType = typename MetaInfoExtractor<MLAlgorithm, MatType>::PredictionsType, typename WeightsType = typename MetaInfoExtractor<MLAlgorithm, MatType, PredictionsType>::WeightsType>
class mlpack::cv::KFoldCV< MLAlgorithm, Metric, MatType, PredictionsType, WeightsType >

The class KFoldCV implements k-fold cross-validation for regression and classification algorithms.

To construct a KFoldCV object you need to pass the k parameter and arguments that specify data. For example, you can run 10-fold cross-validation for SoftmaxRegression in the following way.

// 100-point 5-dimensional random dataset.
arma::mat data = arma::randu<arma::mat>(5, 100);
// Random labels in the [0, 4] interval.
arma::Row<size_t> labels =
arma::randi<arma::Row<size_t>>(100, arma::distr_param(0, 4));
size_t numClasses = 5;
KFoldCV<SoftmaxRegression<>, Accuracy> cv(10, data, labels, numClasses);
double lambda = 0.1;
double softmaxAccuracy = cv.Evaluate(lambda);

Before calling Evaluate(), it is possible to shuffle the data by calling the Shuffle() function. Shuffling is performed at construction time if the parameter shuffle is set to true in the constructor.

Template Parameters
MLAlgorithmA machine learning algorithm.
MetricA metric to assess the quality of a trained model.
MatTypeThe type of data.
PredictionsTypeThe type of predictions (should be passed when the predictions type is a template parameter in Train methods of MLAlgorithm).
WeightsTypeThe type of weights (should be passed when weighted learning is supported, and the weights type is a template parameter in Train methods of MLAlgorithm).

Definition at line 65 of file k_fold_cv.hpp.

Constructor & Destructor Documentation

◆ KFoldCV() [1/6]

KFoldCV ( const size_t  k,
const MatType &  xs,
const PredictionsType &  ys,
const bool  shuffle = true 
)

This constructor can be used for regression algorithms and for binary classification algorithms.

Parameters
kNumber of folds (should be at least 2).
xsData points to cross-validate on.
ysPredictions (labels for classification algorithms and responses for regression algorithms) for each data point.
shuffleWhether or not to shuffle the data during construction.

◆ KFoldCV() [2/6]

KFoldCV ( const size_t  k,
const MatType &  xs,
const PredictionsType &  ys,
const size_t  numClasses,
const bool  shuffle = true 
)

This constructor can be used for multiclass classification algorithms.

Parameters
kNumber of folds (should be at least 2).
xsData points to cross-validate on.
ysLabels for each data point.
numClassesNumber of classes in the dataset.
shuffleWhether or not to shuffle the data during construction.

◆ KFoldCV() [3/6]

KFoldCV ( const size_t  k,
const MatType &  xs,
const data::DatasetInfo datasetInfo,
const PredictionsType &  ys,
const size_t  numClasses,
const bool  shuffle = true 
)

This constructor can be used for multiclass classification algorithms that can take a data::DatasetInfo parameter.

Parameters
kNumber of folds (should be at least 2).
xsData points to cross-validate on.
datasetInfoType information for each dimension of the dataset.
ysLabels for each data point.
numClassesNumber of classes in the dataset.
shuffleWhether or not to shuffle the data during construction.

◆ KFoldCV() [4/6]

KFoldCV ( const size_t  k,
const MatType &  xs,
const PredictionsType &  ys,
const WeightsType &  weights,
const bool  shuffle = true 
)

This constructor can be used for regression and binary classification algorithms that support weighted learning.

Parameters
kNumber of folds (should be at least 2).
xsData points to cross-validate on.
ysPredictions (labels for classification algorithms and responses for regression algorithms) for each data point.
weightsObservation weights (for boosting).
shuffleWhether or not to shuffle the data during construction.

◆ KFoldCV() [5/6]

KFoldCV ( const size_t  k,
const MatType &  xs,
const PredictionsType &  ys,
const size_t  numClasses,
const WeightsType &  weights,
const bool  shuffle = true 
)

This constructor can be used for multiclass classification algorithms that support weighted learning.

Parameters
kNumber of folds (should be at least 2).
xsData points to cross-validate on.
ysLabels for each data point.
numClassesNumber of classes in the dataset.
weightsObservation weights (for boosting).
shuffleWhether or not to shuffle the data during construction.

◆ KFoldCV() [6/6]

KFoldCV ( const size_t  k,
const MatType &  xs,
const data::DatasetInfo datasetInfo,
const PredictionsType &  ys,
const size_t  numClasses,
const WeightsType &  weights,
const bool  shuffle = true 
)

This constructor can be used for multiclass classification algorithms that can take a data::DatasetInfo parameter and support weighted learning.

Parameters
kNumber of folds (should be at least 2).
xsData points to cross-validate on.
datasetInfoType information for each dimension of the dataset.
ysLabels for each data point.
numClassesNumber of classes in the dataset.
weightsObservation weights (for boosting).
shuffleWhether or not to shuffle the data during construction.

Member Function Documentation

◆ Evaluate()

double Evaluate ( const MLAlgorithmArgs &...  args)

Run k-fold cross-validation.

Parameters
argsArguments for MLAlgorithm (in addition to the passed ones in the constructor).

◆ Model()

MLAlgorithm& Model ( )

Access and modify a model from the last run of k-fold cross-validation.

◆ Shuffle() [1/2]

void Shuffle ( )

Shuffle the data.

This overload is called if weights are not supported by the model type.

◆ Shuffle() [2/2]

void Shuffle ( )

Shuffle the data.

This overload is called if weights are supported by the model type.


The documentation for this class was generated from the following file:
  • /home/ryan/src/mlpack.org/_src/mlpack-3.4.1/src/mlpack/core/cv/k_fold_cv.hpp