mlpack: a scalable c++ machine learning library
mlpack  2.0.2
svd_batch_learning.hpp
Go to the documentation of this file.
1 
14 #ifndef mlpack_METHODS_AMF_UPDATE_RULES_SVD_BATCH_LEARNING_HPP
15 #define mlpack_METHODS_AMF_UPDATE_RULES_SVD_BATCH_LEARNING_HPP
16 
17 #include <mlpack/core.hpp>
18 
19 namespace mlpack {
20 namespace amf {
21 
44 {
45  public:
54  SVDBatchLearning(double u = 0.0002,
55  double kw = 0,
56  double kh = 0,
57  double momentum = 0.9)
58  : u(u), kw(kw), kh(kh), momentum(momentum)
59  {
60  // empty constructor
61  }
62 
70  template<typename MatType>
71  void Initialize(const MatType& dataset, const size_t rank)
72  {
73  const size_t n = dataset.n_rows;
74  const size_t m = dataset.n_cols;
75 
76  mW.zeros(n, rank);
77  mH.zeros(rank, m);
78  }
79 
89  template<typename MatType>
90  inline void WUpdate(const MatType& V,
91  arma::mat& W,
92  const arma::mat& H)
93  {
94  size_t n = V.n_rows;
95  size_t m = V.n_cols;
96 
97  size_t r = W.n_cols;
98 
99  // initialize the momentum of this iteration.
100  mW = momentum * mW;
101 
102  // Compute the step.
103  arma::mat deltaW;
104  deltaW.zeros(n, r);
105  for (size_t i = 0; i < n; i++)
106  {
107  for (size_t j = 0; j < m; j++)
108  {
109  const double val = V(i, j);
110  if (val != 0)
111  deltaW.row(i) += (val - arma::dot(W.row(i), H.col(j))) *
112  arma::trans(H.col(j));
113  }
114  // Add regularization.
115  if (kw != 0)
116  deltaW.row(i) -= kw * W.row(i);
117  }
118 
119  // Add the step to the momentum.
120  mW += u * deltaW;
121  // Add the momentum to the W matrix.
122  W += mW;
123  }
124 
134  template<typename MatType>
135  inline void HUpdate(const MatType& V,
136  const arma::mat& W,
137  arma::mat& H)
138  {
139  size_t n = V.n_rows;
140  size_t m = V.n_cols;
141 
142  size_t r = W.n_cols;
143 
144  // Initialize the momentum of this iteration.
145  mH = momentum * mH;
146 
147  // Compute the step.
148  arma::mat deltaH;
149  deltaH.zeros(r, m);
150  for (size_t j = 0; j < m; j++)
151  {
152  for (size_t i = 0; i < n; i++)
153  {
154  const double val = V(i, j);
155  if (val != 0)
156  deltaH.col(j) += (val - arma::dot(W.row(i), H.col(j))) * W.row(i).t();
157  }
158  // Add regularization.
159  if (kh != 0)
160  deltaH.col(j) -= kh * H.col(j);
161  }
162 
163  // Add this step to the momentum.
164  mH += u * deltaH;
165  // Add the momentum to H.
166  H += mH;
167  }
168 
170  template<typename Archive>
171  void Serialize(Archive& ar, const unsigned int /* version */)
172  {
173  using data::CreateNVP;
174  ar & CreateNVP(u, "u");
175  ar & CreateNVP(kw, "kw");
176  ar & CreateNVP(kh, "kh");
177  ar & CreateNVP(momentum, "momentum");
178  ar & CreateNVP(mW, "mW");
179  ar & CreateNVP(mH, "mH");
180  }
181 
182  private:
184  double u;
186  double kw;
188  double kh;
190  double momentum;
191 
193  arma::mat mW;
195  arma::mat mH;
196 }; // class SVDBatchLearning
197 
200 
204 template<>
205 inline void SVDBatchLearning::WUpdate<arma::sp_mat>(const arma::sp_mat& V,
206  arma::mat& W,
207  const arma::mat& H)
208 {
209  const size_t n = V.n_rows;
210  const size_t r = W.n_cols;
211 
212  mW = momentum * mW;
213 
214  arma::mat deltaW;
215  deltaW.zeros(n, r);
216 
217  for (arma::sp_mat::const_iterator it = V.begin(); it != V.end(); ++it)
218  {
219  const size_t row = it.row();
220  const size_t col = it.col();
221  deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) *
222  arma::trans(H.col(col));
223  }
224 
225  if (kw != 0)
226  deltaW -= kw * W;
227 
228  mW += u * deltaW;
229  W += mW;
230 }
231 
232 template<>
233 inline void SVDBatchLearning::HUpdate<arma::sp_mat>(const arma::sp_mat& V,
234  const arma::mat& W,
235  arma::mat& H)
236 {
237  const size_t m = V.n_cols;
238  const size_t r = W.n_cols;
239 
240  mH = momentum * mH;
241 
242  arma::mat deltaH;
243  deltaH.zeros(r, m);
244 
245  for (arma::sp_mat::const_iterator it = V.begin(); it != V.end(); ++it)
246  {
247  const size_t row = it.row();
248  const size_t col = it.col();
249  deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) *
250  W.row(row).t();
251  }
252 
253  if (kh != 0)
254  deltaH -= kh * H;
255 
256  mH += u * deltaH;
257  H += mH;
258 }
259 
260 } // namespace amf
261 } // namespace mlpack
262 
263 #endif // mlpack_METHODS_AMF_UPDATE_RULES_SVD_BATCH_LEARNING_HPP
void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
double kh
Regularization parameter for matrix H.
void Serialize(Archive &ar, const unsigned int)
Serialize the SVDBatch object.
double u
Step size of the algorithm.
Linear algebra utility functions, generally performed on matrices or vectors.
FirstShim< T > CreateNVP(T &t, const std::string &name, typename boost::enable_if< HasSerialize< T >>::type *=0)
Call this function to produce a name-value pair; this is similar to BOOST_SERIALIZATION_NVP(), but should be used for types that have a Serialize() function (or contain a type that has a Serialize() function) instead of a serialize() function.
SVDBatchLearning(double u=0.0002, double kw=0, double kh=0, double momentum=0.9)
SVD Batch learning constructor.
void Initialize(const MatType &dataset, const size_t rank)
Initialize parameters before factorization.
double kw
Regularization parameter for matrix W.
This class implements SVD batch learning with momentum.
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
arma::mat mW
Momentum matrix for matrix W.
arma::mat mH
Momentum matrix for matrix H.
double momentum
Momentum value (between 0 and 1).