decision_tree.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "gini_gain.hpp"
18 #include "information_gain.hpp"
21 #include "all_dimension_select.hpp"
22 #include <type_traits>
23 
24 namespace mlpack {
25 namespace tree {
26 
34 template<typename FitnessFunction = GiniGain,
35  template<typename> class NumericSplitType = BestBinaryNumericSplit,
36  template<typename> class CategoricalSplitType = AllCategoricalSplit,
37  typename DimensionSelectionType = AllDimensionSelect,
38  typename ElemType = double,
39  bool NoRecursion = false>
40 class DecisionTree :
41  public NumericSplitType<FitnessFunction>::template
42  AuxiliarySplitInfo<ElemType>,
43  public CategoricalSplitType<FitnessFunction>::template
44  AuxiliarySplitInfo<ElemType>
45 {
46  public:
48  typedef NumericSplitType<FitnessFunction> NumericSplit;
50  typedef CategoricalSplitType<FitnessFunction> CategoricalSplit;
52  typedef DimensionSelectionType DimensionSelection;
53 
71  template<typename MatType, typename LabelsType>
72  DecisionTree(MatType data,
73  const data::DatasetInfo& datasetInfo,
74  LabelsType labels,
75  const size_t numClasses,
76  const size_t minimumLeafSize = 10,
77  const double minimumGainSplit = 1e-7,
78  const size_t maximumDepth = 0,
79  DimensionSelectionType dimensionSelector =
80  DimensionSelectionType());
81 
98  template<typename MatType, typename LabelsType>
99  DecisionTree(MatType data,
100  LabelsType labels,
101  const size_t numClasses,
102  const size_t minimumLeafSize = 10,
103  const double minimumGainSplit = 1e-7,
104  const size_t maximumDepth = 0,
105  DimensionSelectionType dimensionSelector =
106  DimensionSelectionType());
107 
127  template<typename MatType, typename LabelsType, typename WeightsType>
128  DecisionTree(
129  MatType data,
130  const data::DatasetInfo& datasetInfo,
131  LabelsType labels,
132  const size_t numClasses,
133  WeightsType weights,
134  const size_t minimumLeafSize = 10,
135  const double minimumGainSplit = 1e-7,
136  const size_t maximumDepth = 0,
137  DimensionSelectionType dimensionSelector = DimensionSelectionType(),
138  const std::enable_if_t<arma::is_arma_type<
139  typename std::remove_reference<WeightsType>::type>::value>* = 0);
140 
159  template<typename MatType, typename LabelsType, typename WeightsType>
160  DecisionTree(
161  const DecisionTree& other,
162  MatType data,
163  const data::DatasetInfo& datasetInfo,
164  LabelsType labels,
165  const size_t numClasses,
166  WeightsType weights,
167  const size_t minimumLeafSize = 10,
168  const double minimumGainSplit = 1e-7,
169  const std::enable_if_t<arma::is_arma_type<
170  typename std::remove_reference<WeightsType>::type>::value>* = 0);
189  template<typename MatType, typename LabelsType, typename WeightsType>
190  DecisionTree(
191  MatType data,
192  LabelsType labels,
193  const size_t numClasses,
194  WeightsType weights,
195  const size_t minimumLeafSize = 10,
196  const double minimumGainSplit = 1e-7,
197  const size_t maximumDepth = 0,
198  DimensionSelectionType dimensionSelector = DimensionSelectionType(),
199  const std::enable_if_t<arma::is_arma_type<
200  typename std::remove_reference<WeightsType>::type>::value>* = 0);
201 
220  template<typename MatType, typename LabelsType, typename WeightsType>
221  DecisionTree(
222  const DecisionTree& other,
223  MatType data,
224  LabelsType labels,
225  const size_t numClasses,
226  WeightsType weights,
227  const size_t minimumLeafSize = 10,
228  const double minimumGainSplit = 1e-7,
229  const size_t maximumDepth = 0,
230  DimensionSelectionType dimensionSelector = DimensionSelectionType(),
231  const std::enable_if_t<arma::is_arma_type<
232  typename std::remove_reference<WeightsType>::type>::value>* = 0);
233 
240  DecisionTree(const size_t numClasses = 1);
241 
248  DecisionTree(const DecisionTree& other);
249 
255  DecisionTree(DecisionTree&& other);
256 
263  DecisionTree& operator=(const DecisionTree& other);
264 
271 
275  ~DecisionTree();
276 
297  template<typename MatType, typename LabelsType>
298  double Train(MatType data,
299  const data::DatasetInfo& datasetInfo,
300  LabelsType labels,
301  const size_t numClasses,
302  const size_t minimumLeafSize = 10,
303  const double minimumGainSplit = 1e-7,
304  const size_t maximumDepth = 0,
305  DimensionSelectionType dimensionSelector =
306  DimensionSelectionType());
307 
326  template<typename MatType, typename LabelsType>
327  double Train(MatType data,
328  LabelsType labels,
329  const size_t numClasses,
330  const size_t minimumLeafSize = 10,
331  const double minimumGainSplit = 1e-7,
332  const size_t maximumDepth = 0,
333  DimensionSelectionType dimensionSelector =
334  DimensionSelectionType());
335 
357  template<typename MatType, typename LabelsType, typename WeightsType>
358  double Train(MatType data,
359  const data::DatasetInfo& datasetInfo,
360  LabelsType labels,
361  const size_t numClasses,
362  WeightsType weights,
363  const size_t minimumLeafSize = 10,
364  const double minimumGainSplit = 1e-7,
365  const size_t maximumDepth = 0,
366  DimensionSelectionType dimensionSelector =
367  DimensionSelectionType(),
368  const std::enable_if_t<arma::is_arma_type<typename
369  std::remove_reference<WeightsType>::type>::value>* = 0);
370 
390  template<typename MatType, typename LabelsType, typename WeightsType>
391  double Train(MatType data,
392  LabelsType labels,
393  const size_t numClasses,
394  WeightsType weights,
395  const size_t minimumLeafSize = 10,
396  const double minimumGainSplit = 1e-7,
397  const size_t maximumDepth = 0,
398  DimensionSelectionType dimensionSelector =
399  DimensionSelectionType(),
400  const std::enable_if_t<arma::is_arma_type<typename
401  std::remove_reference<WeightsType>::type>::value>* = 0);
402 
409  template<typename VecType>
410  size_t Classify(const VecType& point) const;
411 
421  template<typename VecType>
422  void Classify(const VecType& point,
423  size_t& prediction,
424  arma::vec& probabilities) const;
425 
433  template<typename MatType>
434  void Classify(const MatType& data,
435  arma::Row<size_t>& predictions) const;
436 
447  template<typename MatType>
448  void Classify(const MatType& data,
449  arma::Row<size_t>& predictions,
450  arma::mat& probabilities) const;
451 
455  template<typename Archive>
456  void serialize(Archive& ar, const unsigned int /* version */);
457 
459  size_t NumChildren() const { return children.size(); }
460 
462  const DecisionTree& Child(const size_t i) const { return *children[i]; }
464  DecisionTree& Child(const size_t i) { return *children[i]; }
465 
468  size_t SplitDimension() const { return splitDimension; }
469 
477  template<typename VecType>
478  size_t CalculateDirection(const VecType& point) const;
479 
483  size_t NumClasses() const;
484 
485  private:
487  std::vector<DecisionTree*> children;
489  size_t splitDimension;
492  size_t dimensionTypeOrMajorityClass;
500  arma::vec classProbabilities;
501 
505  typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
506  NumericAuxiliarySplitInfo;
507  typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
508  CategoricalAuxiliarySplitInfo;
509 
513  template<bool UseWeights, typename RowType, typename WeightsRowType>
514  void CalculateClassProbabilities(const RowType& labels,
515  const size_t numClasses,
516  const WeightsRowType& weights);
517 
535  template<bool UseWeights, typename MatType>
536  double Train(MatType& data,
537  const size_t begin,
538  const size_t count,
539  const data::DatasetInfo& datasetInfo,
540  arma::Row<size_t>& labels,
541  const size_t numClasses,
542  arma::rowvec& weights,
543  const size_t minimumLeafSize,
544  const double minimumGainSplit,
545  const size_t maximumDepth,
546  DimensionSelectionType& dimensionSelector);
547 
564  template<bool UseWeights, typename MatType>
565  double Train(MatType& data,
566  const size_t begin,
567  const size_t count,
568  arma::Row<size_t>& labels,
569  const size_t numClasses,
570  arma::rowvec& weights,
571  const size_t minimumLeafSize,
572  const double minimumGainSplit,
573  const size_t maximumDepth,
574  DimensionSelectionType& dimensionSelector);
575 };
576 
580 template<typename FitnessFunction = GiniGain,
581  template<typename> class NumericSplitType = BestBinaryNumericSplit,
582  template<typename> class CategoricalSplitType = AllCategoricalSplit,
583  typename DimensionSelectType = AllDimensionSelect,
584  typename ElemType = double>
585 using DecisionStump = DecisionTree<FitnessFunction,
586  NumericSplitType,
587  CategoricalSplitType,
588  DimensionSelectType,
589  ElemType,
590  false>;
591 
600  double,
602 } // namespace tree
603 } // namespace mlpack
604 
605 // Include implementation.
606 #include "decision_tree_impl.hpp"
607 
608 #endif
size_t NumChildren() const
Get the number of children.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:58
strip_type.hpp
Definition: add_to_po.hpp:21
This class implements a generic decision tree learner.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
The core includes that mlpack expects; standard C++ includes and Armadillo.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
void serialize(Archive &ar, const unsigned int)
Serialize the tree.
The standard information gain criterion, used for calculating gain in decision trees.
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
size_t NumClasses() const
Get the number of classes in the tree.
DecisionTree< InformationGain, BestBinaryNumericSplit, AllCategoricalSplit, AllDimensionSelect, double, true > ID3DecisionStump
Convenience typedef for ID3 decision stumps (single level decision trees made with the ID3 algorithm)...
This dimension selection policy allows any dimension to be selected for splitting.
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
Definition: gini_gain.hpp:27
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
size_t SplitDimension() const
Get the split dimension (only meaningful if this is a non-leaf in a trained tree).
double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data.
~DecisionTree()
Clean up memory.