13 #ifndef MLPACK_CORE_DATA_SPLIT_DATA_HPP 14 #define MLPACK_CORE_DATA_SPLIT_DATA_HPP 50 template<
typename T,
typename U>
51 void Split(
const arma::Mat
& input, 52 const arma::Row
& inputLabel, 55 arma::Row
& trainLabel, 57 const double testRatio,
58 const bool shuffleData =
true)
60 const size_t testSize =
static_cast<size_t>(input.n_cols * testRatio);
61 const size_t trainSize = input.n_cols - testSize;
62 trainData.set_size(input.n_rows, trainSize);
63 testData.set_size(input.n_rows, testSize);
64 trainLabel.set_size(trainSize);
65 testLabel.set_size(testSize);
69 arma::uvec order = arma::shuffle(arma::linspace
( 70 0, input.n_cols - 1, input.n_cols));
73 trainData = input.cols(order.subvec(0, trainSize - 1));
74 trainLabel = inputLabel.cols(order.subvec(0, trainSize - 1));
76 if (trainSize < input.n_cols)
78 testData = input.cols(order.subvec(trainSize, input.n_cols - 1));
79 testLabel = inputLabel.cols(order.subvec(trainSize, input.n_cols - 1));
86 trainData = input.cols(0, trainSize - 1);
87 trainLabel = inputLabel.subvec(0, trainSize - 1);
89 if (trainSize < input.n_cols)
91 testData = input.cols(trainSize , input.n_cols - 1);
92 testLabel = inputLabel.subvec(trainSize , input.n_cols - 1);
122 arma::Mat
& trainData, 124 const double testRatio,
125 const bool shuffleData =
true)
127 const size_t testSize =
static_cast<size_t>(input.n_cols * testRatio);
128 const size_t trainSize = input.n_cols - testSize;
129 trainData.set_size(input.n_rows, trainSize);
130 testData.set_size(input.n_rows, testSize);
134 arma::uvec order = arma::shuffle(arma::linspace
( 135 0, input.n_cols - 1, input.n_cols));
138 trainData = input.cols(order.subvec(0, trainSize - 1));
140 if (trainSize < input.n_cols)
141 testData = input.cols(order.subvec(trainSize, input.n_cols - 1));
146 trainData = input.cols(0, trainSize - 1);
148 if (trainSize < input.n_cols)
149 testData = input.cols(trainSize , input.n_cols - 1);
174 template<
typename T,
typename U>
175 std::tuple
, arma::Mat, arma::Row, arma::Row> 177 const arma::Row
& inputLabel, 178 const double testRatio,
179 const bool shuffleData =
true)
183 arma::Row
trainLabel; 186 Split(input, inputLabel, trainData, testData, trainLabel, testLabel,
187 testRatio, shuffleData);
189 return std::make_tuple(std::move(trainData),
191 std::move(trainLabel),
192 std::move(testLabel));
214 std::tuple
, arma::Mat> 216 const double testRatio,
217 const bool shuffleData =
true)
221 Split(input, trainData, testData, testRatio, shuffleData);
223 return std::make_tuple(std::move(trainData),
224 std::move(testData));
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Split(const arma::Mat< T > &input, const arma::Row< U > &inputLabel, arma::Mat< T > &trainData, arma::Mat< T > &testData, arma::Row< U > &trainLabel, arma::Row< U > &testLabel, const double testRatio, const bool shuffleData=true)
Given an input dataset and labels, split into a training set and test set.