mlpack: src/mlpack/core/data/split_data.hpp Source File
split_data.hpp
Go to the documentation of this file.
1 
15 #ifndef mlpack_CORE_DATA_SPLIT_DATA_HPP
16 #define mlpack_CORE_DATA_SPLIT_DATA_HPP
17 
18 #include <mlpack/core.hpp>
19 
20 namespace mlpack {
21 namespace data {
50 template<typename T, typename U>
51 void Split(const arma::Mat<T>& input,
52  const arma::Row<U>& inputLabel,
53  arma::Mat<T>& trainData,
54  arma::Mat<T>& testData,
55  arma::Row<U>& trainLabel,
56  arma::Row<U>& testLabel,
57  const double testRatio)
58 {
59  const size_t testSize = static_cast<size_t>(input.n_cols * testRatio);
60  const size_t trainSize = input.n_cols - testSize;
61  trainData.set_size(input.n_rows, trainSize);
62  testData.set_size(input.n_rows, testSize);
63  trainLabel.set_size(trainSize);
64  testLabel.set_size(testSize);
65 
66  const arma::Col<size_t> order =
67  arma::shuffle(arma::linspace<arma::Col<size_t>>(0, input.n_cols - 1,
68  input.n_cols));
69 
70  for (size_t i = 0; i != trainSize; ++i)
71  {
72  trainData.col(i) = input.col(order[i]);
73  trainLabel(i) = inputLabel(order[i]);
74  }
75 
76  for (size_t i = 0; i != testSize; ++i)
77  {
78  testData.col(i) = input.col(order[i + trainSize]);
79  testLabel(i) = inputLabel(order[i + trainSize]);
80  }
81 }
82 
104 template<typename T>
105 void Split(const arma::Mat<T>& input,
106  arma::Mat<T>& trainData,
107  arma::Mat<T>& testData,
108  const double testRatio)
109 {
110  const size_t testSize = static_cast<size_t>(input.n_cols * testRatio);
111  const size_t trainSize = input.n_cols - testSize;
112  trainData.set_size(input.n_rows, trainSize);
113  testData.set_size(input.n_rows, testSize);
114 
115  const arma::Col<size_t> order =
116  arma::shuffle(arma::linspace<arma::Col<size_t>>(0, input.n_cols -1,
117  input.n_cols));
118 
119  for (size_t i = 0; i != trainSize; ++i)
120  {
121  trainData.col(i) = input.col(order[i]);
122  }
123  for (size_t i = 0; i != testSize; ++i)
124  {
125  testData.col(i) = input.col(order[i + trainSize]);
126  }
127 }
128 
148 template<typename T,typename U>
149 std::tuple<arma::Mat<T>, arma::Mat<T>, arma::Row<U>, arma::Row<U>>
150 Split(const arma::Mat<T>& input,
151  const arma::Row<U>& inputLabel,
152  const double testRatio)
153 {
154  arma::Mat<T> trainData;
155  arma::Mat<T> testData;
156  arma::Row<U> trainLabel;
157  arma::Row<U> testLabel;
158 
159  Split(input, inputLabel, trainData, testData, trainLabel, testLabel,
160  testRatio);
161 
162  return std::make_tuple(std::move(trainData),
163  std::move(testData),
164  std::move(trainLabel),
165  std::move(testLabel));
166 }
167 
184 template<typename T>
185 std::tuple<arma::Mat<T>, arma::Mat<T>>
186 Split(const arma::Mat<T>& input,
187  const double testRatio)
188 {
189  arma::Mat<T> trainData;
190  arma::Mat<T> testData;
191  Split(input, trainData, testData, testRatio);
192 
193  return std::make_tuple(std::move(trainData),
194  std::move(testData));
195 }
196 
197 } // namespace data
198 } // namespace mlpack
199 
200 #endif
void Split(const arma::Mat< T > &input, const arma::Row< U > &inputLabel, arma::Mat< T > &trainData, arma::Mat< T > &testData, arma::Row< U > &trainLabel, arma::Row< U > &testLabel, const double testRatio)
Given an input dataset and labels, split into a training set and test set.
Definition: split_data.hpp:51
Linear algebra utility functions, generally performed on matrices or vectors.
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...