get_param.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_BINDINGS_CLI_GET_PARAM_HPP
13 #define MLPACK_BINDINGS_CLI_GET_PARAM_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 #include "parameter_type.hpp"
17 
18 namespace mlpack {
19 namespace bindings {
20 namespace cli {
21 
28 template<typename T>
30  util::ParamData& d,
31  const typename boost::disable_if<arma::is_arma_type<T>>::type* = 0,
32  const typename boost::disable_if<data::HasSerialize<T>>::type* = 0,
33  const typename boost::disable_if<std::is_same<T,
34  std::tuple<mlpack::data::DatasetInfo, arma::mat>>>::type* = 0)
35 {
36  // No mapping is needed, so just cast it directly.
37  return *boost::any_cast<T>(&d.value);
38 }
39 
45 template<typename T>
47  util::ParamData& d,
48  const typename boost::enable_if<arma::is_arma_type<T>>::type* = 0)
49 {
50  // If the matrix is an input matrix, we have to load the matrix. 'value'
51  // contains the filename. It's possible we could load empty matrices many
52  // times, but I am not bothered by that---it shouldn't be something that
53  // happens.
54  typedef std::tuple<T, typename ParameterType<T>::type> TupleType;
55  TupleType& tuple = *boost::any_cast<TupleType>(&d.value);
56  const std::string& value = std::get<1>(tuple);
57  T& matrix = std::get<0>(tuple);
58  if (d.input && !d.loaded)
59  {
60  // Call correct data::Load() function.
61  if (arma::is_Row<T>::value || arma::is_Col<T>::value)
62  data::Load(value, matrix, true);
63  else
64  data::Load(value, matrix, true, !d.noTranspose);
65  d.loaded = true;
66  }
67 
68  return matrix;
69 }
70 
76 template<typename T>
78  util::ParamData& d,
79  const typename boost::enable_if<std::is_same<T,
80  std::tuple<mlpack::data::DatasetInfo, arma::mat>>>::type* = 0)
81 {
82  // If this is an input parameter, we need to load both the matrix and the
83  // dataset info.
84  typedef std::tuple<T, std::string> TupleType;
85  TupleType* tuple = boost::any_cast<TupleType>(&d.value);
86  const std::string& value = std::get<1>(*tuple);
87  T& t = std::get<0>(*tuple);
88  if (d.input && !d.loaded)
89  {
90  data::Load(value, std::get<1>(t), std::get<0>(t), true, !d.noTranspose);
91  d.loaded = true;
92  }
93 
94  return t;
95 }
96 
102 template<typename T>
104  util::ParamData& d,
105  const typename boost::disable_if<arma::is_arma_type<T>>::type* = 0,
106  const typename boost::enable_if<data::HasSerialize<T>>::type* = 0)
107 {
108  // If the model is an input model, we have to load it from file. 'value'
109  // contains the filename.
110  typedef std::tuple<T*, std::string> TupleType;
111  TupleType* tuple = boost::any_cast<TupleType>(&d.value);
112  const std::string& value = std::get<1>(*tuple);
113  if (d.input && !d.loaded)
114  {
115  T* model = new T();
116  data::Load(value, "model", *model, true);
117  d.loaded = true;
118  std::get<0>(*tuple) = model;
119  }
120  return std::get<0>(*tuple);
121 }
122 
131 template<typename T>
132 void GetParam(const util::ParamData& d, const void* /* input */, void* output)
133 {
134  // Cast to the correct type.
135  *((T**) output) = &GetParam<typename std::remove_pointer<T>::type>(
136  const_cast<util::ParamData&>(d));
137 }
138 
139 } // namespace cli
140 } // namespace bindings
141 } // namespace mlpack
142 
143 #endif
boost::any value
The actual value that is held.
Definition: param_data.hpp:82
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
bool input
True if this option is an input option (otherwise, it is output).
Definition: param_data.hpp:73
This structure holds all of the information about a single parameter, including its value (which is s...
Definition: param_data.hpp:52
bool loaded
If this is an input parameter that needs extra loading, this indicates whether or not it has been loa...
Definition: param_data.hpp:76
bool Load(const std::string &filename, arma::Mat< eT > &matrix, const bool fatal=false, const bool transpose=true)
Loads a matrix from file, guessing the filetype from the extension.
T & GetParam(util::ParamData &d, const typename boost::disable_if< arma::is_arma_type< T >>::type *=0, const typename boost::disable_if< data::HasSerialize< T >>::type *=0, const typename boost::disable_if< std::is_same< T, std::tuple< mlpack::data::DatasetInfo, arma::mat >>>::type *=0)
This overload is called when nothing special needs to happen to the name of the parameter.
Definition: get_param.hpp:29
bool noTranspose
True if this is a matrix that should not be transposed.
Definition: param_data.hpp:69