mlpack: src/mlpack/tests/serialization.hpp Source File
serialization.hpp
Go to the documentation of this file.
1 
14 #ifndef mlpack_TESTS_SERIALIZATION_HPP
15 #define mlpack_TESTS_SERIALIZATION_HPP
16 
17 #include <boost/serialization/serialization.hpp>
18 #include <boost/archive/xml_iarchive.hpp>
19 #include <boost/archive/xml_oarchive.hpp>
20 #include <boost/archive/text_iarchive.hpp>
21 #include <boost/archive/text_oarchive.hpp>
22 #include <boost/archive/binary_iarchive.hpp>
23 #include <boost/archive/binary_oarchive.hpp>
24 #include <mlpack/core.hpp>
25 
26 #include <boost/test/unit_test.hpp>
28 
29 namespace mlpack {
30 
31 // Test function for loading and saving Armadillo objects.
32 template<typename CubeType,
33  typename IArchiveType,
34  typename OArchiveType>
35 void TestArmadilloSerialization(arma::Cube<CubeType>& x)
36 {
37  // First save it.
38  std::ofstream ofs("test", std::ios::binary);
39  OArchiveType o(ofs);
40 
41  bool success = true;
42  try
43  {
44  o << BOOST_SERIALIZATION_NVP(x);
45  }
46  catch (boost::archive::archive_exception& e)
47  {
48  success = false;
49  }
50 
51  BOOST_REQUIRE_EQUAL(success, true);
52  ofs.close();
53 
54  // Now load it.
55  arma::Cube<CubeType> orig(x);
56  success = true;
57  std::ifstream ifs("test", std::ios::binary);
58  IArchiveType i(ifs);
59 
60  try
61  {
62  i >> BOOST_SERIALIZATION_NVP(x);
63  }
64  catch (boost::archive::archive_exception& e)
65  {
66  success = false;
67  }
68 
69  BOOST_REQUIRE_EQUAL(success, true);
70 
71  BOOST_REQUIRE_EQUAL(x.n_rows, orig.n_rows);
72  BOOST_REQUIRE_EQUAL(x.n_cols, orig.n_cols);
73  BOOST_REQUIRE_EQUAL(x.n_elem_slice, orig.n_elem_slice);
74  BOOST_REQUIRE_EQUAL(x.n_slices, orig.n_slices);
75  BOOST_REQUIRE_EQUAL(x.n_elem, orig.n_elem);
76 
77  for(size_t slice = 0; slice != x.n_slices; ++slice){
78  auto const &orig_slice = orig.slice(slice);
79  auto const &x_slice = x.slice(slice);
80  for (size_t i = 0; i < x.n_cols; ++i){
81  for (size_t j = 0; j < x.n_rows; ++j){
82  if (double(orig_slice(j, i)) == 0.0)
83  BOOST_REQUIRE_SMALL(double(x_slice(j, i)), 1e-8);
84  else
85  BOOST_REQUIRE_CLOSE(double(orig_slice(j, i)), double(x_slice(j, i)), 1e-8);
86  }
87  }
88  }
89 
90  remove("test");
91 }
92 
93 // Test all serialization strategies.
94 template<typename CubeType>
95 void TestAllArmadilloSerialization(arma::Cube<CubeType>& x)
96 {
97  TestArmadilloSerialization<CubeType, boost::archive::xml_iarchive,
98  boost::archive::xml_oarchive>(x);
99  TestArmadilloSerialization<CubeType, boost::archive::text_iarchive,
100  boost::archive::text_oarchive>(x);
101  TestArmadilloSerialization<CubeType, boost::archive::binary_iarchive,
102  boost::archive::binary_oarchive>(x);
103 }
104 
105 // Test function for loading and saving Armadillo objects.
106 template<typename MatType,
107  typename IArchiveType,
108  typename OArchiveType>
110 {
111  // First save it.
112  std::ofstream ofs("test", std::ios::binary);
113  OArchiveType o(ofs);
114 
115  bool success = true;
116  try
117  {
118  o << BOOST_SERIALIZATION_NVP(x);
119  }
120  catch (boost::archive::archive_exception& e)
121  {
122  success = false;
123  }
124 
125  BOOST_REQUIRE_EQUAL(success, true);
126  ofs.close();
127 
128  // Now load it.
129  MatType orig(x);
130  success = true;
131  std::ifstream ifs("test", std::ios::binary);
132  IArchiveType i(ifs);
133 
134  try
135  {
136  i >> BOOST_SERIALIZATION_NVP(x);
137  }
138  catch (boost::archive::archive_exception& e)
139  {
140  success = false;
141  }
142 
143  BOOST_REQUIRE_EQUAL(success, true);
144 
145  BOOST_REQUIRE_EQUAL(x.n_rows, orig.n_rows);
146  BOOST_REQUIRE_EQUAL(x.n_cols, orig.n_cols);
147  BOOST_REQUIRE_EQUAL(x.n_elem, orig.n_elem);
148 
149  for (size_t i = 0; i < x.n_cols; ++i)
150  for (size_t j = 0; j < x.n_rows; ++j)
151  if (double(orig(j, i)) == 0.0)
152  BOOST_REQUIRE_SMALL(double(x(j, i)), 1e-8);
153  else
154  BOOST_REQUIRE_CLOSE(double(orig(j, i)), double(x(j, i)), 1e-8);
155 
156  remove("test");
157 }
158 
159 // Test all serialization strategies.
160 template<typename MatType>
162 {
163  TestArmadilloSerialization<MatType, boost::archive::xml_iarchive,
164  boost::archive::xml_oarchive>(x);
165  TestArmadilloSerialization<MatType, boost::archive::text_iarchive,
166  boost::archive::text_oarchive>(x);
167  TestArmadilloSerialization<MatType, boost::archive::binary_iarchive,
168  boost::archive::binary_oarchive>(x);
169 }
170 
171 // Save and load an mlpack object.
172 // The re-loaded copy is placed in 'newT'.
173 template<typename T, typename IArchiveType, typename OArchiveType>
174 void SerializeObject(T& t, T& newT)
175 {
176  std::ofstream ofs("test", std::ios::binary);
177  OArchiveType o(ofs);
178 
179  bool success = true;
180  try
181  {
182  o << data::CreateNVP(t, "t");
183  }
184  catch (boost::archive::archive_exception& e)
185  {
186  success = false;
187  }
188  ofs.close();
189 
190  BOOST_REQUIRE_EQUAL(success, true);
191 
192  std::ifstream ifs("test", std::ios::binary);
193  IArchiveType i(ifs);
194 
195  try
196  {
197  i >> data::CreateNVP(newT, "t");
198  }
199  catch (boost::archive::archive_exception& e)
200  {
201  success = false;
202  }
203  ifs.close();
204 
205  BOOST_REQUIRE_EQUAL(success, true);
206 }
207 
208 // Test mlpack serialization with all three archive types.
209 template<typename T>
210 void SerializeObjectAll(T& t, T& xmlT, T& textT, T& binaryT)
211 {
212  SerializeObject<T, boost::archive::text_iarchive,
213  boost::archive::text_oarchive>(t, textT);
214  SerializeObject<T, boost::archive::binary_iarchive,
215  boost::archive::binary_oarchive>(t, binaryT);
216  SerializeObject<T, boost::archive::xml_iarchive,
217  boost::archive::xml_oarchive>(t, xmlT);
218 }
219 
220 // Save and load a non-default-constructible mlpack object.
221 template<typename T, typename IArchiveType, typename OArchiveType>
222 void SerializePointerObject(T* t, T*& newT)
223 {
224  std::ofstream ofs("test", std::ios::binary);
225  OArchiveType o(ofs);
226 
227  bool success = true;
228  try
229  {
230  o << data::CreateNVP(*t, "t");
231  }
232  catch (boost::archive::archive_exception& e)
233  {
234  success = false;
235  }
236  ofs.close();
237 
238  BOOST_REQUIRE_EQUAL(success, true);
239 
240  std::ifstream ifs("test", std::ios::binary);
241  IArchiveType i(ifs);
242 
243  try
244  {
245  newT = new T(i);
246  }
247  catch (std::exception& e)
248  {
249  success = false;
250  }
251  ifs.close();
252 
253  BOOST_REQUIRE_EQUAL(success, true);
254 }
255 
256 template<typename T>
257 void SerializePointerObjectAll(T* t, T*& xmlT, T*& textT, T*& binaryT)
258 {
259  SerializePointerObject<T, boost::archive::text_iarchive,
260  boost::archive::text_oarchive>(t, textT);
261  SerializePointerObject<T, boost::archive::binary_iarchive,
262  boost::archive::binary_oarchive>(t, binaryT);
263  SerializePointerObject<T, boost::archive::xml_iarchive,
264  boost::archive::xml_oarchive>(t, xmlT);
265 }
266 
267 // Utility function to check the equality of two Armadillo matrices.
268 void CheckMatrices(const arma::mat& x,
269  const arma::mat& xmlX,
270  const arma::mat& textX,
271  const arma::mat& binaryX);
272 
273 void CheckMatrices(const arma::Mat<size_t>& x,
274  const arma::Mat<size_t>& xmlX,
275  const arma::Mat<size_t>& textX,
276  const arma::Mat<size_t>& binaryX);
277 
278 } // namespace mlpack
279 
280 #endif
void SerializePointerObject(T *t, T *&newT)
Linear algebra utility functions, generally performed on matrices or vectors.
FirstShim< T > CreateNVP(T &t, const std::string &name, typename boost::enable_if< HasSerialize< T >>::type *=0)
Call this function to produce a name-value pair; this is similar to BOOST_SERIALIZATION_NVP(), but should be used for types that have a Serialize() function (or contain a type that has a Serialize() function) instead of a serialize() function.
void CheckMatrices(const arma::mat &x, const arma::mat &xmlX, const arma::mat &textX, const arma::mat &binaryX)
void TestArmadilloSerialization(arma::Cube< CubeType > &x)
void SerializePointerObjectAll(T *t, T *&xmlT, T *&textT, T *&binaryT)
void SerializeObjectAll(T &t, T &xmlT, T &textT, T &binaryT)
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
void SerializeObject(T &t, T &newT)
void TestAllArmadilloSerialization(arma::Cube< CubeType > &x)