print_input_processing.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_BINDINGS_PYTHON_PRINT_INPUT_PROCESSING_HPP
14 #define MLPACK_BINDINGS_PYTHON_PRINT_INPUT_PROCESSING_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "get_arma_type.hpp"
18 #include "get_numpy_type.hpp"
19 #include "get_numpy_type_char.hpp"
20 #include "get_cython_type.hpp"
21 #include "strip_type.hpp"
22 
23 namespace mlpack {
24 namespace bindings {
25 namespace python {
26 
30 template<typename T>
32  util::ParamData& d,
33  const size_t indent,
34  const typename boost::disable_if<util::IsStdVector>::type* = 0,
35  const typename boost::disable_if>::type* = 0,
36  const typename boost::disable_if<data::HasSerialize>::type* = 0,
37  const typename boost::disable_if
38  std::tuple>>::type* = 0)
39 {
40  // The copy_all_inputs parameter must be handled first, and therefore is
41  // outside the scope of this code.
42  if (d.name == "copy_all_inputs")
43  return;
44 
45  const std::string prefix(indent, ' ');
46 
47  std::string def = "None";
48  if (std::is_same::value)
49  def = "False";
50 
51  // Make sure that we don't use names that are Python keywords.
52  std::string name = (d.name == "lambda") ? "lambda_" : d.name;
53 
65  std::cout << prefix << "# Detect if the parameter was passed; set if so."
66  << std::endl;
67  if (!d.required)
68  {
69  if (GetPrintableType(d) == "bool")
70  {
71  std::cout << prefix << "if isinstance(" << name << ", "
72  << GetPrintableType(d) << "):" << std::endl;
73  std::cout << prefix << " if " << name << " is not " << def << ":"
74  << std::endl;
75  }
76  else
77  {
78  std::cout << prefix << "if " << name << " is not " << def << ":"
79  << std::endl;
80  std::cout << prefix << " if isinstance(" << name << ", "
81  << GetPrintableType(d) << "):" << std::endl;
82  }
83 
84  std::cout << prefix << " SetParam[" << GetCythonType(d)
85  << "]( '" << d.name << "', ";
86  if (GetCythonType(d) == "string")
87  std::cout << name << ".encode(\"UTF-8\")";
88  else
89  std::cout << name;
90  std::cout << ")" << std::endl;
91  std::cout << prefix << " IO.SetPassed( '" << d.name
92  << "')" << std::endl;
93 
94  // If this parameter is "verbose", then enable verbose output.
95  if (d.name == "verbose")
96  std::cout << prefix << " EnableVerbose()" << std::endl;
97 
98  if (GetPrintableType(d) == "bool")
99  {
100  std::cout << " else:" << std::endl;
101  std::cout << " raise TypeError(" <<"\"'"<< name
102  << "' must have type \'" << GetPrintableType(d)
103  << "'!\")" << std::endl;
104  }
105  else
106  {
107  std::cout << " else:" << std::endl;
108  std::cout << " raise TypeError(" <<"\"'"<< name
109  << "' must have type \'" << GetPrintableType(d)
110  << "'!\")" << std::endl;
111  }
112  }
113  else
114  {
115  if (GetPrintableType(d) == "bool")
116  {
117  std::cout << prefix << "if isinstance(" << name << ", "
118  << GetPrintableType(d) << "):" << std::endl;
119  std::cout << prefix << " if " << name << " is not " << def << ":"
120  << std::endl;
121  }
122  else
123  {
124  std::cout << prefix << "if " << name << " is not " << def << ":"
125  << std::endl;
126  std::cout << prefix << " if isinstance(" << name << ", "
127  << GetPrintableType(d) << "):" << std::endl;
128  }
129 
130  std::cout << prefix << " SetParam[" << GetCythonType(d) << "](
131  << "string> '" << d.name << "', ";
132  if (GetCythonType(d) == "string")
133  std::cout << name << ".encode(\"UTF-8\")";
134  else if (GetCythonType(d) == "vector[string]")
135  std::cout << "[i.encode(\"UTF-8\") for i in " << name << "]";
136  else
137  std::cout << name;
138  std::cout << ")" << std::endl;
139  std::cout << prefix << " IO.SetPassed( '"
140  << d.name << "')" << std::endl;
141 
142  if (GetPrintableType(d) == "bool")
143  {
144  std::cout << " else:" << std::endl;
145  std::cout << " raise TypeError(" <<"\"'"<< name
146  << "' must have type \'" << GetPrintableType(d)
147  << "'!\")" << std::endl;
148  }
149  else
150  {
151  std::cout << " else:" << std::endl;
152  std::cout << " raise TypeError(" <<"\"'"<< name
153  << "' must have type \'" << GetPrintableType(d)
154  << "'!\")" << std::endl;
155  }
156  }
157  std::cout << std::endl; // Extra line is to clear up the code a bit.
158 }
159 
163 template<typename T>
165  util::ParamData& d,
166  const size_t indent,
167  const typename boost::disable_if>::type* = 0,
168  const typename boost::disable_if<data::HasSerialize>::type* = 0,
169  const typename boost::disable_if
170  std::tuple>>::type* = 0,
171  const typename boost::enable_if<util::IsStdVector>::type* = 0)
172 {
173  const std::string prefix(indent, ' ');
174 
189  std::cout << prefix << "# Detect if the parameter was passed; set if so."
190  << std::endl;
191  if (!d.required)
192  {
193  std::cout << prefix << "if " << d.name << " is not None:"
194  << std::endl;
195  std::cout << prefix << " if isinstance(" << d.name << ", list):"
196  << std::endl;
197  std::cout << prefix << " if len(" << d.name << ") > 0:"
198  << std::endl;
199  std::cout << prefix << " if isinstance(" << d.name << "[0], "
200  << GetPrintableType(d) << "):" << std::endl;
201  std::cout << prefix << " SetParam[" << GetCythonType(d)
202  << "]( '" << d.name << "', ";
203  // Strings need special handling.
204  if (GetCythonType(d) == "vector[string]")
205  std::cout << "[i.encode(\"UTF-8\") for i in " << d.name << "]";
206  else
207  std::cout << d.name;
208  std::cout << ")" << std::endl;
209  std::cout << prefix << " IO.SetPassed( '" << d.name
210  << "')" << std::endl;
211  std::cout << prefix << " else:" << std::endl;
212  std::cout << prefix << " raise TypeError(" <<"\"'"<< d.name
213  << "' must have type \'" << GetPrintableType(d)
214  << "'!\")" << std::endl;
215  std::cout << prefix << " else:" << std::endl;
216  std::cout << prefix << " raise TypeError(" <<"\"'"<< d.name
217  << "' must have type \'list'!\")" << std::endl;
218  }
219  else
220  {
221  std::cout << prefix << "if isinstance(" << d.name << ", list):"
222  << std::endl;
223  std::cout << prefix << " if len(" << d.name << ") > 0:"
224  << std::endl;
225  std::cout << prefix << " if isinstance(" << d.name << "[0], "
226  << GetPrintableType(d) << "):" << std::endl;
227  std::cout << prefix << " SetParam[" << GetCythonType(d)
228  << "]( '" << d.name << "', ";
229  // Strings need special handling.
230  if (GetCythonType(d) == "vector[string]")
231  std::cout << "[i.encode(\"UTF-8\") for i in " << d.name << "]";
232  else
233  std::cout << d.name;
234  std::cout << ")" << std::endl;
235  std::cout << prefix << " IO.SetPassed( '" << d.name
236  << "')" << std::endl;
237  std::cout << prefix << " else:" << std::endl;
238  std::cout << prefix << " raise TypeError(" <<"\"'"<< d.name
239  << "' must have type \'" << GetPrintableType(d)
240  << "'!\")" << std::endl;
241  std::cout << prefix << "else:" << std::endl;
242  std::cout << prefix << " raise TypeError(" <<"\"'"<< d.name
243  << "' must have type \'list'!\")" << std::endl;
244  }
245 }
246 
250 template<typename T>
252  util::ParamData& d,
253  const size_t indent,
254  const typename boost::disable_if<util::IsStdVector>::type* = 0,
255  const typename boost::enable_if>::type* = 0)
256 {
257  const std::string prefix(indent, ' ');
258 
274  std::cout << prefix << "# Detect if the parameter was passed; set if so."
275  << std::endl;
276  if (!d.required)
277  {
278  if (T::is_row || T::is_col)
279  {
280  std::cout << prefix << "if " << d.name << " is not None:" << std::endl;
281  std::cout << prefix << " " << d.name << "_tuple = to_matrix("
282  << d.name << ", dtype=" << GetNumpyType()
283  << ", copy=IO.HasParam('copy_all_inputs'))" << std::endl;
284  std::cout << prefix << " if len(" << d.name << "_tuple[0].shape) > 1:"
285  << std::endl;
286  std::cout << prefix << " if " << d.name << "_tuple[0]"
287  << ".shape[0] == 1 or " << d.name << "_tuple[0].shape[1] == 1:"
288  << std::endl;
289  std::cout << prefix << " " << d.name << "_tuple[0].shape = ("
290  << d.name << "_tuple[0].size,)" << std::endl;
291  std::cout << prefix << " " << d.name << "_mat = arma_numpy.numpy_to_"
292  << GetArmaType() << "_" << GetNumpyTypeChar() << "(" << d.name
293  << "_tuple[0], " << d.name << "_tuple[1])" << std::endl;
294  std::cout << prefix << " SetParam[" << GetCythonType(d)
295  << "]( '" << d.name << "', dereference("
296  << d.name << "_mat))"<< std::endl;
297  std::cout << prefix << " IO.SetPassed( '" << d.name
298  << "')" << std::endl;
299  std::cout << prefix << " del " << d.name << "_mat" << std::endl;
300  }
301  else
302  {
303  std::cout << prefix << "if " << d.name << " is not None:" << std::endl;
304  std::cout << prefix << " " << d.name << "_tuple = to_matrix("
305  << d.name << ", dtype=" << GetNumpyType()
306  << ", copy=IO.HasParam('copy_all_inputs'))" << std::endl;
307  std::cout << prefix << " if len(" << d.name << "_tuple[0].shape"
308  << ") < 2:" << std::endl;
309  std::cout << prefix << " " << d.name << "_tuple[0].shape = (" << d.name
310  << "_tuple[0].shape[0], 1)" << std::endl;
311  std::cout << prefix << " " << d.name << "_mat = arma_numpy.numpy_to_"
312  << GetArmaType() << "_" << GetNumpyTypeChar() << "(" << d.name
313  << "_tuple[0], " << d.name << "_tuple[1])" << std::endl;
314  std::cout << prefix << " SetParam[" << GetCythonType(d)
315  << "]( '" << d.name << "', dereference("
316  << d.name << "_mat))"<< std::endl;
317  std::cout << prefix << " IO.SetPassed( '" << d.name
318  << "')" << std::endl;
319  std::cout << prefix << " del " << d.name << "_mat" << std::endl;
320  }
321  }
322  else
323  {
324  if (T::is_row || T::is_col)
325  {
326  std::cout << prefix << d.name << "_tuple = to_matrix(" << d.name
327  << ", dtype=" << GetNumpyType()
328  << ", copy=IO.HasParam('copy_all_inputs'))" << std::endl;
329  std::cout << prefix << "if len(" << d.name << "_tuple[0].shape) > 1:"
330  << std::endl;
331  std::cout << prefix << " if " << d.name << "_tuple[0].shape[0] == 1 or "
332  << d.name << "_tuple[0].shape[1] == 1:" << std::endl;
333  std::cout << prefix << " " << d.name << "_tuple[0].shape = ("
334  << d.name << "_tuple[0].size,)" << std::endl;
335  std::cout << prefix << d.name << "_mat = arma_numpy.numpy_to_"
336  << GetArmaType() << "_" << GetNumpyTypeChar() << "(" << d.name
337  << "_tuple[0], " << d.name << "_tuple[1])" << std::endl;
338  std::cout << prefix << "SetParam[" << GetCythonType(d)
339  << "]( '" << d.name << "', dereference("
340  << d.name << "_mat))"<< std::endl;
341  std::cout << prefix << "IO.SetPassed( '" << d.name << "')"
342  << std::endl;
343  std::cout << prefix << "del " << d.name << "_mat" << std::endl;
344  }
345  else
346  {
347  std::cout << prefix << d.name << "_tuple = to_matrix(" << d.name
348  << ", dtype=" << GetNumpyType()
349  << ", copy=IO.HasParam('copy_all_inputs'))" << std::endl;
350  std::cout << prefix << "if len(" << d.name << "_tuple[0].shape) > 2:"
351  << std::endl;
352  std::cout << prefix << " " << d.name << "_tuple[0].shape = (" << d.name
353  << "_tuple[0].shape[0], 1)" << std::endl;
354  std::cout << prefix << d.name << "_mat = arma_numpy.numpy_to_"
355  << GetArmaType() << "_" << GetNumpyTypeChar() << "(" << d.name
356  << "_tuple[0], " << d.name << "_tuple[1])" << std::endl;
357  std::cout << prefix << "SetParam[" << GetCythonType(d)
358  << "]( '" << d.name << "', dereference(" << d.name
359  << "_mat))" << std::endl;
360  std::cout << prefix << "IO.SetPassed( '" << d.name << "')"
361  << std::endl;
362  std::cout << prefix << "del " << d.name << "_mat" << std::endl;
363  }
364  }
365  std::cout << std::endl;
366 }
367 
371 template<typename T>
373  util::ParamData& d,
374  const size_t indent,
375  const typename boost::disable_if<util::IsStdVector>::type* = 0,
376  const typename boost::disable_if>::type* = 0,
377  const typename boost::enable_if<data::HasSerialize>::type* = 0)
378 {
379  // First, get the correct class name if needed.
380  std::string strippedType, printedType, defaultsType;
381  StripType(d.cppType, strippedType, printedType, defaultsType);
382 
383  const std::string prefix(indent, ' ');
384 
401  std::cout << prefix << "# Detect if the parameter was passed; set if so."
402  << std::endl;
403  if (!d.required)
404  {
405  std::cout << prefix << "if " << d.name << " is not None:" << std::endl;
406  std::cout << prefix << " try:" << std::endl;
407  std::cout << prefix << " SetParamPtr[" << strippedType << "]('" << d.name
408  << "', (<" << strippedType << "Type?> " << d.name << ").modelptr, "
409  << "IO.HasParam('copy_all_inputs'))" << std::endl;
410  std::cout << prefix << " except TypeError as e:" << std::endl;
411  std::cout << prefix << " if type(" << d.name << ").__name__ == '"
412  << strippedType << "Type':" << std::endl;
413  std::cout << prefix << " SetParamPtr[" << strippedType << "]('"
414  << d.name << "', (<" << strippedType << "Type> " << d.name
415  << ").modelptr, IO.HasParam('copy_all_inputs'))" << std::endl;
416  std::cout << prefix << " else:" << std::endl;
417  std::cout << prefix << " raise e" << std::endl;
418  std::cout << prefix << " IO.SetPassed( '" << d.name << "')"
419  << std::endl;
420  }
421  else
422  {
423  std::cout << prefix << "try:" << std::endl;
424  std::cout << prefix << " SetParamPtr[" << strippedType << "]('" << d.name
425  << "', (<" << strippedType << "Type?> " << d.name << ").modelptr, "
426  << "IO.HasParam('copy_all_inputs'))" << std::endl;
427  std::cout << prefix << "except TypeError as e:" << std::endl;
428  std::cout << prefix << " if type(" << d.name << ").__name__ == '"
429  << strippedType << "Type':" << std::endl;
430  std::cout << prefix << " SetParamPtr[" << strippedType << "]('" << d.name
431  << "', (<" << strippedType << "Type> " << d.name << ").modelptr, "
432  << "IO.HasParam('copy_all_inputs'))" << std::endl;
433  std::cout << prefix << " else:" << std::endl;
434  std::cout << prefix << " raise e" << std::endl;
435  std::cout << prefix << "IO.SetPassed( '" << d.name << "')"
436  << std::endl;
437  }
438  std::cout << std::endl;
439 }
440 
444 template<typename T>
446  util::ParamData& d,
447  const size_t indent,
448  const typename boost::disable_if<util::IsStdVector>::type* = 0,
449  const typename boost::enable_if
450  std::tuple>>::type* = 0)
451 {
452  // The user should pass in a matrix type of some sort.
453  const std::string prefix(indent, ' ');
454 
466  std::cout << prefix << "cdef np.ndarray " << d.name << "_dims" << std::endl;
467  std::cout << prefix << "# Detect if the parameter was passed; set if so."
468  << std::endl;
469  if (!d.required)
470  {
471  std::cout << prefix << "if " << d.name << " is not None:" << std::endl;
472  std::cout << prefix << " " << d.name << "_tuple = to_matrix_with_info("
473  << d.name << ", dtype=np.double, copy=IO.HasParam('copy_all_inputs'))"
474  << std::endl;
475  std::cout << prefix << " if len(" << d.name << "_tuple[0].shape"
476  << ") < 2:" << std::endl;
477  std::cout << prefix << " " << d.name << "_tuple[0].shape = (" << d.name
478  << "_tuple[0].shape[0], 1)" << std::endl;
479  std::cout << prefix << " " << d.name << "_mat = arma_numpy.numpy_to_mat_d("
480  << d.name << "_tuple[0], " << d.name << "_tuple[1])" << std::endl;
481  std::cout << prefix << " " << d.name << "_dims = " << d.name
482  << "_tuple[2]" << std::endl;
483  std::cout << prefix << " SetParamWithInfo[arma.Mat[double]](
484  << "string> '" << d.name << "', dereference(" << d.name << "_mat), "
485  << " " << d.name << "_dims.data)" << std::endl;
486  std::cout << prefix << " IO.SetPassed( '" << d.name
487  << "')" << std::endl;
488  std::cout << prefix << " del " << d.name << "_mat" << std::endl;
489  }
490  else
491  {
492  std::cout << prefix << d.name << "_tuple = to_matrix_with_info(" << d.name
493  << ", dtype=np.double, copy=IO.HasParam('copy_all_inputs'))"
494  << std::endl;
495  std::cout << prefix << "if len(" << d.name << "_tuple[0].shape"
496  << ") < 2:" << std::endl;
497  std::cout << prefix << " " << d.name << "_tuple[0].shape = (" << d.name
498  << "_tuple[0].shape[0], 1)" << std::endl;
499  std::cout << prefix << d.name << "_mat = arma_numpy.numpy_to_mat_d("
500  << d.name << "_tuple[0], " << d.name << "_tuple[1])" << std::endl;
501  std::cout << prefix << d.name << "_dims = " << d.name << "_tuple[2]"
502  << std::endl;
503  std::cout << prefix << "SetParamWithInfo[arma.Mat[double]](
504  << "string> '" << d.name << "', dereference(" << d.name << "_mat), "
505  << " " << d.name << "_dims.data)" << std::endl;
506  std::cout << prefix << "IO.SetPassed( '" << d.name << "')"
507  << std::endl;
508  std::cout << prefix << "del " << d.name << "_mat" << std::endl;
509  }
510  std::cout << std::endl;
511 }
512 
524 template<typename T>
526  const void* input,
527  void* /* output */)
528 {
529  PrintInputProcessing::type>(d,
530  *((size_t*) input));
531 }
532 
533 } // namespace python
534 } // namespace bindings
535 } // namespace mlpack
536 
537 #endif
void PrintInputProcessing(util::ParamData &d, const size_t indent, const typename boost::disable_if< util::IsStdVector< T >>::type *=0, const typename boost::disable_if< arma::is_arma_type< T >>::type *=0, const typename boost::disable_if< data::HasSerialize< T >>::type *=0, const typename boost::disable_if< std::is_same< T, std::tuple< data::DatasetInfo, arma::mat >>>::type *=0)
Print input processing for a standard option type.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
python
Definition: CMakeLists.txt:6
This structure holds all of the information about a single parameter, including its value (which is s...
Definition: param_data.hpp:52
Metaprogramming structure for vector detection.
std::string name
Name of this parameter.
Definition: param_data.hpp:56
bool required
True if this option is required.
Definition: param_data.hpp:71
void StripType(const std::string &inputType, std::string &strippedType, std::string &printedType, std::string &defaultsType)
Given an input type like, e.g., "LogisticRegression<>", return three types that can be used in Python...
Definition: strip_type.hpp:28
std::string cppType
The true name of the type, as it would be written in C++.
Definition: param_data.hpp:84
if(NOT BUILD_GO_SHLIB) macro(add_go_binding name) endmacro() return() endif() endmacro() macro(post_go_setup) if(BUILD_GO_BINDINGS) file(APPEND "$
Definition: CMakeLists.txt:3