mlpack: a scalable c++ machine learning library
mlpack  2.0.2
Go to the documentation of this file.
15 #ifndef mlpack_METHODS_DET_DTREE_HPP
16 #define mlpack_METHODS_DET_DTREE_HPP
18 #include <mlpack/core.hpp>
20 namespace mlpack {
21 namespace det {
46 class DTree
47 {
48  public:
52  DTree();
62  DTree(const arma::vec& maxVals,
63  const arma::vec& minVals,
64  const size_t totalPoints);
74  DTree(arma::mat& data);
88  DTree(const arma::vec& maxVals,
89  const arma::vec& minVals,
90  const size_t start,
91  const size_t end,
92  const double logNegError);
105  DTree(const arma::vec& maxVals,
106  const arma::vec& minVals,
107  const size_t totalPoints,
108  const size_t start,
109  const size_t end);
112  ~DTree();
124  double Grow(arma::mat& data,
125  arma::Col<size_t>& oldFromNew,
126  const bool useVolReg = false,
127  const size_t maxLeafSize = 10,
128  const size_t minLeafSize = 5);
138  double PruneAndUpdate(const double oldAlpha,
139  const size_t points,
140  const bool useVolReg = false);
147  double ComputeValue(const arma::vec& query) const;
156  void WriteTree(FILE *fp, const size_t level = 0) const;
165  int TagTree(const int tag = 0);
173  int FindBucket(const arma::vec& query) const;
180  void ComputeVariableImportance(arma::vec& importances) const;
188  double LogNegativeError(const size_t totalPoints) const;
193  bool WithinRange(const arma::vec& query) const;
195  private:
196  // The indices in the complete set of points
197  // (after all forms of swapping in the original data
198  // matrix to align all the points in a node
199  // consecutively in the matrix. The 'old_from_new' array
200  // maps the points back to their original indices.
204  size_t start;
207  size_t end;
210  arma::vec maxVals;
212  arma::vec minVals;
215  size_t splitDim;
218  double splitValue;
221  double logNegError;
230  bool root;
233  double ratio;
236  double logVolume;
242  double alphaUpper;
249  public:
251  size_t Start() const { return start; }
253  size_t End() const { return end; }
255  size_t SplitDim() const { return splitDim; }
257  double SplitValue() const { return splitValue; }
259  double LogNegError() const { return logNegError; }
263  size_t SubtreeLeaves() const { return subtreeLeaves; }
266  double Ratio() const { return ratio; }
268  double LogVolume() const { return logVolume; }
270  DTree* Left() const { return left; }
272  DTree* Right() const { return right; }
274  bool Root() const { return root; }
276  double AlphaUpper() const { return alphaUpper; }
279  const arma::vec& MaxVals() const { return maxVals; }
281  arma::vec& MaxVals() { return maxVals; }
284  const arma::vec& MinVals() const { return minVals; }
286  arma::vec& MinVals() { return minVals; }
291  template<typename Archive>
292  void Serialize(Archive& ar, const unsigned int /* version */)
293  {
294  using data::CreateNVP;
296  ar & CreateNVP(start, "start");
297  ar & CreateNVP(end, "end");
298  ar & CreateNVP(maxVals, "maxVals");
299  ar & CreateNVP(minVals, "minVals");
300  ar & CreateNVP(splitDim, "splitDim");
301  ar & CreateNVP(splitValue, "splitValue");
302  ar & CreateNVP(logNegError, "logNegError");
303  ar & CreateNVP(subtreeLeavesLogNegError, "subtreeLeavesLogNegError");
304  ar & CreateNVP(subtreeLeaves, "subtreeLeaves");
305  ar & CreateNVP(root, "root");
306  ar & CreateNVP(ratio, "ratio");
307  ar & CreateNVP(logVolume, "logVolume");
308  ar & CreateNVP(bucketTag, "bucketTag");
309  ar & CreateNVP(alphaUpper, "alphaUpper");
311  if (Archive::is_loading::value)
312  {
313  if (left)
314  delete left;
315  if (right)
316  delete right;
317  }
319  ar & CreateNVP(left, "left");
320  ar & CreateNVP(right, "right");
321  }
323  private:
325  // Utility methods.
330  bool FindSplit(const arma::mat& data,
331  size_t& splitDim,
332  double& splitValue,
333  double& leftError,
334  double& rightError,
335  const size_t minLeafSize = 5) const;
340  size_t SplitData(arma::mat& data,
341  const size_t splitDim,
342  const double splitValue,
343  arma::Col<size_t>& oldFromNew) const;
345 };
347 } // namespace det
348 } // namespace mlpack
350 #endif // mlpack_METHODS_DET_DTREE_HPP
arma::vec & MaxVals()
Modify the maximum values.
Definition: dtree.hpp:281
DTree * Left() const
Return the left child.
Definition: dtree.hpp:270
double splitValue
The split value on the splitting dimension for this node.
Definition: dtree.hpp:218
double logNegError
log-negative-L2-error of the node.
Definition: dtree.hpp:221
Clean up memory allocated by the tree.
DTree * left
The left child.
Definition: dtree.hpp:245
Linear algebra utility functions, generally performed on matrices or vectors.
FirstShim< T > CreateNVP(T &t, const std::string &name, typename boost::enable_if< HasSerialize< T >>::type *=0)
Call this function to produce a name-value pair; this is similar to BOOST_SERIALIZATION_NVP(), but should be used for types that have a Serialize() function (or contain a type that has a Serialize() function) instead of a serialize() function.
double LogNegativeError(const size_t totalPoints) const
Compute the log-negative-error for this point, given the total number of points in the dataset...
arma::vec maxVals
Upper half of bounding box for this node.
Definition: dtree.hpp:210
bool FindSplit(const arma::mat &data, size_t &splitDim, double &splitValue, double &leftError, double &rightError, const size_t minLeafSize=5) const
Find the dimension to split on.
double LogNegError() const
Return the log negative error of this node.
Definition: dtree.hpp:259
size_t subtreeLeaves
Number of leaves of the subtree.
Definition: dtree.hpp:227
const arma::vec & MinVals() const
Return the minimum values.
Definition: dtree.hpp:284
const arma::vec & MaxVals() const
Return the maximum values.
Definition: dtree.hpp:279
double SplitValue() const
Return the split value of this node.
Definition: dtree.hpp:257
DTree * Right() const
Return the right child.
Definition: dtree.hpp:272
size_t start
The index of the first point in the dataset contained in this node (and its children).
Definition: dtree.hpp:204
arma::vec & MinVals()
Modify the minimum values.
Definition: dtree.hpp:286
size_t splitDim
The splitting dimension for this node.
Definition: dtree.hpp:215
size_t SubtreeLeaves() const
Return the number of leaves which are descendants of this node.
Definition: dtree.hpp:263
double AlphaUpper() const
Return the upper part of the alpha sum.
Definition: dtree.hpp:276
double PruneAndUpdate(const double oldAlpha, const size_t points, const bool useVolReg=false)
Perform alpha pruning on a tree.
bool WithinRange(const arma::vec &query) const
Return whether a query point is within the range of this node.
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
size_t SplitData(arma::mat &data, const size_t splitDim, const double splitValue, arma::Col< size_t > &oldFromNew) const
Split the data, returning the number of points left of the split.
bool root
If true, this node is the root of the tree.
Definition: dtree.hpp:230
double ComputeValue(const arma::vec &query) const
Compute the logarithm of the density estimate of a given query point.
DTree * right
The right child.
Definition: dtree.hpp:247
size_t Start() const
Return the starting index of points contained in this node.
Definition: dtree.hpp:251
A density estimation tree is similar to both a decision tree and a space partitioning tree (like a kd...
Definition: dtree.hpp:46
double alphaUpper
Upper part of alpha sum; used for pruning.
Definition: dtree.hpp:242
double ratio
Ratio of the number of points in the node to the total number of points.
Definition: dtree.hpp:233
int bucketTag
The tag for the leaf, used for hashing points.
Definition: dtree.hpp:239
size_t SplitDim() const
Return the split dimension of this node.
Definition: dtree.hpp:255
void ComputeVariableImportance(arma::vec &importances) const
Compute the variable importance of each dimension in the learned tree.
double SubtreeLeavesLogNegError() const
Return the log negative error of all descendants of this node.
Definition: dtree.hpp:261
size_t End() const
Return the first index of a point not contained in this node.
Definition: dtree.hpp:253
double subtreeLeavesLogNegError
Sum of the error of the leaves of the subtree.
Definition: dtree.hpp:224
double logVolume
The logarithm of the volume of the node.
Definition: dtree.hpp:236
arma::vec minVals
Lower half of bounding box for this node.
Definition: dtree.hpp:212
void Serialize(Archive &ar, const unsigned int)
Serialize the density estimation tree.
Definition: dtree.hpp:292
bool Root() const
Return whether or not this is the root of the tree.
Definition: dtree.hpp:274
Create an empty density estimation tree.
double Grow(arma::mat &data, arma::Col< size_t > &oldFromNew, const bool useVolReg=false, const size_t maxLeafSize=10, const size_t minLeafSize=5)
Greedily expand the tree.
int FindBucket(const arma::vec &query) const
Return the tag of the leaf containing the query.
void WriteTree(FILE *fp, const size_t level=0) const
Print the tree in a depth-first manner (this function is called recursively).
int TagTree(const int tag=0)
Index the buckets for possible usage later; this results in every leaf in the tree having a specific ...
size_t end
The index of the last point in the dataset contained in this node (and its children).
Definition: dtree.hpp:207
double LogVolume() const
Return the inverse of the volume of this node.
Definition: dtree.hpp:268
double Ratio() const
Return the ratio of points in this node to the points in the whole dataset.
Definition: dtree.hpp:266