simple_tolerance_termination.hpp
Go to the documentation of this file.
1 
12 #ifndef _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
13 #define _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace amf {
19 
30 template <class MatType>
32 {
33  public:
35  SimpleToleranceTermination(const double tolerance = 1e-5,
36  const size_t maxIterations = 10000,
37  const size_t reverseStepTolerance = 3)
38  : tolerance(tolerance),
39  maxIterations(maxIterations),
40  reverseStepTolerance(reverseStepTolerance) {}
41 
47  void Initialize(const MatType& V)
48  {
49  residueOld = DBL_MAX;
50  iteration = 1;
51  residue = DBL_MIN;
52  reverseStepCount = 0;
53  isCopy = false;
54 
55  this->V = &V;
56 
57  c_index = 0;
58  c_indexOld = 0;
59 
60  reverseStepCount = 0;
61  }
62 
69  bool IsConverged(arma::mat& W, arma::mat& H)
70  {
71  arma::mat WH;
72 
73  WH = W * H;
74 
75  // compute residue
76  residueOld = residue;
77  size_t n = V->n_rows;
78  size_t m = V->n_cols;
79  double sum = 0;
80  size_t count = 0;
81  for (size_t i = 0; i < n; i++)
82  {
83  for (size_t j = 0; j < m; j++)
84  {
85  double temp = 0;
86  if ((temp = (*V)(i, j)) != 0)
87  {
88  temp = (temp - WH(i, j));
89  temp = temp * temp;
90  sum += temp;
91  count++;
92  }
93  }
94  }
95  residue = sum / count;
96  residue = sqrt(residue);
97 
98  // increment iteration count
99  iteration++;
100  Log::Info << "Iteration " << iteration << "; residue "
101  << ((residueOld - residue) / residueOld) << ".\n";
102 
103  // if residue tolerance is not satisfied
104  if ((residueOld - residue) / residueOld < tolerance && iteration > 4)
105  {
106  // check if this is a first of successive drops
107  if (reverseStepCount == 0 && isCopy == false)
108  {
109  // store a copy of W and H matrix
110  isCopy = true;
111  this->W = W;
112  this->H = H;
113  // store residue values
114  c_index = residue;
115  c_indexOld = residueOld;
116  }
117  // increase successive drop count
118  reverseStepCount++;
119  }
120  // if tolerance is satisfied
121  else
122  {
123  // initialize successive drop count
124  reverseStepCount = 0;
125  // if residue is droped below minimum scrap stored values
126  if (residue <= c_indexOld && isCopy == true)
127  {
128  isCopy = false;
129  }
130  }
131 
132  // check if termination criterion is met
133  if (reverseStepCount == reverseStepTolerance || iteration > maxIterations)
134  {
135  // if stored values are present replace them with current value as they
136  // represent the minimum residue point
137  if (isCopy)
138  {
139  W = this->W;
140  H = this->H;
141  residue = c_index;
142  }
143  return true;
144  }
145  else return false;
146  }
147 
149  const double& Index() const { return residue; }
150 
152  const size_t& Iteration() const { return iteration; }
153 
155  const size_t& MaxIterations() const { return maxIterations; }
156  size_t& MaxIterations() { return maxIterations; }
157 
159  const double& Tolerance() const { return tolerance; }
160  double& Tolerance() { return tolerance; }
161 
162  private:
164  double tolerance;
166  size_t maxIterations;
167 
169  const MatType* V;
170 
172  size_t iteration;
173 
175  double residueOld;
176  double residue;
177  double normOld;
178 
180  size_t reverseStepTolerance;
182  size_t reverseStepCount;
183 
186  bool isCopy;
187 
189  arma::mat W;
190  arma::mat H;
191  double c_indexOld;
192  double c_index;
193 }; // class SimpleToleranceTermination
194 
195 } // namespace amf
196 } // namespace mlpack
197 
198 #endif // _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
199 
.hpp
Definition: add_to_po.hpp:21
SimpleToleranceTermination(const double tolerance=1e-5, const size_t maxIterations=10000, const size_t reverseStepTolerance=3)
empty constructor
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Initialize(const MatType &V)
Initializes the termination policy before stating the factorization.
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
const double & Tolerance() const
Access tolerance value.
bool IsConverged(arma::mat &W, arma::mat &H)
Check if termination criterio is met.
This class implements residue tolerance termination policy.
const size_t & Iteration() const
Get current iteration count.
const double & Index() const
Get current value of residue.
const size_t & MaxIterations() const
Access upper limit of iteration count.