SpillTree
The SpillTree class represents a generic multidimensional binary space
partitioning tree that allows overlapping volumes between nodes, also known as a
โhybrid spill treeโ. It is heavily templatized to control splitting behavior
and other behaviors, and is the actual class underlying trees such as the
SPTree. In general, the SpillTree class is not meant to be
used directly, and instead one of the handful of variants should be used
instead:
The SpillTree is similar to the BinarySpaceTree,
except that the two children of a node are allowed to overlap, and thus a single
point can be contained in multiple branches of the tree. This can be useful to,
e.g., improve nearest neighbor performance when using defeatist traversals
without backtracking.
For users who want to use SpillTree directly or with custom behavior,
the full class is still detailed in the subsections below. SpillTree supports
the TreeType API and can be used
with mlpackโs tree-based algorithms, although using custom behavior may require
a template typedef.
- Template parameters
- Constructors
- Basic tree properties
- Bounding distances with the tree
HyperplaneTypetemplate parameterSplitTypetemplate parameter- Tree traversals
- Example usage
๐ See also
SPTreeMeanSPTreeNonOrtSPTreeNonOrtMeanSPTreeBinarySpaceTree- An Investigation of Practical Approximate Nearest Neighbor Algorithms (pdf)
- Tree-Independent Dual-Tree Algorithms (pdf)
๐ Template parameters
The SpillTree class takes five template parameters. The first three of
these are required by the
TreeType API
(see also
this more detailed section). The
full signature of the class is:
template<typename DistanceType,
typename StatisticType,
typename MatType,
template<typename HyperplaneDistanceType,
typename HyperplaneMatType> class HyperplaneType,
template<typename SplitDistanceType,
typename SplitMatType> class SplitType>
class SpillTree;
-
DistanceType: the distance metric to use for distance computations. By default, this isEuclideanDistance. StatisticType: this holds auxiliary information in each tree node. By default,EmptyStatisticis used, which holds no information.- See the
StatisticTypesection in theBinarySpaceTreedocumentation for more details.
- See the
-
MatType: the type of matrix used to represent points. Must be a type matching the Armadillo API. By default,arma::matis used, but other types such asarma::fmator similar will work just fine. HyperplaneType: the class defining the type of the hyperplane that will split each node. By default,AxisOrthogonalHyperplaneis used.- See the
HyperplaneTypesection for more details.
- See the
SplitType: the class defining how an individualSpillTreenode should be split. By default,MidpointSpaceSplitis used.- See the
SplitTypesection for more details.
- See the
Note that the TreeType API requires trees to have only three template
parameters. In order to use a SpillTree with its five template parameters
with an mlpack algorithm that needs a TreeType, it is easiest to define a
template typedef:
template<typename DistanceType, typename StatisticType, typename MatType>
using CustomTree = SpillTree<DistanceType, StatisticType, MatType,
CustomHyperplaneType, CustomSplitType>
Here, CustomHyperplaneType and CustomSplitType are the desired hyperplane
type and split strategy. This is the way that all SpillTree variants (such as
SPTree) are defined.
๐ Constructors
SpillTrees are constructed by iteratively finding splitting hyperplanes, and
points within a margin of the hyperplane are assigned to both child nodes.
Unlike the constructors of
BinarySpaceTree, the dataset is not
permuted during construction.
node = SpillTree(data, tau=0.0, maxLeafSize=20, rho=0.7)- Construct a
SpillTreeon the givendata, using the specified hyperparameters to control tree construction behavior. - Default template parameters are used, meaning that this tree will be a
SPTree. - By default, a reference to
datais stored. Ifdatagoes out of scope after tree construction, memory errors will occur! To avoid this, either pass the dataset or a copy withstd::move()(e.g.std::move(data)); when doing this,datawill be set to an empty matrix.
- Construct a
node = SpillTree<DistanceType, StatisticType, MatType, HyperplaneType, SplitType>(data, tau=0.0, maxLeafSize=20, rho=0.7)- Construct a
SpillTreeon the givendata, using custom template parameters, and using the specified hyperparameters to control tree construction behavior. - By default, a reference to
datais stored. Ifdatagoes out of scope after tree construction, memory errors will occur! To avoid this, either pass the dataset or a copy withstd::move()(e.g.std::move(data)); when doing this,datawill be set to an empty matrix.
- Construct a
node = SpillTree()- Construct an empty
SpillTreewith no children, no points, and default template parameters.
- Construct an empty
Notes:
-
The name
nodeis used here forSpillTreeobjects instead oftree, because eachSpillTreeobject is a single node in the tree. The constructor returns the node that is the root of the tree. -
Inserting individual points or removing individual points from a
SpillTreeis not supported, because this generally results in a tree with very suboptimal hyperplane splits. It is better to simply build a newSpillTreeon the modified dataset. For trees that support individual insertion and deletions, see theRectangleTreeclass and all its variants (e.g.RTree,RStarTree, etc.). -
See also the developer documentation on tree constructors.
๐ Constructor parameters:
| name | type | description | default |
|---|---|---|---|
data |
MatType |
Column-major matrix to build the tree on. | (N/A) |
tau |
double |
Width of spill margin: points within tau of the splitting hyperplane of a node will be contained in both left and right children. |
0.0 |
maxLeafSize |
size_t |
Maximum number of points to store in each leaf. | 20 |
rho |
double |
Balance threshold. When splitting, if either overlapping node would contain a fraction of more than rho of the points, a non-overlapping split is performed. Must be in the range [0.0, 1.0). |
0.7 |
Caveats:
-
taumust be manually tuned for the properties of each dataset; the default,0.0, will never allow overlap between nodes (and thus the created tree will essentially be a non-overlappingBinarySpaceTree). -
If
tauis set too large, nodes will overlap too much and search quality will be degraded. -
rhoimplicitly controls the depth of the tree by forcing very overlapping children to be non-overlapping. Asrhogets closer to1, more overlap is allowed, which in turn makes the tree deeper. Ifrhois set to0.5or less, then all splits will be non-overlapping (and the tree will essentially be aBinarySpaceTree).
๐ Basic tree properties
Once a SpillTree object is constructed, various properties of the tree
can be accessed or inspected. Many of these functions are required by the
TreeType API.
๐ Navigating the tree
-
node.NumChildren()returns the number of children innode. This is either2ifnodehas children, or0ifnodeis a leaf. -
node.IsLeaf()returns aboolindicating whether or notnodeis a leaf. node.Child(i)returns aSpillTree&that is theith child.imust be0or1.- This function should only be called if
node.NumChildren()is not0(e.g. ifnodeis not a leaf). Note that this returns a validSpillTree&that can itself be used just like the root node of the tree! node.Left()andnode.Right()are convenience functions specific toSpillTreethat will returnSpillTree*(pointers) to the left and right children, respectively, orNULLifnodehas no children.
node.Parent()will return aSpillTree*that points to the parent ofnode, orNULLifnodeis the root of theSpillTree.
๐ Accessing members of a tree
-
node.Overlap()will return aboolthat istrueifnodeโs children are overlapping, andfalseotherwise. node.Hyperplane()will return aHyperplaneType&object that represents the splitting hyperplane ofnode.- All points in
node.Left()are to the left ofnode.Hyperplane()ifnode.Overlap()isfalse; otherwise, all points innode.Left()are to the left ofnode.Hyperplane() + tau. - All points in
node.Right()are to the right ofnode.Hyperplane()ifnode.Overlap()isfalse; otherwise, all points innode.Right()are to the right ofnode.Hyperplane() - tau.
- All points in
node.Bound()will return aconst HRectBound&representing the bounding box associated withnode.- If a custom
HyperplaneTypeis specified, then theBoundTypeassociated with that hyperplane type is returned instead. - If a custom
DistanceTypeand/orMatTypeare specified, then aconst HRectBound<DistanceType, ElemType>&is returned (or aBoundTypewith thatDistanceType, if a customHyperplaneTypewas also specified).ElemTypeis the element type of the specifiedMatType(e.g.doubleforarma::mat,floatforarma::fmat, etc.).
- If a custom
-
node.Stat()will return aStatisticType&holding the statistics of the node that were computed during tree construction. node.Distance()will return aDistanceType&.
See also the developer documentation for basic tree functionality in mlpack.
๐ Accessing data held in a tree
-
node.Dataset()will return aconst MatType&that is the dataset the tree was built on. node.NumPoints()returns asize_tindicating the number of points held directly innode.- If
nodeis not a leaf, this will return0, asSpillTreeonly holds points directly in its leaves. - If
nodeis a leaf, then the number of points will be less than or equal to themaxLeafSizethat was specified when the tree was constructed.
- If
node.Point(i)returns asize_tindicating the index of theiโth point innode.Dataset().imust be in the range[0, node.NumPoints() - 1](inclusive).nodemust be a leaf (as non-leaves do not hold any points).- The
iโth point innodecan then be accessed asnode.Dataset().col(node.Point(i)). - Accessing the actual
iโth point itself can be done with, e.g.,node.Dataset().col(node.Point(i)).
node.NumDescendants()returns asize_tindicating the number of points held in all descendant leaves ofnode.- If
nodeis the root of the tree, thennode.NumDescendants()will be equal tonode.Dataset().n_cols.
- If
node.Descendant(i)returns asize_tindicating the index of theiโth descendant point innode.Dataset().imust be in the range[0, node.NumDescendants() - 1](inclusive).nodedoes not need to be a leaf.- The
iโth descendant point innodecan then be accessed asnode.Dataset().col(node.Descendant(i)). - Accessing the actual
iโth descendant itself can be done with, e.g.,node.Dataset().col(node.Descendant(i)).
๐ Accessing computed bound quantities of a tree
The following quantities are cached for each node in a SpillTree, and so
accessing them does not require any computation. In the documentation below,
ElemType is the element type of the given MatType; e.g., if MatType is
arma::mat, then ElemType is double.
node.FurthestPointDistance()returns anElemTyperepresenting the distance between the center of the bound ofnodeand the furthest point held bynode.- If
nodeis not a leaf, this returns 0 (becausenodedoes not hold any points).
- If
-
node.FurthestDescendantDistance()returns anElemTyperepresenting the distance between the center of the bound ofnodeand the furthest descendant point held bynode. -
node.MinimumBoundDistance()returns anElemTyperepresenting the minimum possible distance from the center of the node to any edge of its bound. node.ParentDistance()returns anElemTyperepresenting the distance between the center of the bound ofnodeand the center of the bound of its parent.- If
nodeis the root of the tree,0is returned.
- If
Note: for more details on each bound quantity, see the developer documentation on bound quantities for trees.
๐ Other functionality
node.Center(center)computes the center of the bound ofnodeand stores it incenter.centershould be of typearma::Col<ElemType>&, whereElemTypeis the element type of the specifiedMatType.centerwill be set to have size equivalent to the dimensionality of the dataset held bynode.- This is equivalent to calling
node.Bound().Center(center).
- A
SpillTreecan be serialized withdata::Save()anddata::Load().
๐ Bounding distances with the tree
The primary use of trees in mlpack is bounding distances to points or other tree nodes. The following functions can be used for these tasks.
node.GetNearestChild(point)node.GetFurthestChild(point)- Return a
size_tindicating the index of the child (0for left,1for right) that is closest to (or furthest from)point, with respect to theMinDistance()(orMaxDistance()) function. - If there is a tie,
0(the left child) is returned. - If
nodeis a leaf,0is returned. pointshould be a column vector type of the same type asMatType. (e.g., ifMatTypeisarma::mat, thenpointshould be anarma::vec.)
- Return a
node.GetNearestChild(other)node.GetFurthestChild(other)- Return a
size_tindicating the index of the child (0for left,1for right) that is closest to (or furthest from) theSpillTreenodeother, with respect to theMinDistance()(orMaxDistance()) function. - If there is a tie,
0(the left child) is returned. - If
nodeis a leaf,0is returned.
- Return a
node.MinDistance(point)node.MinDistance(other)- Return a
doubleindicating the minimum possible distance betweennodeandpoint, or theSpillTreenodeother. - This is equivalent to the minimum possible distance between any point
contained in the bounding hyperrectangle of
nodeandpoint, or between any point contained in the bounding hyperrectangle ofnodeand any point contained in the bounding hyperrectangle ofother. pointshould be a column vector type of the same type asMatType. (e.g., ifMatTypeisarma::mat, thenpointshould be anarma::vec.)
- Return a
node.MaxDistance(point)node.MaxDistance(other)- Return a
doubleindicating the maximum possible distance betweennodeandpoint, or theSpillTreenodeother. - This is equivalent to the maximum possible distance between any point
contained in the bounding hyperrectangle of
nodeandpoint, or between any point contained in the bounding hyperrectangle ofnodeand any point contained in the bounding hyperrectangle ofother. pointshould be a column vector type of the same type asMatType. (e.g., ifMatTypeisarma::mat, thenpointshould be anarma::vec.)
- Return a
node.RangeDistance(point)node.RangeDistance(other)- Return a
RangeType<ElemType>whose lower bound isnode.MinDistance(point)ornode.MinDistance(other), and whose upper bound isnode.MaxDistance(point)ornode.MaxDistance(other). ElemTypeis the element type ofMatType.pointshould be a column vector type of the same type asMatType. (e.g., ifMatTypeisarma::mat, thenpointshould be anarma::vec.)
- Return a
๐ Tree traversals
Like every mlpack tree, the SpillTree class provides a single-tree and
dual-tree traversal that can be paired
with a RuleType class to implement a
single-tree or dual-tree algorithm.
SpillTree::SingleTreeTraverser- Implements a depth-first single-tree traverser.
SpillTree::DualTreeTraverser- Implements a dual-depth-first dual-tree traverser.
However, the spill tree is primarily useful because the overlapping nodes allow
defeatist search to be effective. Defeatist search is non-backtracking: the
tree is traversed to one leaf only. For example, finding the approximate
nearest neighbor of a point p with defeatist search is done by recursing in
the tree, choosing the child with smallest minimum distance to p, and when a
leaf is encountered, choosing the closest point in the leaf to p as the
nearest neighbor. This is the strategy used in the
original spill tree paper (pdf).
Defeatist traversers, matching the API for a regular traversal are made available as the following two classes:
SpillTree::DefeatistSingleTreeTraverser- Implements a depth-first single-tree defeatist traverser with no backtracking. Traversal will terminate after the first leaf is visited.
SpillTree::DefeatistDualTreeTraverser- Implements a dual-depth-first dual-tree defeatist traversal with no backtracking. For each query leaf node, traversal will terminate after the first reference leaf node is visited.
Any RuleType that is being used with a
defeatist traversal, in addition to the functions required by the RuleType
API, must implement the following functions:
// This is only required for single-tree defeatist traversals.
// It should return the index of the branch that should be chosen for the given
// query point and reference node.
template<typename VecType, typename TreeType>
size_t GetBestChild(const VecType& queryPoint, TreeType& referenceNode);
// This is only required for dual-tree defeatist traversals.
// It should return the index of the best child of the reference node that
// should be chosen for the given query node.
template<typename TreeType>
size_t GetBestChild(TreeType& queryNode, TreeType& referenceNode);
// Return the minimum number of base cases (point-to-point computations) that
// are required during the traversal.
size_t MinimumBaseCases();
๐ HyperplaneType
Each node in a SpillTree corresponds to some region in space that contains all
of the descendant points in the node. Similar to the KDTree, this
region is a hyperrectangle; however, instead of representing that hyperrectangle
explicitly like the KDTree with the
HRectBound class, the SpillTree
represents the region implicitly, with each node storing only the hyperplane
and margin required to determine whether a point belongs to the left node, the
right node, or both.
The type of hyperplane (e.g. axis-aligned or arbitrary) can be controlled by the
HyperplaneType template parameter. mlpack supplies two drop-in classes that
can be used for HyperplaneType, and it is also possible to write a custom
HyperplaneType:
AxisOrthogonalHyperplane: uses hyperplanes that are axis-orthogonal (or axis-aligned).Hyperplane: uses arbitrary hyperplanes specified by any vector.- Custom
HyperplaneTypes: implement a fully customHyperplaneTypeclass
๐ AxisOrthogonalHyperplane
The AxisOrthogonalHyperplane class is used to provide an axis-orthogonal split
for a SpillTree. That is, whether or not a point is on the left or right side
of the split is a very efficient computation using only a single dimension of
the data.
- The
AxisOrthogonalHyperplaneclass defines the following two typedefs:-
AxisOrthogonalHyperplane::BoundType, which is the type of the bound used by the spill tree with this hyperplane, isHRectBound, orHRectBound<DistanceType, MatType>if customDistanceTypeand/orMatTypeare specified. -
AxisOrthogonalHyperplane::ProjVectorTypeisAxisParallelProjVector, a class that simply holds the index of the dimension of the projection vector.- For more details, see the source code.
-
-
An
AxisOrthogonalHyperplaneobjecth(e.g. returned withnode.Hyperplane()) has the following members:-
h.Project(point)returns adoublethat is the orthogonal projection ofpointonto the tangent vector of the hyperplaneh. -
h.Left(point)returnstrueifpointis to the left ofh. -
h.Right(point)returnstrueifpointis to the right ofh. -
h.Left(bound)returnstrueifbound(anAxisOrthogonalHyperplane::BoundType; see the bullet point above) is to the left ofh. -
h.Right(bound)returnstrueifbound(anAxisOrthogonalHyperplane::BoundType; see the bullet point above) is to the right ofh.
-
- An
AxisOrthogonalHyperplaneobject can be serialized withdata::Save()anddata::Load().
For more details, see the the source code.
๐ Hyperplane
The Hyperplane class is used to provide an arbitrary hyperplane split for a
SpillTree. The computation of whether or not a point is on the left or right
side of the split is less efficient than
AxisOrthogonalHyperplane, but Hyperplane is
able to represent any possible hyperplane.
- The
Hyperplaneclass defines the two following typedefs:-
Hyperplane::BoundType, which is the type of the bound used by the spill tree with this hyperplane, isBallBound, orBallBound<DistanceType>if a customDistanceTypeis specified. -
Hyperplane::ProjVectorTypeisProjVector<>, an arbitrary projection vector class that wraps anarma::vec.- If a custom
MatTypeis specified, thenHyperplane::ProjVectorTypeisProjVector<MatType>, which wraps a vector of the same type asMatType. - For more details, see the source code.
- If a custom
-
-
A
Hyperplaneobjecth(e.g. returned withnode.Hyperplane()) has the following members:-
h.Project(point)returns adoublethat is the orthogonal projection ofpointonto the tangent vector of the hyperplaneh. -
h.Left(point)returnstrueifpointis to the left ofh. -
h.Right(point)returnstrueifpointis to the right ofh. -
h.Left(bound)returnstrueifbound(anAxisOrthogonalHyperplane::BoundType; see the bullet point above) is to the left ofh. -
h.Right(bound)returnstrueifbound(aHyperplane::BoundType; see the bullet point above) is to the right ofh.
-
- A
Hyperplaneobject can be serialized withdata::Save()anddata::Load().
For more details, see the the source code.
๐ Custom HyperplaneTypes
Custom hyperplane types for a spill tree can be implemented via the
HyperplaneType template parameter. By default, the
AxisOrthogonalHyperplane hyperplane type is used,
but it is also possible to implement and use a custom HyperplaneType. Any
custom HyperplaneType class must implement the following signature:
// NOTE: the custom HyperplaneType class must take two template parameters.
template<typename DistanceType, typename MatType>
class HyperplaneType
{
public:
// The hyperplane type must specify these two public typedefs, which are used
// by the spill tree and the splitting strategy.
//
// Substitute HRectBound and ProjVector with your choices.
using BoundType = mlpack::HRectBound<DistanceType,
typename MatType::elem_type>;
using ProjVectorType = mlpack::ProjVector<MatType>;
// Empty constructor, which will construct an empty or default hyperplane.
HyperplaneType();
// Construct the HyperplaneType with the given projection vector and split
// value along that projection.
HyperplaneType(const ProjVectorType& projVector, double splitVal);
// Compute the projection of the given point (an `arma::vec` or similar type
// matching the Armadillo API and element type of `MatType`) onto the vector
// tangent to the hyperplane.
template<typename VecType>
double Project(const VecType& point) const;
// Return true if the point (an `arma::vec` or similar type matching the
// Armadillo API and element type of `MatType`) falls to the left of the
// hyperplane.
template<typename VecType>
double Left(const VecType& point) const;
// Return true if the point (an `arma::vec` or similar type matching the
// Armadillo API and element type of `MatType`) falls to the right of the
// hyperplane.
template<typename VecType>
double Right(const VecType& point) const;
// Return true if the given bound is fully to the left of the hyperplane.
bool Left(const BoundType& bound) const;
// Return true if the given bound is fully to the right of the hyperplane.
bool Right(const BoundType& bound) const;
// Serialize the hyperplane using cereal.
template<typename Archive>
void serialize(Archive& ar, const uint32_t version);
};
๐ SplitType
The SplitType template parameter controls the algorithm used to split each
node of a SpillTree while building. The splitting strategy used can be
entirely arbitraryโthe SplitType only needs to compute a
HyperplaneType to split a set of points.
mlpack provides two drop-in choices for SplitType, and it is also possible
to write a fully custom split:
MidpointSpaceSplit: split a set of points using a hyperplane built on the midpoint (median) of points in a dataset.MeanSpaceSplit: split a set of points using a hyperplane built on the mean (average) of points in a dataset.- Custom
SplitTypes: implement a fully customSplitTypeclass
๐ MidpointSpaceSplit
The MidpointSpaceSplit class is a splitting strategy that can be used by
SpillTree. It is the default strategy for splitting SPTrees
and NonOrtSPTrees.
The splitting strategy for the MidpointSpaceSplit class is, given a set of
points:
-
If
AxisOrthogonalHyperplaneis being used, then select the dimension with the maximum width, and use the midpoint of the pointsโ values in that dimension. -
If
Hyperplaneis being used, then estimate the furthest two points in the dataset by random sampling, and use the vector connecting those points as the tangent vector to the hyperplane. The midpoint of the points projected onto this hyperplane is used as the split value.
Note that MidpointSpaceSplit can only be used with a HyperplaneType with
HyperplaneType::ProjVectorType as either AxisAlignedProjVector or
ProjVector.
For implementation details, see the source code.
๐ MeanSpaceSplit
The MeanSpaceSplit class is a splitting strategy that can be used by
SpillTree. It is the splitting strategy used by the
MeanSPTree and the
NonOrtMeanSPTree classes.
The splitting strategy for the MeanSpaceSplit class is, given a set of
points:
-
If
AxisOrthogonalHyperplaneis being used, then select the dimension with the maximum width, and use the mean of the pointsโ values in that dimension. -
If
Hyperplaneis being used, then estimate the furthest two points in the dataset by random sampling, and use the vector connecting those points as the tangent vector to the hyperplane. The mean of the points projected onto this hyperplane is used as the split value.
Note that MeanSpaceSplit can only be used with a HyperplaneType with
HyperplaneType::ProjVectorType as either AxisAlignedProjVector or
ProjVector.
For implementation details, see the source code.
๐ Custom SplitTypes
Custom split strategies for a spill tree can be implemented via the
SplitType template parameter. By default, the
MidpointSpaceSplit splitting strategy is used, but it
is also possible to implement and use a custom SplitType. Any custom
SplitType class must implement the following signature:
// NOTE: the custom SplitType class must take two template parameters.
template<typename DistanceType, typename MatType>
class SplitType
{
public:
// The SplitType class is only required to provide one static function.
// Create a splitting hyperplane and store it in the given `HyperplaneType`,
// using the given data and bounding box `bound`. `data` will be an Armadillo
// matrix that is the entire dataset, and `points` are the indices of points
// in `data` that should be split.
template<typename HyperplaneType>
static bool SplitSpace(
const typename HyperplaneType::BoundType& bound,
const MatType& data,
const arma::Col<size_t>& points,
HyperplaneType& hyp);
};
๐ Example usage
The SpillTree class is only really necessary when a custom hyperplane type or
custom splitting strategy is intended to be used. For simpler use cases, one of
the typedefs of SpillTree (such as SPTree) will suffice.
For this reason, all of the examples below explicitly specify all five template
parameters of SPTree.
Writing a custom hyperplane type and
writing a custom splitting strategy are discussed
in the previous sections. Each of the parameters in the examples below can be
trivially changed for different behavior.
Build a SpillTree on the cloud dataset and print basic statistics about the
tree.
// See https://datasets.mlpack.org/cloud.csv.
arma::mat dataset;
mlpack::data::Load("cloud.csv", dataset, true);
// Build the spill tree with a tau (margin) of 0.2 and a leaf size of 10.
// (This means that nodes are split until they contain 10 or fewer points.)
//
// The std::move() means that `dataset` will be empty after this call, and no
// data will be copied during tree building.
mlpack::SpillTree<mlpack::EuclideanDistance,
mlpack::EmptyStatistic,
arma::mat,
mlpack::AxisOrthogonalHyperplane,
mlpack::MidpointSpaceSplit> tree(std::move(dataset), 0.2, 10);
// Print the bounding box of the root node.
std::cout << "Bounding box of root node:" << std::endl;
for (size_t i = 0; i < tree.Bound().Dim(); ++i)
{
std::cout << " - Dimension " << i << ": [" << tree.Bound()[i].Lo() << ", "
<< tree.Bound()[i].Hi() << "]." << std::endl;
}
std::cout << std::endl;
// Print the number of descendant points of the root, and of each of its
// children.
std::cout << "Descendant points of root: "
<< tree.NumDescendants() << "." << std::endl;
std::cout << "Descendant points of left child: "
<< tree.Left()->NumDescendants() << "." << std::endl;
std::cout << "Descendant points of right child: "
<< tree.Right()->NumDescendants() << "." << std::endl;
std::cout << std::endl;
// Compute the center of the SpillTree.
arma::vec center;
tree.Center(center);
std::cout << "Center of tree: " << center.t();
Build two SpillTrees on subsets of the corel dataset and compute minimum
and maximum distances between different nodes in the tree.
// See https://datasets.mlpack.org/corel-histogram.csv.
arma::mat dataset;
mlpack::data::Load("corel-histogram.csv", dataset, true);
// Convenience typedef for the tree type.
using TreeType = mlpack::SpillTree<mlpack::EuclideanDistance,
mlpack::EmptyStatistic,
arma::mat,
mlpack::AxisOrthogonalHyperplane,
mlpack::MidpointSpaceSplit>;
// Build trees on the first half and the second half of points. Use a tau
// (overlap) parameter of 0.3, which is tuned to this dataset, and a rho value
// of 0.6 to prevent the trees getting too deep.
TreeType tree1(dataset.cols(0, dataset.n_cols / 2), 0.3, 20, 0.6);
TreeType tree2(dataset.cols(dataset.n_cols / 2 + 1, dataset.n_cols - 1),
0.3, 20, 0.6);
// Compute the maximum distance between the trees.
std::cout << "Maximum distance between tree root nodes: "
<< tree1.MaxDistance(tree2) << "." << std::endl;
// Get the leftmost grandchild of the first tree's root---if it exists.
if (!tree1.IsLeaf() && !tree1.Child(0).IsLeaf())
{
TreeType& node1 = tree1.Child(0).Child(0);
// Get the rightmost grandchild of the second tree's root---if it exists.
if (!tree2.IsLeaf() && !tree2.Child(1).IsLeaf())
{
TreeType& node2 = tree2.Child(1).Child(1);
// Print the minimum and maximum distance between the nodes.
mlpack::Range dists = node1.RangeDistance(node2);
std::cout << "Possible distances between two grandchild nodes: ["
<< dists.Lo() << ", " << dists.Hi() << "]." << std::endl;
// Print the minimum distance between the first node and the first
// descendant point of the second node.
const size_t descendantIndex = node2.Descendant(0);
const double descendantMinDist =
node1.MinDistance(node2.Dataset().col(descendantIndex));
std::cout << "Minimum distance between grandchild node and descendant "
<< "point: " << descendantMinDist << "." << std::endl;
// Which child of node2 is closer to node1?
const size_t closerIndex = node2.GetNearestChild(node1);
if (closerIndex == 0)
std::cout << "The left child of node2 is closer to node1." << std::endl;
else if (closerIndex == 1)
std::cout << "The right child of node2 is closer to node1." << std::endl;
else // closerIndex == 2 in this case.
std::cout << "Both children of node2 are equally close to node1."
<< std::endl;
// And which child of node1 is further from node2?
const size_t furtherIndex = node1.GetFurthestChild(node2);
if (furtherIndex == 0)
std::cout << "The left child of node1 is further from node2."
<< std::endl;
else if (furtherIndex == 1)
std::cout << "The right child of node1 is further from node2."
<< std::endl;
else // furtherIndex == 2 in this case.
std::cout << "Both children of node1 are equally far from node2."
<< std::endl;
}
}
Build a SpillTree on 32-bit floating point data and save it to disk.
// See https://datasets.mlpack.org/corel-histogram.csv.
arma::fmat dataset;
mlpack::data::Load("corel-histogram.csv", dataset);
// Build the SpillTree using 32-bit floating point data as the matrix type.
// We will still use the default EmptyStatistic and EuclideanDistance
// parameters.
mlpack::SpillTree<mlpack::EuclideanDistance,
mlpack::EmptyStatistic,
arma::fmat,
mlpack::AxisOrthogonalHyperplane,
mlpack::MidpointSpaceSplit> tree(
std::move(dataset), 0.1, 20, 0.95);
// Save the tree to disk with the name 'tree'.
mlpack::data::Save("tree.bin", "tree", tree);
std::cout << "Saved tree with " << tree.Dataset().n_cols << " points to "
<< "'tree.bin'." << std::endl;
Load a 32-bit floating point SpillTree from disk, then traverse it
manually and find the number of nodes whose children overlap.
// This assumes the tree has already been saved to 'tree.bin' (as in the example
// above).
// This convenient typedef saves us a long type name!
using TreeType = mlpack::SpillTree<mlpack::EuclideanDistance,
mlpack::EmptyStatistic,
arma::fmat,
mlpack::AxisOrthogonalHyperplane,
mlpack::MidpointSpaceSplit>;
TreeType tree;
mlpack::data::Load("tree.bin", "tree", tree);
std::cout << "Tree loaded with " << tree.NumDescendants() << " points."
<< std::endl;
// Recurse in a depth-first manner. Count both the total number of non-leaves,
// and the number of non-leaves that have overlapping children.
size_t overlapCount = 0;
size_t totalInternalNodeCount = 0;
std::stack<TreeType*> stack;
stack.push(&tree);
while (!stack.empty())
{
TreeType* node = stack.top();
stack.pop();
if (node->IsLeaf())
continue;
if (node->Overlap())
++overlapCount;
++totalInternalNodeCount;
stack.push(node->Left());
stack.push(node->Right());
}
// Note that it would be possible to use TreeType::SingleTreeTraverser to
// perform the recursion above, but that is more well-suited for more complex
// tasks that require pruning and other non-trivial behavior; so using a simple
// stack is the better option here.
// Print the results.
std::cout << overlapCount << " out of " << totalInternalNodeCount
<< " internal nodes have overlapping children." << std::endl;
Use a defeatist traversal to find the approximate nearest neighbor of the third
and fourth points in the corel-histogram dataset. (Note: this can also be
done more easily with the KNN class! This example is a demonstration of how
to use the defeatist traverser.)
For this example, we must first define a
RuleType class.
// For simplicity, this only implements those methods required by single-tree
// traversals, and cannot be used with a dual-tree traversal.
//
// `.Reset()` must be called before any additional single-tree traversals after
// the first is run.
class SpillNearestNeighborRule
{
public:
// Store the dataset internally.
SpillNearestNeighborRule(const arma::mat& dataset) :
dataset(dataset),
nearestNeighbor(size_t(-1)),
nearestDistance(DBL_MAX) { }
// Compute the base case (point-to-point comparison).
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
{
// Skip the base case if the points are the same.
if (queryIndex == referenceIndex)
return 0.0;
const double dist = mlpack::EuclideanDistance::Evaluate(
dataset.col(queryIndex), dataset.col(referenceIndex));
if (dist < nearestDistance)
{
nearestNeighbor = referenceIndex;
nearestDistance = dist;
}
return dist;
}
// Score the given node in the tree; if it is sufficiently far away that it
// cannot contain a better nearest neighbor candidate, we can prune it.
template<typename TreeType>
double Score(const size_t queryIndex, const TreeType& referenceNode) const
{
const double minDist = referenceNode.MinDistance(dataset.col(queryIndex));
if (minDist > nearestDistance)
return DBL_MAX; // Prune: this cannot contain a better candidate!
return minDist;
}
// Rescore the given node/point combination. Note that this will not be used
// by the defeatist traversal as it never backtracks, but we include it for
// completeness because the RuleType API requires it.
template<typename TreeType>
double Rescore(const size_t, const TreeType&, const double oldScore) const
{
if (oldScore > nearestDistance)
return DBL_MAX; // Prune: the node is too far away.
return oldScore;
}
// This is required by defeatist traversals to select the best reference
// child to recurse into for overlapping nodes.
template<typename TreeType>
size_t GetBestChild(const size_t queryIndex, TreeType& referenceNode)
const
{
return referenceNode.GetNearestChild(dataset.col(queryIndex));
}
// We must perform at least two base cases in order to have a result. Note
// that this is two, and not one, because we skip base cases where the query
// and reference points are the same. That can only happen a maximum of once,
// so to ensure that we compare a query point to a different reference point
// at least once, we must return 2 here.
size_t MinimumBaseCases() const { return 2; }
// Get the results (to be called after the traversal).
size_t NearestNeighbor() const { return nearestNeighbor; }
double NearestDistance() const { return nearestDistance; }
// Reset the internal statistics for an additional traversal.
void Reset()
{
nearestNeighbor = size_t(-1);
nearestDistance = DBL_MAX;
}
private:
const arma::mat& dataset;
size_t nearestNeighbor;
double nearestDistance;
};
// See https://datasets.mlpack.org/corel-histogram.csv.
arma::mat dataset;
mlpack::data::Load("corel-histogram.csv", dataset, true);
typedef mlpack::SpillTree<mlpack::EuclideanDistance,
mlpack::EmptyStatistic,
arma::mat,
mlpack::AxisOrthogonalHyperplane,
mlpack::MidpointSpaceSplit> TreeType;
// Build two trees, one with a lot of overlap, and one with no overlap
// (e.g. tau = 0).
TreeType tree1(dataset, 0.5, 10), tree2(dataset, 0.0, 10);
// Construct the rule types, and then the traversals.
SpillNearestNeighborRule r1(dataset), r2(dataset);
TreeType::DefeatistSingleTreeTraverser<SpillNearestNeighborRule> t1(r1);
TreeType::DefeatistSingleTreeTraverser<SpillNearestNeighborRule> t2(r2);
// Search for the approximate nearest neighbor of point 3 using both trees.
t1.Traverse(3, tree1);
t2.Traverse(3, tree2);
std::cout << "Approximate nearest neighbor of point 3:" << std::endl;
std::cout << " - Spill tree with overlap 0.5 found: point "
<< r1.NearestNeighbor() << ", distance " << r1.NearestDistance()
<< "." << std::endl;
std::cout << " - Spill tree with no overlap found: point "
<< r2.NearestNeighbor() << ", distance " << r2.NearestDistance()
<< "." << std::endl;
// Now search for point 6.
r1.Reset();
r2.Reset();
t1.Traverse(6, tree1);
t2.Traverse(6, tree2);
std::cout << "Approximate nearest neighbor of point 6:" << std::endl;
std::cout << " - Spill tree with overlap 0.5 found: point "
<< r1.NearestNeighbor() << ", distance " << r1.NearestDistance()
<< "." << std::endl;
std::cout << " - Spill tree with no overlap found: point "
<< r2.NearestNeighbor() << ", distance " << r2.NearestDistance()
<< "." << std::endl;