The class KFoldCV implements k-fold cross-validation for regression and classification algorithms. More...
Public Member Functions | |
KFoldCV (const size_t k, const MatType &xs, const PredictionsType &ys, const bool shuffle=true) | |
This constructor can be used for regression algorithms and for binary classification algorithms. More... | |
KFoldCV (const size_t k, const MatType &xs, const PredictionsType &ys, const size_t numClasses, const bool shuffle=true) | |
This constructor can be used for multiclass classification algorithms. More... | |
KFoldCV (const size_t k, const MatType &xs, const data::DatasetInfo &datasetInfo, const PredictionsType &ys, const size_t numClasses, const bool shuffle=true) | |
This constructor can be used for multiclass classification algorithms that can take a data::DatasetInfo parameter. More... | |
KFoldCV (const size_t k, const MatType &xs, const PredictionsType &ys, const WeightsType &weights, const bool shuffle=true) | |
This constructor can be used for regression and binary classification algorithms that support weighted learning. More... | |
KFoldCV (const size_t k, const MatType &xs, const PredictionsType &ys, const size_t numClasses, const WeightsType &weights, const bool shuffle=true) | |
This constructor can be used for multiclass classification algorithms that support weighted learning. More... | |
KFoldCV (const size_t k, const MatType &xs, const data::DatasetInfo &datasetInfo, const PredictionsType &ys, const size_t numClasses, const WeightsType &weights, const bool shuffle=true) | |
This constructor can be used for multiclass classification algorithms that can take a data::DatasetInfo parameter and support weighted learning. More... | |
template | |
double | Evaluate (const MLAlgorithmArgs &...args) |
Run k-fold cross-validation. More... | |
MLAlgorithm & | Model () |
Access and modify a model from the last run of k-fold cross-validation. More... | |
template | |
void | Shuffle () |
Shuffle the data. More... | |
template | |
void | Shuffle () |
Shuffle the data. More... | |
The class KFoldCV implements k-fold cross-validation for regression and classification algorithms.
To construct a KFoldCV object you need to pass the k parameter and arguments that specify data. For example, you can run 10-fold cross-validation for SoftmaxRegression in the following way.
Before calling Evaluate()
, it is possible to shuffle the data by calling the Shuffle()
function. Shuffling is performed at construction time if the parameter shuffle
is set to true
in the constructor.
MLAlgorithm | A machine learning algorithm. |
Metric | A metric to assess the quality of a trained model. |
MatType | The type of data. |
PredictionsType | The type of predictions (should be passed when the predictions type is a template parameter in Train methods of MLAlgorithm). |
WeightsType | The type of weights (should be passed when weighted learning is supported, and the weights type is a template parameter in Train methods of MLAlgorithm). |
Definition at line 65 of file k_fold_cv.hpp.
KFoldCV | ( | const size_t | k, |
const MatType & | xs, | ||
const PredictionsType & | ys, | ||
const bool | shuffle = true |
||
) |
This constructor can be used for regression algorithms and for binary classification algorithms.
k | Number of folds (should be at least 2). |
xs | Data points to cross-validate on. |
ys | Predictions (labels for classification algorithms and responses for regression algorithms) for each data point. |
shuffle | Whether or not to shuffle the data during construction. |
KFoldCV | ( | const size_t | k, |
const MatType & | xs, | ||
const PredictionsType & | ys, | ||
const size_t | numClasses, | ||
const bool | shuffle = true |
||
) |
This constructor can be used for multiclass classification algorithms.
k | Number of folds (should be at least 2). |
xs | Data points to cross-validate on. |
ys | Labels for each data point. |
numClasses | Number of classes in the dataset. |
shuffle | Whether or not to shuffle the data during construction. |
KFoldCV | ( | const size_t | k, |
const MatType & | xs, | ||
const data::DatasetInfo & | datasetInfo, | ||
const PredictionsType & | ys, | ||
const size_t | numClasses, | ||
const bool | shuffle = true |
||
) |
This constructor can be used for multiclass classification algorithms that can take a data::DatasetInfo parameter.
k | Number of folds (should be at least 2). |
xs | Data points to cross-validate on. |
datasetInfo | Type information for each dimension of the dataset. |
ys | Labels for each data point. |
numClasses | Number of classes in the dataset. |
shuffle | Whether or not to shuffle the data during construction. |
KFoldCV | ( | const size_t | k, |
const MatType & | xs, | ||
const PredictionsType & | ys, | ||
const WeightsType & | weights, | ||
const bool | shuffle = true |
||
) |
This constructor can be used for regression and binary classification algorithms that support weighted learning.
k | Number of folds (should be at least 2). |
xs | Data points to cross-validate on. |
ys | Predictions (labels for classification algorithms and responses for regression algorithms) for each data point. |
weights | Observation weights (for boosting). |
shuffle | Whether or not to shuffle the data during construction. |
KFoldCV | ( | const size_t | k, |
const MatType & | xs, | ||
const PredictionsType & | ys, | ||
const size_t | numClasses, | ||
const WeightsType & | weights, | ||
const bool | shuffle = true |
||
) |
This constructor can be used for multiclass classification algorithms that support weighted learning.
k | Number of folds (should be at least 2). |
xs | Data points to cross-validate on. |
ys | Labels for each data point. |
numClasses | Number of classes in the dataset. |
weights | Observation weights (for boosting). |
shuffle | Whether or not to shuffle the data during construction. |
KFoldCV | ( | const size_t | k, |
const MatType & | xs, | ||
const data::DatasetInfo & | datasetInfo, | ||
const PredictionsType & | ys, | ||
const size_t | numClasses, | ||
const WeightsType & | weights, | ||
const bool | shuffle = true |
||
) |
This constructor can be used for multiclass classification algorithms that can take a data::DatasetInfo parameter and support weighted learning.
k | Number of folds (should be at least 2). |
xs | Data points to cross-validate on. |
datasetInfo | Type information for each dimension of the dataset. |
ys | Labels for each data point. |
numClasses | Number of classes in the dataset. |
weights | Observation weights (for boosting). |
shuffle | Whether or not to shuffle the data during construction. |
double Evaluate | ( | const MLAlgorithmArgs &... | args | ) |
Run k-fold cross-validation.
args | Arguments for MLAlgorithm (in addition to the passed ones in the constructor). |
MLAlgorithm& Model | ( | ) |
Access and modify a model from the last run of k-fold cross-validation.
void Shuffle | ( | ) |
Shuffle the data.
This overload is called if weights are not supported by the model type.
void Shuffle | ( | ) |
Shuffle the data.
This overload is called if weights are supported by the model type.