SpillTree
The SpillTree
class represents a generic multidimensional binary space
partitioning tree that allows overlapping volumes between nodes, also known as a
โhybrid spill treeโ. It is heavily templatized to control splitting behavior
and other behaviors, and is the actual class underlying trees such as the
SPTree
. In general, the SpillTree
class is not meant to be
used directly, and instead one of the handful of variants should be used
instead:
The SpillTree
is similar to the BinarySpaceTree
,
except that the two children of a node are allowed to overlap, and thus a single
point can be contained in multiple branches of the tree. This can be useful to,
e.g., improve nearest neighbor performance when using defeatist traversals
without backtracking.
For users who want to use SpillTree
directly or with custom behavior,
the full class is still detailed in the subsections below. SpillTree
supports
the TreeType API and can be used
with mlpackโs tree-based algorithms, although using custom behavior may require
a template typedef.
- Template parameters
- Constructors
- Basic tree properties
- Bounding distances with the tree
HyperplaneType
template parameterSplitType
template parameter- Tree traversals
- Example usage
๐ See also
SPTree
MeanSPTree
NonOrtSPTree
NonOrtMeanSPTree
BinarySpaceTree
- An Investigation of Practical Approximate Nearest Neighbor Algorithms (pdf)
- Tree-Independent Dual-Tree Algorithms (pdf)
๐ Template parameters
The SpillTree
class takes five template parameters. The first three of
these are required by the
TreeType API
(see also
this more detailed section). The
full signature of the class is:
template<typename DistanceType,
typename StatisticType,
typename MatType,
template<typename HyperplaneDistanceType,
typename HyperplaneMatType> class HyperplaneType,
template<typename SplitDistanceType,
typename SplitMatType> class SplitType>
class SpillTree;
-
DistanceType
: the distance metric to use for distance computations. By default, this isEuclideanDistance
. StatisticType
: this holds auxiliary information in each tree node. By default,EmptyStatistic
is used, which holds no information.- See the
StatisticType
section in theBinarySpaceTree
documentation for more details.
- See the
-
MatType
: the type of matrix used to represent points. Must be a type matching the Armadillo API. By default,arma::mat
is used, but other types such asarma::fmat
or similar will work just fine. HyperplaneType
: the class defining the type of the hyperplane that will split each node. By default,AxisOrthogonalHyperplane
is used.- See the
HyperplaneType
section for more details.
- See the
SplitType
: the class defining how an individualSpillTree
node should be split. By default,MidpointSpaceSplit
is used.- See the
SplitType
section for more details.
- See the
Note that the TreeType API requires trees to have only three template
parameters. In order to use a SpillTree
with its five template parameters
with an mlpack algorithm that needs a TreeType, it is easiest to define a
template typedef:
template<typename DistanceType, typename StatisticType, typename MatType>
using CustomTree = SpillTree<DistanceType, StatisticType, MatType,
CustomHyperplaneType, CustomSplitType>
Here, CustomHyperplaneType
and CustomSplitType
are the desired hyperplane
type and split strategy. This is the way that all SpillTree
variants (such as
SPTree
) are defined.
๐ Constructors
SpillTree
s are constructed by iteratively finding splitting hyperplanes, and
points within a margin of the hyperplane are assigned to both child nodes.
Unlike the constructors of
BinarySpaceTree
, the dataset is not
permuted during construction.
node = SpillTree(data, tau=0.0, maxLeafSize=20, rho=0.7)
- Construct a
SpillTree
on the givendata
, using the specified hyperparameters to control tree construction behavior. - Default template parameters are used, meaning that this tree will be a
SPTree
. - By default, a reference to
data
is stored. Ifdata
goes out of scope after tree construction, memory errors will occur! To avoid this, either pass the dataset or a copy withstd::move()
(e.g.std::move(data)
); when doing this,data
will be set to an empty matrix.
- Construct a
node = SpillTree<DistanceType, StatisticType, MatType, HyperplaneType, SplitType>(data, tau=0.0, maxLeafSize=20, rho=0.7)
- Construct a
SpillTree
on the givendata
, using custom template parameters, and using the specified hyperparameters to control tree construction behavior. - By default, a reference to
data
is stored. Ifdata
goes out of scope after tree construction, memory errors will occur! To avoid this, either pass the dataset or a copy withstd::move()
(e.g.std::move(data)
); when doing this,data
will be set to an empty matrix.
- Construct a
node = SpillTree()
- Construct an empty
SpillTree
with no children, no points, and default template parameters.
- Construct an empty
Notes:
-
The name
node
is used here forSpillTree
objects instead oftree
, because eachSpillTree
object is a single node in the tree. The constructor returns the node that is the root of the tree. -
Inserting individual points or removing individual points from a
SpillTree
is not supported, because this generally results in a tree with very suboptimal hyperplane splits. It is better to simply build a newSpillTree
on the modified dataset. For trees that support individual insertion and deletions, see theRectangleTree
class and all its variants (e.g.RTree
,RStarTree
, etc.). -
See also the developer documentation on tree constructors.
๐ Constructor parameters:
name | type | description | default |
---|---|---|---|
data |
MatType |
Column-major matrix to build the tree on. | (N/A) |
tau |
double |
Width of spill margin: points within tau of the splitting hyperplane of a node will be contained in both left and right children. |
0.0 |
maxLeafSize |
size_t |
Maximum number of points to store in each leaf. | 20 |
rho |
double |
Balance threshold. When splitting, if either overlapping node would contain a fraction of more than rho of the points, a non-overlapping split is performed. Must be in the range [0.0, 1.0) . |
0.7 |
Caveats:
-
tau
must be manually tuned for the properties of each dataset; the default,0.0
, will never allow overlap between nodes (and thus the created tree will essentially be a non-overlappingBinarySpaceTree
). -
If
tau
is set too large, nodes will overlap too much and search quality will be degraded. -
rho
implicitly controls the depth of the tree by forcing very overlapping children to be non-overlapping. Asrho
gets closer to1
, more overlap is allowed, which in turn makes the tree deeper. Ifrho
is set to0.5
or less, then all splits will be non-overlapping (and the tree will essentially be aBinarySpaceTree
).
๐ Basic tree properties
Once a SpillTree
object is constructed, various properties of the tree
can be accessed or inspected. Many of these functions are required by the
TreeType API.
๐ Navigating the tree
-
node.NumChildren()
returns the number of children innode
. This is either2
ifnode
has children, or0
ifnode
is a leaf. -
node.IsLeaf()
returns abool
indicating whether or notnode
is a leaf. node.Child(i)
returns aSpillTree&
that is thei
th child.i
must be0
or1
.- This function should only be called if
node.NumChildren()
is not0
(e.g. ifnode
is not a leaf). Note that this returns a validSpillTree&
that can itself be used just like the root node of the tree! node.Left()
andnode.Right()
are convenience functions specific toSpillTree
that will returnSpillTree*
(pointers) to the left and right children, respectively, orNULL
ifnode
has no children.
node.Parent()
will return aSpillTree*
that points to the parent ofnode
, orNULL
ifnode
is the root of theSpillTree
.
๐ Accessing members of a tree
-
node.Overlap()
will return abool
that istrue
ifnode
โs children are overlapping, andfalse
otherwise. node.Hyperplane()
will return aHyperplaneType&
object that represents the splitting hyperplane ofnode
.- All points in
node.Left()
are to the left ofnode.Hyperplane()
ifnode.Overlap()
isfalse
; otherwise, all points innode.Left()
are to the left ofnode.Hyperplane() + tau
. - All points in
node.Right()
are to the right ofnode.Hyperplane()
ifnode.Overlap()
isfalse
; otherwise, all points innode.Right()
are to the right ofnode.Hyperplane() - tau
.
- All points in
node.Bound()
will return aconst HRectBound&
representing the bounding box associated withnode
.- If a custom
HyperplaneType
is specified, then theBoundType
associated with that hyperplane type is returned instead. - If a custom
DistanceType
and/orMatType
are specified, then aconst HRectBound<DistanceType, ElemType>&
is returned (or aBoundType
with thatDistanceType
, if a customHyperplaneType
was also specified).ElemType
is the element type of the specifiedMatType
(e.g.double
forarma::mat
,float
forarma::fmat
, etc.).
- If a custom
-
node.Stat()
will return aStatisticType&
holding the statistics of the node that were computed during tree construction. node.Distance()
will return aDistanceType&
.
See also the developer documentation for basic tree functionality in mlpack.
๐ Accessing data held in a tree
-
node.Dataset()
will return aconst MatType&
that is the dataset the tree was built on. node.NumPoints()
returns asize_t
indicating the number of points held directly innode
.- If
node
is not a leaf, this will return0
, asSpillTree
only holds points directly in its leaves. - If
node
is a leaf, then the number of points will be less than or equal to themaxLeafSize
that was specified when the tree was constructed.
- If
node.Point(i)
returns asize_t
indicating the index of thei
โth point innode.Dataset()
.i
must be in the range[0, node.NumPoints() - 1]
(inclusive).node
must be a leaf (as non-leaves do not hold any points).- The
i
โth point innode
can then be accessed asnode.Dataset().col(node.Point(i))
. - Accessing the actual
i
โth point itself can be done with, e.g.,node.Dataset().col(node.Point(i))
.
node.NumDescendants()
returns asize_t
indicating the number of points held in all descendant leaves ofnode
.- If
node
is the root of the tree, thennode.NumDescendants()
will be equal tonode.Dataset().n_cols
.
- If
node.Descendant(i)
returns asize_t
indicating the index of thei
โth descendant point innode.Dataset()
.i
must be in the range[0, node.NumDescendants() - 1]
(inclusive).node
does not need to be a leaf.- The
i
โth descendant point innode
can then be accessed asnode.Dataset().col(node.Descendant(i))
. - Accessing the actual
i
โth descendant itself can be done with, e.g.,node.Dataset().col(node.Descendant(i))
.
๐ Accessing computed bound quantities of a tree
The following quantities are cached for each node in a SpillTree
, and so
accessing them does not require any computation. In the documentation below,
ElemType
is the element type of the given MatType
; e.g., if MatType
is
arma::mat
, then ElemType
is double
.
node.FurthestPointDistance()
returns anElemType
representing the distance between the center of the bound ofnode
and the furthest point held bynode
.- If
node
is not a leaf, this returns 0 (becausenode
does not hold any points).
- If
-
node.FurthestDescendantDistance()
returns anElemType
representing the distance between the center of the bound ofnode
and the furthest descendant point held bynode
. -
node.MinimumBoundDistance()
returns anElemType
representing the minimum possible distance from the center of the node to any edge of its bound. node.ParentDistance()
returns anElemType
representing the distance between the center of the bound ofnode
and the center of the bound of its parent.- If
node
is the root of the tree,0
is returned.
- If
Note: for more details on each bound quantity, see the developer documentation on bound quantities for trees.
๐ Other functionality
node.Center(center)
computes the center of the bound ofnode
and stores it incenter
.center
should be of typearma::Col<ElemType>&
, whereElemType
is the element type of the specifiedMatType
.center
will be set to have size equivalent to the dimensionality of the dataset held bynode
.- This is equivalent to calling
node.Bound().Center(center)
.
- A
SpillTree
can be serialized withdata::Save()
anddata::Load()
.
๐ Bounding distances with the tree
The primary use of trees in mlpack is bounding distances to points or other tree nodes. The following functions can be used for these tasks.
node.GetNearestChild(point)
node.GetFurthestChild(point)
- Return a
size_t
indicating the index of the child (0
for left,1
for right) that is closest to (or furthest from)point
, with respect to theMinDistance()
(orMaxDistance()
) function. - If there is a tie,
0
(the left child) is returned. - If
node
is a leaf,0
is returned. point
should be a column vector type of the same type asMatType
. (e.g., ifMatType
isarma::mat
, thenpoint
should be anarma::vec
.)
- Return a
node.GetNearestChild(other)
node.GetFurthestChild(other)
- Return a
size_t
indicating the index of the child (0
for left,1
for right) that is closest to (or furthest from) theSpillTree
nodeother
, with respect to theMinDistance()
(orMaxDistance()
) function. - If there is a tie,
0
(the left child) is returned. - If
node
is a leaf,0
is returned.
- Return a
node.MinDistance(point)
node.MinDistance(other)
- Return a
double
indicating the minimum possible distance betweennode
andpoint
, or theSpillTree
nodeother
. - This is equivalent to the minimum possible distance between any point
contained in the bounding hyperrectangle of
node
andpoint
, or between any point contained in the bounding hyperrectangle ofnode
and any point contained in the bounding hyperrectangle ofother
. point
should be a column vector type of the same type asMatType
. (e.g., ifMatType
isarma::mat
, thenpoint
should be anarma::vec
.)
- Return a
node.MaxDistance(point)
node.MaxDistance(other)
- Return a
double
indicating the maximum possible distance betweennode
andpoint
, or theSpillTree
nodeother
. - This is equivalent to the maximum possible distance between any point
contained in the bounding hyperrectangle of
node
andpoint
, or between any point contained in the bounding hyperrectangle ofnode
and any point contained in the bounding hyperrectangle ofother
. point
should be a column vector type of the same type asMatType
. (e.g., ifMatType
isarma::mat
, thenpoint
should be anarma::vec
.)
- Return a
node.RangeDistance(point)
node.RangeDistance(other)
- Return a
RangeType<ElemType>
whose lower bound isnode.MinDistance(point)
ornode.MinDistance(other)
, and whose upper bound isnode.MaxDistance(point)
ornode.MaxDistance(other)
. ElemType
is the element type ofMatType
.point
should be a column vector type of the same type asMatType
. (e.g., ifMatType
isarma::mat
, thenpoint
should be anarma::vec
.)
- Return a
๐ Tree traversals
Like every mlpack tree, the SpillTree
class provides a single-tree and
dual-tree traversal that can be paired
with a RuleType
class to implement a
single-tree or dual-tree algorithm.
SpillTree::SingleTreeTraverser
- Implements a depth-first single-tree traverser.
SpillTree::DualTreeTraverser
- Implements a dual-depth-first dual-tree traverser.
However, the spill tree is primarily useful because the overlapping nodes allow
defeatist search to be effective. Defeatist search is non-backtracking: the
tree is traversed to one leaf only. For example, finding the approximate
nearest neighbor of a point p
with defeatist search is done by recursing in
the tree, choosing the child with smallest minimum distance to p
, and when a
leaf is encountered, choosing the closest point in the leaf to p
as the
nearest neighbor. This is the strategy used in the
original spill tree paper (pdf).
Defeatist traversers, matching the API for a regular traversal are made available as the following two classes:
SpillTree::DefeatistSingleTreeTraverser
- Implements a depth-first single-tree defeatist traverser with no backtracking. Traversal will terminate after the first leaf is visited.
SpillTree::DefeatistDualTreeTraverser
- Implements a dual-depth-first dual-tree defeatist traversal with no backtracking. For each query leaf node, traversal will terminate after the first reference leaf node is visited.
Any RuleType
that is being used with a
defeatist traversal, in addition to the functions required by the RuleType
API, must implement the following functions:
// This is only required for single-tree defeatist traversals.
// It should return the index of the branch that should be chosen for the given
// query point and reference node.
template<typename VecType, typename TreeType>
size_t GetBestChild(const VecType& queryPoint, TreeType& referenceNode);
// This is only required for dual-tree defeatist traversals.
// It should return the index of the best child of the reference node that
// should be chosen for the given query node.
template<typename TreeType>
size_t GetBestChild(TreeType& queryNode, TreeType& referenceNode);
// Return the minimum number of base cases (point-to-point computations) that
// are required during the traversal.
size_t MinimumBaseCases();
๐ HyperplaneType
Each node in a SpillTree
corresponds to some region in space that contains all
of the descendant points in the node. Similar to the KDTree
, this
region is a hyperrectangle; however, instead of representing that hyperrectangle
explicitly like the KDTree
with the
HRectBound
class, the SpillTree
represents the region implicitly, with each node storing only the hyperplane
and margin required to determine whether a point belongs to the left node, the
right node, or both.
The type of hyperplane (e.g. axis-aligned or arbitrary) can be controlled by the
HyperplaneType
template parameter. mlpack supplies two drop-in classes that
can be used for HyperplaneType
, and it is also possible to write a custom
HyperplaneType
:
AxisOrthogonalHyperplane
: uses hyperplanes that are axis-orthogonal (or axis-aligned).Hyperplane
: uses arbitrary hyperplanes specified by any vector.- Custom
HyperplaneType
s: implement a fully customHyperplaneType
class
๐ AxisOrthogonalHyperplane
The AxisOrthogonalHyperplane
class is used to provide an axis-orthogonal split
for a SpillTree
. That is, whether or not a point is on the left or right side
of the split is a very efficient computation using only a single dimension of
the data.
- The
AxisOrthogonalHyperplane
class defines the following two typedefs:-
AxisOrthogonalHyperplane::BoundType
, which is the type of the bound used by the spill tree with this hyperplane, isHRectBound
, orHRectBound<DistanceType, MatType>
if customDistanceType
and/orMatType
are specified. -
AxisOrthogonalHyperplane::ProjVectorType
isAxisParallelProjVector
, a class that simply holds the index of the dimension of the projection vector.- For more details, see the source code.
-
-
An
AxisOrthogonalHyperplane
objecth
(e.g. returned withnode.Hyperplane()
) has the following members:-
h.Project(point)
returns adouble
that is the orthogonal projection ofpoint
onto the tangent vector of the hyperplaneh
. -
h.Left(point)
returnstrue
ifpoint
is to the left ofh
. -
h.Right(point)
returnstrue
ifpoint
is to the right ofh
. -
h.Left(bound)
returnstrue
ifbound
(anAxisOrthogonalHyperplane::BoundType
; see the bullet point above) is to the left ofh
. -
h.Right(bound)
returnstrue
ifbound
(anAxisOrthogonalHyperplane::BoundType
; see the bullet point above) is to the right ofh
.
-
- An
AxisOrthogonalHyperplane
object can be serialized withdata::Save()
anddata::Load()
.
For more details, see the the source code.
๐ Hyperplane
The Hyperplane
class is used to provide an arbitrary hyperplane split for a
SpillTree
. The computation of whether or not a point is on the left or right
side of the split is less efficient than
AxisOrthogonalHyperplane
, but Hyperplane
is
able to represent any possible hyperplane.
- The
Hyperplane
class defines the two following typedefs:-
Hyperplane::BoundType
, which is the type of the bound used by the spill tree with this hyperplane, isBallBound
, orBallBound<DistanceType>
if a customDistanceType
is specified. -
Hyperplane::ProjVectorType
isProjVector<>
, an arbitrary projection vector class that wraps anarma::vec
.- If a custom
MatType
is specified, thenHyperplane::ProjVectorType
isProjVector<MatType>
, which wraps a vector of the same type asMatType
. - For more details, see the source code.
- If a custom
-
-
A
Hyperplane
objecth
(e.g. returned withnode.Hyperplane()
) has the following members:-
h.Project(point)
returns adouble
that is the orthogonal projection ofpoint
onto the tangent vector of the hyperplaneh
. -
h.Left(point)
returnstrue
ifpoint
is to the left ofh
. -
h.Right(point)
returnstrue
ifpoint
is to the right ofh
. -
h.Left(bound)
returnstrue
ifbound
(anAxisOrthogonalHyperplane::BoundType
; see the bullet point above) is to the left ofh
. -
h.Right(bound)
returnstrue
ifbound
(aHyperplane::BoundType
; see the bullet point above) is to the right ofh
.
-
- A
Hyperplane
object can be serialized withdata::Save()
anddata::Load()
.
For more details, see the the source code.
๐ Custom HyperplaneType
s
Custom hyperplane types for a spill tree can be implemented via the
HyperplaneType
template parameter. By default, the
AxisOrthogonalHyperplane
hyperplane type is used,
but it is also possible to implement and use a custom HyperplaneType
. Any
custom HyperplaneType
class must implement the following signature:
// NOTE: the custom HyperplaneType class must take two template parameters.
template<typename DistanceType, typename MatType>
class HyperplaneType
{
public:
// The hyperplane type must specify these two public typedefs, which are used
// by the spill tree and the splitting strategy.
//
// Substitute HRectBound and ProjVector with your choices.
using BoundType = mlpack::HRectBound<DistanceType,
typename MatType::elem_type>;
using ProjVectorType = mlpack::ProjVector<MatType>;
// Empty constructor, which will construct an empty or default hyperplane.
HyperplaneType();
// Construct the HyperplaneType with the given projection vector and split
// value along that projection.
HyperplaneType(const ProjVectorType& projVector, double splitVal);
// Compute the projection of the given point (an `arma::vec` or similar type
// matching the Armadillo API and element type of `MatType`) onto the vector
// tangent to the hyperplane.
template<typename VecType>
double Project(const VecType& point) const;
// Return true if the point (an `arma::vec` or similar type matching the
// Armadillo API and element type of `MatType`) falls to the left of the
// hyperplane.
template<typename VecType>
double Left(const VecType& point) const;
// Return true if the point (an `arma::vec` or similar type matching the
// Armadillo API and element type of `MatType`) falls to the right of the
// hyperplane.
template<typename VecType>
double Right(const VecType& point) const;
// Return true if the given bound is fully to the left of the hyperplane.
bool Left(const BoundType& bound) const;
// Return true if the given bound is fully to the right of the hyperplane.
bool Right(const BoundType& bound) const;
// Serialize the hyperplane using cereal.
template<typename Archive>
void serialize(Archive& ar, const uint32_t version);
};
๐ SplitType
The SplitType
template parameter controls the algorithm used to split each
node of a SpillTree
while building. The splitting strategy used can be
entirely arbitraryโthe SplitType
only needs to compute a
HyperplaneType
to split a set of points.
mlpack provides two drop-in choices for SplitType
, and it is also possible
to write a fully custom split:
MidpointSpaceSplit
: split a set of points using a hyperplane built on the midpoint (median) of points in a dataset.MeanSpaceSplit
: split a set of points using a hyperplane built on the mean (average) of points in a dataset.- Custom
SplitType
s: implement a fully customSplitType
class
๐ MidpointSpaceSplit
The MidpointSpaceSplit
class is a splitting strategy that can be used by
SpillTree
. It is the default strategy for splitting SPTree
s
and NonOrtSPTree
s.
The splitting strategy for the MidpointSpaceSplit
class is, given a set of
points:
-
If
AxisOrthogonalHyperplane
is being used, then select the dimension with the maximum width, and use the midpoint of the pointsโ values in that dimension. -
If
Hyperplane
is being used, then estimate the furthest two points in the dataset by random sampling, and use the vector connecting those points as the tangent vector to the hyperplane. The midpoint of the points projected onto this hyperplane is used as the split value.
Note that MidpointSpaceSplit
can only be used with a HyperplaneType
with
HyperplaneType::ProjVectorType
as either AxisAlignedProjVector
or
ProjVector
.
For implementation details, see the source code.
๐ MeanSpaceSplit
The MeanSpaceSplit
class is a splitting strategy that can be used by
SpillTree
. It is the splitting strategy used by the
MeanSPTree
and the
NonOrtMeanSPTree
classes.
The splitting strategy for the MeanSpaceSplit
class is, given a set of
points:
-
If
AxisOrthogonalHyperplane
is being used, then select the dimension with the maximum width, and use the mean of the pointsโ values in that dimension. -
If
Hyperplane
is being used, then estimate the furthest two points in the dataset by random sampling, and use the vector connecting those points as the tangent vector to the hyperplane. The mean of the points projected onto this hyperplane is used as the split value.
Note that MeanSpaceSplit
can only be used with a HyperplaneType
with
HyperplaneType::ProjVectorType
as either AxisAlignedProjVector
or
ProjVector
.
For implementation details, see the source code.
๐ Custom SplitType
s
Custom split strategies for a spill tree can be implemented via the
SplitType
template parameter. By default, the
MidpointSpaceSplit
splitting strategy is used, but it
is also possible to implement and use a custom SplitType
. Any custom
SplitType
class must implement the following signature:
// NOTE: the custom SplitType class must take two template parameters.
template<typename DistanceType, typename MatType>
class SplitType
{
public:
// The SplitType class is only required to provide one static function.
// Create a splitting hyperplane and store it in the given `HyperplaneType`,
// using the given data and bounding box `bound`. `data` will be an Armadillo
// matrix that is the entire dataset, and `points` are the indices of points
// in `data` that should be split.
template<typename HyperplaneType>
static bool SplitSpace(
const typename HyperplaneType::BoundType& bound,
const MatType& data,
const arma::Col<size_t>& points,
HyperplaneType& hyp);
};
๐ Example usage
The SpillTree
class is only really necessary when a custom hyperplane type or
custom splitting strategy is intended to be used. For simpler use cases, one of
the typedefs of SpillTree
(such as SPTree
) will suffice.
For this reason, all of the examples below explicitly specify all five template
parameters of SPTree
.
Writing a custom hyperplane type and
writing a custom splitting strategy are discussed
in the previous sections. Each of the parameters in the examples below can be
trivially changed for different behavior.
Build a SpillTree
on the cloud
dataset and print basic statistics about the
tree.
// See https://datasets.mlpack.org/cloud.csv.
arma::mat dataset;
mlpack::data::Load("cloud.csv", dataset, true);
// Build the spill tree with a tau (margin) of 0.2 and a leaf size of 10.
// (This means that nodes are split until they contain 10 or fewer points.)
//
// The std::move() means that `dataset` will be empty after this call, and no
// data will be copied during tree building.
mlpack::SpillTree<mlpack::EuclideanDistance,
mlpack::EmptyStatistic,
arma::mat,
mlpack::AxisOrthogonalHyperplane,
mlpack::MidpointSpaceSplit> tree(std::move(dataset), 0.2, 10);
// Print the bounding box of the root node.
std::cout << "Bounding box of root node:" << std::endl;
for (size_t i = 0; i < tree.Bound().Dim(); ++i)
{
std::cout << " - Dimension " << i << ": [" << tree.Bound()[i].Lo() << ", "
<< tree.Bound()[i].Hi() << "]." << std::endl;
}
std::cout << std::endl;
// Print the number of descendant points of the root, and of each of its
// children.
std::cout << "Descendant points of root: "
<< tree.NumDescendants() << "." << std::endl;
std::cout << "Descendant points of left child: "
<< tree.Left()->NumDescendants() << "." << std::endl;
std::cout << "Descendant points of right child: "
<< tree.Right()->NumDescendants() << "." << std::endl;
std::cout << std::endl;
// Compute the center of the SpillTree.
arma::vec center;
tree.Center(center);
std::cout << "Center of tree: " << center.t();
Build two SpillTree
s on subsets of the corel dataset and compute minimum
and maximum distances between different nodes in the tree.
// See https://datasets.mlpack.org/corel-histogram.csv.
arma::mat dataset;
mlpack::data::Load("corel-histogram.csv", dataset, true);
// Convenience typedef for the tree type.
using TreeType = mlpack::SpillTree<mlpack::EuclideanDistance,
mlpack::EmptyStatistic,
arma::mat,
mlpack::AxisOrthogonalHyperplane,
mlpack::MidpointSpaceSplit>;
// Build trees on the first half and the second half of points. Use a tau
// (overlap) parameter of 0.3, which is tuned to this dataset, and a rho value
// of 0.6 to prevent the trees getting too deep.
TreeType tree1(dataset.cols(0, dataset.n_cols / 2), 0.3, 20, 0.6);
TreeType tree2(dataset.cols(dataset.n_cols / 2 + 1, dataset.n_cols - 1),
0.3, 20, 0.6);
// Compute the maximum distance between the trees.
std::cout << "Maximum distance between tree root nodes: "
<< tree1.MaxDistance(tree2) << "." << std::endl;
// Get the leftmost grandchild of the first tree's root---if it exists.
if (!tree1.IsLeaf() && !tree1.Child(0).IsLeaf())
{
TreeType& node1 = tree1.Child(0).Child(0);
// Get the rightmost grandchild of the second tree's root---if it exists.
if (!tree2.IsLeaf() && !tree2.Child(1).IsLeaf())
{
TreeType& node2 = tree2.Child(1).Child(1);
// Print the minimum and maximum distance between the nodes.
mlpack::Range dists = node1.RangeDistance(node2);
std::cout << "Possible distances between two grandchild nodes: ["
<< dists.Lo() << ", " << dists.Hi() << "]." << std::endl;
// Print the minimum distance between the first node and the first
// descendant point of the second node.
const size_t descendantIndex = node2.Descendant(0);
const double descendantMinDist =
node1.MinDistance(node2.Dataset().col(descendantIndex));
std::cout << "Minimum distance between grandchild node and descendant "
<< "point: " << descendantMinDist << "." << std::endl;
// Which child of node2 is closer to node1?
const size_t closerIndex = node2.GetNearestChild(node1);
if (closerIndex == 0)
std::cout << "The left child of node2 is closer to node1." << std::endl;
else if (closerIndex == 1)
std::cout << "The right child of node2 is closer to node1." << std::endl;
else // closerIndex == 2 in this case.
std::cout << "Both children of node2 are equally close to node1."
<< std::endl;
// And which child of node1 is further from node2?
const size_t furtherIndex = node1.GetFurthestChild(node2);
if (furtherIndex == 0)
std::cout << "The left child of node1 is further from node2."
<< std::endl;
else if (furtherIndex == 1)
std::cout << "The right child of node1 is further from node2."
<< std::endl;
else // furtherIndex == 2 in this case.
std::cout << "Both children of node1 are equally far from node2."
<< std::endl;
}
}
Build a SpillTree
on 32-bit floating point data and save it to disk.
// See https://datasets.mlpack.org/corel-histogram.csv.
arma::fmat dataset;
mlpack::data::Load("corel-histogram.csv", dataset);
// Build the SpillTree using 32-bit floating point data as the matrix type.
// We will still use the default EmptyStatistic and EuclideanDistance
// parameters.
mlpack::SpillTree<mlpack::EuclideanDistance,
mlpack::EmptyStatistic,
arma::fmat,
mlpack::AxisOrthogonalHyperplane,
mlpack::MidpointSpaceSplit> tree(
std::move(dataset), 0.1, 20, 0.95);
// Save the tree to disk with the name 'tree'.
mlpack::data::Save("tree.bin", "tree", tree);
std::cout << "Saved tree with " << tree.Dataset().n_cols << " points to "
<< "'tree.bin'." << std::endl;
Load a 32-bit floating point SpillTree
from disk, then traverse it
manually and find the number of nodes whose children overlap.
// This assumes the tree has already been saved to 'tree.bin' (as in the example
// above).
// This convenient typedef saves us a long type name!
using TreeType = mlpack::SpillTree<mlpack::EuclideanDistance,
mlpack::EmptyStatistic,
arma::fmat,
mlpack::AxisOrthogonalHyperplane,
mlpack::MidpointSpaceSplit>;
TreeType tree;
mlpack::data::Load("tree.bin", "tree", tree);
std::cout << "Tree loaded with " << tree.NumDescendants() << " points."
<< std::endl;
// Recurse in a depth-first manner. Count both the total number of non-leaves,
// and the number of non-leaves that have overlapping children.
size_t overlapCount = 0;
size_t totalInternalNodeCount = 0;
std::stack<TreeType*> stack;
stack.push(&tree);
while (!stack.empty())
{
TreeType* node = stack.top();
stack.pop();
if (node->IsLeaf())
continue;
if (node->Overlap())
++overlapCount;
++totalInternalNodeCount;
stack.push(node->Left());
stack.push(node->Right());
}
// Note that it would be possible to use TreeType::SingleTreeTraverser to
// perform the recursion above, but that is more well-suited for more complex
// tasks that require pruning and other non-trivial behavior; so using a simple
// stack is the better option here.
// Print the results.
std::cout << overlapCount << " out of " << totalInternalNodeCount
<< " internal nodes have overlapping children." << std::endl;
Use a defeatist traversal to find the approximate nearest neighbor of the third
and fourth points in the corel-histogram
dataset. (Note: this can also be
done more easily with the KNN
class! This example is a demonstration of how
to use the defeatist traverser.)
For this example, we must first define a
RuleType
class.
// For simplicity, this only implements those methods required by single-tree
// traversals, and cannot be used with a dual-tree traversal.
//
// `.Reset()` must be called before any additional single-tree traversals after
// the first is run.
class SpillNearestNeighborRule
{
public:
// Store the dataset internally.
SpillNearestNeighborRule(const arma::mat& dataset) :
dataset(dataset),
nearestNeighbor(size_t(-1)),
nearestDistance(DBL_MAX) { }
// Compute the base case (point-to-point comparison).
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
{
// Skip the base case if the points are the same.
if (queryIndex == referenceIndex)
return 0.0;
const double dist = mlpack::EuclideanDistance::Evaluate(
dataset.col(queryIndex), dataset.col(referenceIndex));
if (dist < nearestDistance)
{
nearestNeighbor = referenceIndex;
nearestDistance = dist;
}
return dist;
}
// Score the given node in the tree; if it is sufficiently far away that it
// cannot contain a better nearest neighbor candidate, we can prune it.
template<typename TreeType>
double Score(const size_t queryIndex, const TreeType& referenceNode) const
{
const double minDist = referenceNode.MinDistance(dataset.col(queryIndex));
if (minDist > nearestDistance)
return DBL_MAX; // Prune: this cannot contain a better candidate!
return minDist;
}
// Rescore the given node/point combination. Note that this will not be used
// by the defeatist traversal as it never backtracks, but we include it for
// completeness because the RuleType API requires it.
template<typename TreeType>
double Rescore(const size_t, const TreeType&, const double oldScore) const
{
if (oldScore > nearestDistance)
return DBL_MAX; // Prune: the node is too far away.
return oldScore;
}
// This is required by defeatist traversals to select the best reference
// child to recurse into for overlapping nodes.
template<typename TreeType>
size_t GetBestChild(const size_t queryIndex, TreeType& referenceNode)
const
{
return referenceNode.GetNearestChild(dataset.col(queryIndex));
}
// We must perform at least two base cases in order to have a result. Note
// that this is two, and not one, because we skip base cases where the query
// and reference points are the same. That can only happen a maximum of once,
// so to ensure that we compare a query point to a different reference point
// at least once, we must return 2 here.
size_t MinimumBaseCases() const { return 2; }
// Get the results (to be called after the traversal).
size_t NearestNeighbor() const { return nearestNeighbor; }
double NearestDistance() const { return nearestDistance; }
// Reset the internal statistics for an additional traversal.
void Reset()
{
nearestNeighbor = size_t(-1);
nearestDistance = DBL_MAX;
}
private:
const arma::mat& dataset;
size_t nearestNeighbor;
double nearestDistance;
};
// See https://datasets.mlpack.org/corel-histogram.csv.
arma::mat dataset;
mlpack::data::Load("corel-histogram.csv", dataset, true);
typedef mlpack::SpillTree<mlpack::EuclideanDistance,
mlpack::EmptyStatistic,
arma::mat,
mlpack::AxisOrthogonalHyperplane,
mlpack::MidpointSpaceSplit> TreeType;
// Build two trees, one with a lot of overlap, and one with no overlap
// (e.g. tau = 0).
TreeType tree1(dataset, 0.5, 10), tree2(dataset, 0.0, 10);
// Construct the rule types, and then the traversals.
SpillNearestNeighborRule r1(dataset), r2(dataset);
TreeType::DefeatistSingleTreeTraverser<SpillNearestNeighborRule> t1(r1);
TreeType::DefeatistSingleTreeTraverser<SpillNearestNeighborRule> t2(r2);
// Search for the approximate nearest neighbor of point 3 using both trees.
t1.Traverse(3, tree1);
t2.Traverse(3, tree2);
std::cout << "Approximate nearest neighbor of point 3:" << std::endl;
std::cout << " - Spill tree with overlap 0.5 found: point "
<< r1.NearestNeighbor() << ", distance " << r1.NearestDistance()
<< "." << std::endl;
std::cout << " - Spill tree with no overlap found: point "
<< r2.NearestNeighbor() << ", distance " << r2.NearestDistance()
<< "." << std::endl;
// Now search for point 6.
r1.Reset();
r2.Reset();
t1.Traverse(6, tree1);
t2.Traverse(6, tree2);
std::cout << "Approximate nearest neighbor of point 6:" << std::endl;
std::cout << " - Spill tree with overlap 0.5 found: point "
<< r1.NearestNeighbor() << ", distance " << r1.NearestDistance()
<< "." << std::endl;
std::cout << " - Spill tree with no overlap found: point "
<< r2.NearestNeighbor() << ", distance " << r2.NearestDistance()
<< "." << std::endl;