split_data.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_CORE_DATA_SPLIT_DATA_HPP
14 #define MLPACK_CORE_DATA_SPLIT_DATA_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace data {
50 template<typename T, typename U>
51 void Split(const arma::Mat& input,
52  const arma::Row& inputLabel,
53  arma::Mat& trainData,
54  arma::Mat& testData,
55  arma::Row& trainLabel,
56  arma::Row& testLabel,
57  const double testRatio,
58  const bool shuffleData = true)
59 {
60  const size_t testSize = static_cast<size_t>(input.n_cols * testRatio);
61  const size_t trainSize = input.n_cols - testSize;
62  trainData.set_size(input.n_rows, trainSize);
63  testData.set_size(input.n_rows, testSize);
64  trainLabel.set_size(trainSize);
65  testLabel.set_size(testSize);
66 
67  if (shuffleData)
68  {
69  arma::uvec order = arma::shuffle(arma::linspace(
70  0, input.n_cols - 1, input.n_cols));
71  if (trainSize > 0)
72  {
73  trainData = input.cols(order.subvec(0, trainSize - 1));
74  trainLabel = inputLabel.cols(order.subvec(0, trainSize - 1));
75  }
76  if (trainSize < input.n_cols)
77  {
78  testData = input.cols(order.subvec(trainSize, input.n_cols - 1));
79  testLabel = inputLabel.cols(order.subvec(trainSize, input.n_cols - 1));
80  }
81  }
82  else
83  {
84  if (trainSize > 0)
85  {
86  trainData = input.cols(0, trainSize - 1);
87  trainLabel = inputLabel.subvec(0, trainSize - 1);
88  }
89  if (trainSize < input.n_cols)
90  {
91  testData = input.cols(trainSize , input.n_cols - 1);
92  testLabel = inputLabel.subvec(trainSize , input.n_cols - 1);
93  }
94  }
95 }
96 
120 template<typename T>
121 void Split(const arma::Mat& input,
122  arma::Mat& trainData,
123  arma::Mat& testData,
124  const double testRatio,
125  const bool shuffleData = true)
126 {
127  const size_t testSize = static_cast<size_t>(input.n_cols * testRatio);
128  const size_t trainSize = input.n_cols - testSize;
129  trainData.set_size(input.n_rows, trainSize);
130  testData.set_size(input.n_rows, testSize);
131 
132  if (shuffleData)
133  {
134  arma::uvec order = arma::shuffle(arma::linspace(
135  0, input.n_cols - 1, input.n_cols));
136 
137  if (trainSize > 0)
138  trainData = input.cols(order.subvec(0, trainSize - 1));
139 
140  if (trainSize < input.n_cols)
141  testData = input.cols(order.subvec(trainSize, input.n_cols - 1));
142  }
143  else
144  {
145  if (trainSize > 0)
146  trainData = input.cols(0, trainSize - 1);
147 
148  if (trainSize < input.n_cols)
149  testData = input.cols(trainSize , input.n_cols - 1);
150  }
151 }
152 
174 template<typename T, typename U>
175 std::tuple, arma::Mat, arma::Row, arma::Row>
176 Split(const arma::Mat& input,
177  const arma::Row& inputLabel,
178  const double testRatio,
179  const bool shuffleData = true)
180 {
181  arma::Mat trainData;
182  arma::Mat testData;
183  arma::Row trainLabel;
184  arma::Row testLabel;
185 
186  Split(input, inputLabel, trainData, testData, trainLabel, testLabel,
187  testRatio, shuffleData);
188 
189  return std::make_tuple(std::move(trainData),
190  std::move(testData),
191  std::move(trainLabel),
192  std::move(testLabel));
193 }
194 
213 template<typename T>
214 std::tuple, arma::Mat>
215 Split(const arma::Mat& input,
216  const double testRatio,
217  const bool shuffleData = true)
218 {
219  arma::Mat trainData;
220  arma::Mat testData;
221  Split(input, trainData, testData, testRatio, shuffleData);
222 
223  return std::make_tuple(std::move(trainData),
224  std::move(testData));
225 }
226 
227 } // namespace data
228 } // namespace mlpack
229 
230 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Split(const arma::Mat< T > &input, const arma::Row< U > &inputLabel, arma::Mat< T > &trainData, arma::Mat< T > &testData, arma::Row< U > &trainLabel, arma::Row< U > &testLabel, const double testRatio, const bool shuffleData=true)
Given an input dataset and labels, split into a training set and test set.
Definition: split_data.hpp:51