decision_tree.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "gini_gain.hpp"
18 #include "information_gain.hpp"
21 #include "all_dimension_select.hpp"
22 #include <type_traits>
23 
24 namespace mlpack {
25 namespace tree {
26 
34 template<typename FitnessFunction = GiniGain,
35  template<typename> class NumericSplitType = BestBinaryNumericSplit,
36  template<typename> class CategoricalSplitType = AllCategoricalSplit,
37  typename DimensionSelectionType = AllDimensionSelect,
38  typename ElemType = double,
39  bool NoRecursion = false>
40 class DecisionTree :
41  public NumericSplitType<FitnessFunction>::template
42  AuxiliarySplitInfo<ElemType>,
43  public CategoricalSplitType<FitnessFunction>::template
44  AuxiliarySplitInfo<ElemType>
45 {
46  public:
48  typedef NumericSplitType<FitnessFunction> NumericSplit;
50  typedef CategoricalSplitType<FitnessFunction> CategoricalSplit;
52  typedef DimensionSelectionType DimensionSelection;
53 
71  template<typename MatType, typename LabelsType>
72  DecisionTree(MatType data,
73  const data::DatasetInfo& datasetInfo,
74  LabelsType labels,
75  const size_t numClasses,
76  const size_t minimumLeafSize = 10,
77  const double minimumGainSplit = 1e-7,
78  const size_t maximumDepth = 0,
79  DimensionSelectionType dimensionSelector =
80  DimensionSelectionType());
81 
98  template<typename MatType, typename LabelsType>
99  DecisionTree(MatType data,
100  LabelsType labels,
101  const size_t numClasses,
102  const size_t minimumLeafSize = 10,
103  const double minimumGainSplit = 1e-7,
104  const size_t maximumDepth = 0,
105  DimensionSelectionType dimensionSelector =
106  DimensionSelectionType());
107 
127  template<typename MatType, typename LabelsType, typename WeightsType>
128  DecisionTree(
129  MatType data,
130  const data::DatasetInfo& datasetInfo,
131  LabelsType labels,
132  const size_t numClasses,
133  WeightsType weights,
134  const size_t minimumLeafSize = 10,
135  const double minimumGainSplit = 1e-7,
136  const size_t maximumDepth = 0,
137  DimensionSelectionType dimensionSelector = DimensionSelectionType(),
138  const std::enable_if_t<arma::is_arma_type<
139  typename std::remove_reference<WeightsType>::type>::value>* = 0);
140 
159  template<typename MatType, typename LabelsType, typename WeightsType>
160  DecisionTree(
161  const DecisionTree& other,
162  MatType data,
163  const data::DatasetInfo& datasetInfo,
164  LabelsType labels,
165  const size_t numClasses,
166  WeightsType weights,
167  const size_t minimumLeafSize = 10,
168  const double minimumGainSplit = 1e-7,
169  const std::enable_if_t<arma::is_arma_type<
170  typename std::remove_reference<WeightsType>::type>::value>* = 0);
189  template<typename MatType, typename LabelsType, typename WeightsType>
190  DecisionTree(
191  MatType data,
192  LabelsType labels,
193  const size_t numClasses,
194  WeightsType weights,
195  const size_t minimumLeafSize = 10,
196  const double minimumGainSplit = 1e-7,
197  const size_t maximumDepth = 0,
198  DimensionSelectionType dimensionSelector = DimensionSelectionType(),
199  const std::enable_if_t<arma::is_arma_type<
200  typename std::remove_reference<WeightsType>::type>::value>* = 0);
201 
220  template<typename MatType, typename LabelsType, typename WeightsType>
221  DecisionTree(
222  const DecisionTree& other,
223  MatType data,
224  LabelsType labels,
225  const size_t numClasses,
226  WeightsType weights,
227  const size_t minimumLeafSize = 10,
228  const double minimumGainSplit = 1e-7,
229  const size_t maximumDepth = 0,
230  DimensionSelectionType dimensionSelector = DimensionSelectionType(),
231  const std::enable_if_t<arma::is_arma_type<
232  typename std::remove_reference<WeightsType>::type>::value>* = 0);
233 
240  DecisionTree(const size_t numClasses = 1);
241 
248  DecisionTree(const DecisionTree& other);
249 
255  DecisionTree(DecisionTree&& other);
256 
263  DecisionTree& operator=(const DecisionTree& other);
264 
271 
275  ~DecisionTree();
276 
296  template<typename MatType, typename LabelsType>
297  double Train(MatType data,
298  const data::DatasetInfo& datasetInfo,
299  LabelsType labels,
300  const size_t numClasses,
301  const size_t minimumLeafSize = 10,
302  const double minimumGainSplit = 1e-7,
303  const size_t maximumDepth = 0,
304  DimensionSelectionType dimensionSelector =
305  DimensionSelectionType());
306 
324  template<typename MatType, typename LabelsType>
325  double Train(MatType data,
326  LabelsType labels,
327  const size_t numClasses,
328  const size_t minimumLeafSize = 10,
329  const double minimumGainSplit = 1e-7,
330  const size_t maximumDepth = 0,
331  DimensionSelectionType dimensionSelector =
332  DimensionSelectionType());
333 
355  template<typename MatType, typename LabelsType, typename WeightsType>
356  double Train(MatType data,
357  const data::DatasetInfo& datasetInfo,
358  LabelsType labels,
359  const size_t numClasses,
360  WeightsType weights,
361  const size_t minimumLeafSize = 10,
362  const double minimumGainSplit = 1e-7,
363  const size_t maximumDepth = 0,
364  DimensionSelectionType dimensionSelector =
365  DimensionSelectionType(),
366  const std::enable_if_t<arma::is_arma_type<typename
367  std::remove_reference<WeightsType>::type>::value>* = 0);
368 
388  template<typename MatType, typename LabelsType, typename WeightsType>
389  double Train(MatType data,
390  LabelsType labels,
391  const size_t numClasses,
392  WeightsType weights,
393  const size_t minimumLeafSize = 10,
394  const double minimumGainSplit = 1e-7,
395  const size_t maximumDepth = 0,
396  DimensionSelectionType dimensionSelector =
397  DimensionSelectionType(),
398  const std::enable_if_t<arma::is_arma_type<typename
399  std::remove_reference<WeightsType>::type>::value>* = 0);
400 
407  template<typename VecType>
408  size_t Classify(const VecType& point) const;
409 
419  template<typename VecType>
420  void Classify(const VecType& point,
421  size_t& prediction,
422  arma::vec& probabilities) const;
423 
431  template<typename MatType>
432  void Classify(const MatType& data,
433  arma::Row<size_t>& predictions) const;
434 
445  template<typename MatType>
446  void Classify(const MatType& data,
447  arma::Row<size_t>& predictions,
448  arma::mat& probabilities) const;
449 
453  template<typename Archive>
454  void serialize(Archive& ar, const unsigned int /* version */);
455 
457  size_t NumChildren() const { return children.size(); }
458 
460  const DecisionTree& Child(const size_t i) const { return *children[i]; }
462  DecisionTree& Child(const size_t i) { return *children[i]; }
463 
466  size_t SplitDimension() const { return splitDimension; }
467 
475  template<typename VecType>
476  size_t CalculateDirection(const VecType& point) const;
477 
481  size_t NumClasses() const;
482 
483  private:
485  std::vector<DecisionTree*> children;
487  size_t splitDimension;
490  size_t dimensionTypeOrMajorityClass;
498  arma::vec classProbabilities;
499 
503  typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
504  NumericAuxiliarySplitInfo;
505  typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
506  CategoricalAuxiliarySplitInfo;
507 
511  template<bool UseWeights, typename RowType, typename WeightsRowType>
512  void CalculateClassProbabilities(const RowType& labels,
513  const size_t numClasses,
514  const WeightsRowType& weights);
515 
533  template<bool UseWeights, typename MatType>
534  double Train(MatType& data,
535  const size_t begin,
536  const size_t count,
537  const data::DatasetInfo& datasetInfo,
538  arma::Row<size_t>& labels,
539  const size_t numClasses,
540  arma::rowvec& weights,
541  const size_t minimumLeafSize,
542  const double minimumGainSplit,
543  const size_t maximumDepth,
544  DimensionSelectionType& dimensionSelector);
545 
562  template<bool UseWeights, typename MatType>
563  double Train(MatType& data,
564  const size_t begin,
565  const size_t count,
566  arma::Row<size_t>& labels,
567  const size_t numClasses,
568  arma::rowvec& weights,
569  const size_t minimumLeafSize,
570  const double minimumGainSplit,
571  const size_t maximumDepth,
572  DimensionSelectionType& dimensionSelector);
573 };
574 
578 template<typename FitnessFunction = GiniGain,
579  template<typename> class NumericSplitType = BestBinaryNumericSplit,
580  template<typename> class CategoricalSplitType = AllCategoricalSplit,
581  typename DimensionSelectType = AllDimensionSelect,
582  typename ElemType = double>
583 using DecisionStump = DecisionTree<FitnessFunction,
584  NumericSplitType,
585  CategoricalSplitType,
586  DimensionSelectType,
587  ElemType,
588  false>;
589 
598  double,
600 } // namespace tree
601 } // namespace mlpack
602 
603 // Include implementation.
604 #include "decision_tree_impl.hpp"
605 
606 #endif
size_t NumChildren() const
Get the number of children.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:70
Linear algebra utility functions, generally performed on matrices or vectors.
This class implements a generic decision tree learner.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
The core includes that mlpack expects; standard C++ includes and Armadillo.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
void serialize(Archive &ar, const unsigned int)
Serialize the tree.
The standard information gain criterion, used for calculating gain in decision trees.
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
size_t NumClasses() const
Get the number of classes in the tree.
DecisionTree< InformationGain, BestBinaryNumericSplit, AllCategoricalSplit, AllDimensionSelect, double, true > ID3DecisionStump
Convenience typedef for ID3 decision stumps (single level decision trees made with the ID3 algorithm)...
This dimension selection policy allows any dimension to be selected for splitting.
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
Definition: gini_gain.hpp:27
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
size_t SplitDimension() const
Get the split dimension (only meaningful if this is a non-leaf in a trained tree).
double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data.
~DecisionTree()
Clean up memory.