mlpack

SpillTree

The SpillTree class represents a generic multidimensional binary space partitioning tree that allows overlapping volumes between nodes, also known as a โ€˜hybrid spill treeโ€™. It is heavily templatized to control splitting behavior and other behaviors, and is the actual class underlying trees such as the SPTree. In general, the SpillTree class is not meant to be used directly, and instead one of the handful of variants should be used instead:

The SpillTree is similar to the BinarySpaceTree, except that the two children of a node are allowed to overlap, and thus a single point can be contained in multiple branches of the tree. This can be useful to, e.g., improve nearest neighbor performance when using defeatist traversals without backtracking.


For users who want to use SpillTree directly or with custom behavior, the full class is still detailed in the subsections below. SpillTree supports the TreeType API and can be used with mlpackโ€™s tree-based algorithms, although using custom behavior may require a template typedef.

๐Ÿ”— See also

๐Ÿ”— Template parameters

The SpillTree class takes five template parameters. The first three of these are required by the TreeType API (see also this more detailed section). The full signature of the class is:

template<typename DistanceType,
         typename StatisticType,
         typename MatType,
         template<typename HyperplaneDistanceType,
                  typename HyperplaneMatType> class HyperplaneType,
         template<typename SplitDistanceType,
                  typename SplitMatType> class SplitType>
class SpillTree;

Note that the TreeType API requires trees to have only three template parameters. In order to use a SpillTree with its five template parameters with an mlpack algorithm that needs a TreeType, it is easiest to define a template typedef:

template<typename DistanceType, typename StatisticType, typename MatType>
using CustomTree = SpillTree<DistanceType, StatisticType, MatType,
    CustomHyperplaneType, CustomSplitType>

Here, CustomHyperplaneType and CustomSplitType are the desired hyperplane type and split strategy. This is the way that all SpillTree variants (such as SPTree) are defined.

๐Ÿ”— Constructors

SpillTrees are constructed by iteratively finding splitting hyperplanes, and points within a margin of the hyperplane are assigned to both child nodes. Unlike the constructors of BinarySpaceTree, the dataset is not permuted during construction.





Notes:


๐Ÿ”— Constructor parameters:

name type description default
data MatType Column-major matrix to build the tree on. (N/A)
tau double Width of spill margin: points within tau of the splitting hyperplane of a node will be contained in both left and right children. 0.0
maxLeafSize size_t Maximum number of points to store in each leaf. 20
rho double Balance threshold. When splitting, if either overlapping node would contain a fraction of more than rho of the points, a non-overlapping split is performed. Must be in the range [0.0, 1.0). 0.7

Caveats:

๐Ÿ”— Basic tree properties

Once a SpillTree object is constructed, various properties of the tree can be accessed or inspected. Many of these functions are required by the TreeType API.


๐Ÿ”— Accessing members of a tree

See also the developer documentation for basic tree functionality in mlpack.


๐Ÿ”— Accessing data held in a tree


๐Ÿ”— Accessing computed bound quantities of a tree

The following quantities are cached for each node in a SpillTree, and so accessing them does not require any computation. In the documentation below, ElemType is the element type of the given MatType; e.g., if MatType is arma::mat, then ElemType is double.

Note: for more details on each bound quantity, see the developer documentation on bound quantities for trees.


๐Ÿ”— Other functionality

๐Ÿ”— Bounding distances with the tree

The primary use of trees in mlpack is bounding distances to points or other tree nodes. The following functions can be used for these tasks.


๐Ÿ”— Tree traversals

Like every mlpack tree, the SpillTree class provides a single-tree and dual-tree traversal that can be paired with a RuleType class to implement a single-tree or dual-tree algorithm.

However, the spill tree is primarily useful because the overlapping nodes allow defeatist search to be effective. Defeatist search is non-backtracking: the tree is traversed to one leaf only. For example, finding the approximate nearest neighbor of a point p with defeatist search is done by recursing in the tree, choosing the child with smallest minimum distance to p, and when a leaf is encountered, choosing the closest point in the leaf to p as the nearest neighbor. This is the strategy used in the original spill tree paper (pdf).

Defeatist traversers, matching the API for a regular traversal are made available as the following two classes:

Any RuleType that is being used with a defeatist traversal, in addition to the functions required by the RuleType API, must implement the following functions:

// This is only required for single-tree defeatist traversals.
// It should return the index of the branch that should be chosen for the given
// query point and reference node.
template<typename VecType, typename TreeType>
size_t GetBestChild(const VecType& queryPoint, TreeType& referenceNode);

// This is only required for dual-tree defeatist traversals.
// It should return the index of the best child of the reference node that
// should be chosen for the given query node.
template<typename TreeType>
size_t GetBestChild(TreeType& queryNode, TreeType& referenceNode);

// Return the minimum number of base cases (point-to-point computations) that
// are required during the traversal.
size_t MinimumBaseCases();

๐Ÿ”— HyperplaneType

Each node in a SpillTree corresponds to some region in space that contains all of the descendant points in the node. Similar to the KDTree, this region is a hyperrectangle; however, instead of representing that hyperrectangle explicitly like the KDTree with the HRectBound class, the SpillTree represents the region implicitly, with each node storing only the hyperplane and margin required to determine whether a point belongs to the left node, the right node, or both.

The type of hyperplane (e.g. axis-aligned or arbitrary) can be controlled by the HyperplaneType template parameter. mlpack supplies two drop-in classes that can be used for HyperplaneType, and it is also possible to write a custom HyperplaneType:

๐Ÿ”— AxisOrthogonalHyperplane

The AxisOrthogonalHyperplane class is used to provide an axis-orthogonal split for a SpillTree. That is, whether or not a point is on the left or right side of the split is a very efficient computation using only a single dimension of the data.

For more details, see the the source code.

๐Ÿ”— Hyperplane

The Hyperplane class is used to provide an arbitrary hyperplane split for a SpillTree. The computation of whether or not a point is on the left or right side of the split is less efficient than AxisOrthogonalHyperplane, but Hyperplane is able to represent any possible hyperplane.

For more details, see the the source code.

๐Ÿ”— Custom HyperplaneTypes

Custom hyperplane types for a spill tree can be implemented via the HyperplaneType template parameter. By default, the AxisOrthogonalHyperplane hyperplane type is used, but it is also possible to implement and use a custom HyperplaneType. Any custom HyperplaneType class must implement the following signature:

// NOTE: the custom HyperplaneType class must take two template parameters.
template<typename DistanceType, typename MatType>
class HyperplaneType
{
 public:
  // The hyperplane type must specify these two public typedefs, which are used
  // by the spill tree and the splitting strategy.
  //
  // Substitute HRectBound and ProjVector with your choices.
  using BoundType = mlpack::HRectBound<DistanceType,
                                       typename MatType::elem_type>;
  using ProjVectorType = mlpack::ProjVector<MatType>;

  // Empty constructor, which will construct an empty or default hyperplane.
  HyperplaneType();

  // Construct the HyperplaneType with the given projection vector and split
  // value along that projection.
  HyperplaneType(const ProjVectorType& projVector, double splitVal);

  // Compute the projection of the given point (an `arma::vec` or similar type
  // matching the Armadillo API and element type of `MatType`) onto the vector
  // tangent to the hyperplane.
  template<typename VecType>
  double Project(const VecType& point) const;

  // Return true if the point (an `arma::vec` or similar type matching the
  // Armadillo API and element type of `MatType`) falls to the left of the
  // hyperplane.
  template<typename VecType>
  double Left(const VecType& point) const;

  // Return true if the point (an `arma::vec` or similar type matching the
  // Armadillo API and element type of `MatType`) falls to the right of the
  // hyperplane.
  template<typename VecType>
  double Right(const VecType& point) const;

  // Return true if the given bound is fully to the left of the hyperplane.
  bool Left(const BoundType& bound) const;

  // Return true if the given bound is fully to the right of the hyperplane.
  bool Right(const BoundType& bound) const;

  // Serialize the hyperplane using cereal.
  template<typename Archive>
  void serialize(Archive& ar, const uint32_t version);
};

๐Ÿ”— SplitType

The SplitType template parameter controls the algorithm used to split each node of a SpillTree while building. The splitting strategy used can be entirely arbitraryโ€”the SplitType only needs to compute a HyperplaneType to split a set of points.

mlpack provides two drop-in choices for SplitType, and it is also possible to write a fully custom split:

๐Ÿ”— MidpointSpaceSplit

The MidpointSpaceSplit class is a splitting strategy that can be used by SpillTree. It is the default strategy for splitting SPTrees and NonOrtSPTrees.

The splitting strategy for the MidpointSpaceSplit class is, given a set of points:

Note that MidpointSpaceSplit can only be used with a HyperplaneType with HyperplaneType::ProjVectorType as either AxisAlignedProjVector or ProjVector.

For implementation details, see the source code.

๐Ÿ”— MeanSpaceSplit

The MeanSpaceSplit class is a splitting strategy that can be used by SpillTree. It is the splitting strategy used by the MeanSPTree and the NonOrtMeanSPTree classes.

The splitting strategy for the MeanSpaceSplit class is, given a set of points:

Note that MeanSpaceSplit can only be used with a HyperplaneType with HyperplaneType::ProjVectorType as either AxisAlignedProjVector or ProjVector.

For implementation details, see the source code.

๐Ÿ”— Custom SplitTypes

Custom split strategies for a spill tree can be implemented via the SplitType template parameter. By default, the MidpointSpaceSplit splitting strategy is used, but it is also possible to implement and use a custom SplitType. Any custom SplitType class must implement the following signature:

// NOTE: the custom SplitType class must take two template parameters.
template<typename DistanceType, typename MatType>
class SplitType
{
 public:
  // The SplitType class is only required to provide one static function.

  // Create a splitting hyperplane and store it in the given `HyperplaneType`,
  // using the given data and bounding box `bound`.  `data` will be an Armadillo
  // matrix that is the entire dataset, and `points` are the indices of points
  // in `data` that should be split.
  template<typename HyperplaneType>
  static bool SplitSpace(
      const typename HyperplaneType::BoundType& bound,
      const MatType& data,
      const arma::Col<size_t>& points,
      HyperplaneType& hyp);
};

๐Ÿ”— Example usage

The SpillTree class is only really necessary when a custom hyperplane type or custom splitting strategy is intended to be used. For simpler use cases, one of the typedefs of SpillTree (such as SPTree) will suffice.

For this reason, all of the examples below explicitly specify all five template parameters of SPTree. Writing a custom hyperplane type and writing a custom splitting strategy are discussed in the previous sections. Each of the parameters in the examples below can be trivially changed for different behavior.


Build a SpillTree on the cloud dataset and print basic statistics about the tree.

// See https://datasets.mlpack.org/cloud.csv.
arma::mat dataset;
mlpack::data::Load("cloud.csv", dataset, true);

// Build the spill tree with a tau (margin) of 0.2 and a leaf size of 10.
// (This means that nodes are split until they contain 10 or fewer points.)
//
// The std::move() means that `dataset` will be empty after this call, and no
// data will be copied during tree building.
mlpack::SpillTree<mlpack::EuclideanDistance,
                  mlpack::EmptyStatistic,
                  arma::mat,
                  mlpack::AxisOrthogonalHyperplane,
                  mlpack::MidpointSpaceSplit> tree(std::move(dataset), 0.2, 10);

// Print the bounding box of the root node.
std::cout << "Bounding box of root node:" << std::endl;
for (size_t i = 0; i < tree.Bound().Dim(); ++i)
{
  std::cout << " - Dimension " << i << ": [" << tree.Bound()[i].Lo() << ", "
      << tree.Bound()[i].Hi() << "]." << std::endl;
}
std::cout << std::endl;

// Print the number of descendant points of the root, and of each of its
// children.
std::cout << "Descendant points of root:        "
    << tree.NumDescendants() << "." << std::endl;
std::cout << "Descendant points of left child:  "
    << tree.Left()->NumDescendants() << "." << std::endl;
std::cout << "Descendant points of right child: "
    << tree.Right()->NumDescendants() << "." << std::endl;
std::cout << std::endl;

// Compute the center of the SpillTree.
arma::vec center;
tree.Center(center);
std::cout << "Center of tree: " << center.t();

Build two SpillTrees on subsets of the corel dataset and compute minimum and maximum distances between different nodes in the tree.

// See https://datasets.mlpack.org/corel-histogram.csv.
arma::mat dataset;
mlpack::data::Load("corel-histogram.csv", dataset, true);

// Convenience typedef for the tree type.
using TreeType = mlpack::SpillTree<mlpack::EuclideanDistance,
                                   mlpack::EmptyStatistic,
                                   arma::mat,
                                   mlpack::AxisOrthogonalHyperplane,
                                   mlpack::MidpointSpaceSplit>;

// Build trees on the first half and the second half of points.  Use a tau
// (overlap) parameter of 0.3, which is tuned to this dataset, and a rho value
// of 0.6 to prevent the trees getting too deep.
TreeType tree1(dataset.cols(0, dataset.n_cols / 2), 0.3, 20, 0.6);
TreeType tree2(dataset.cols(dataset.n_cols / 2 + 1, dataset.n_cols - 1),
    0.3, 20, 0.6);

// Compute the maximum distance between the trees.
std::cout << "Maximum distance between tree root nodes: "
    << tree1.MaxDistance(tree2) << "." << std::endl;

// Get the leftmost grandchild of the first tree's root---if it exists.
if (!tree1.IsLeaf() && !tree1.Child(0).IsLeaf())
{
  TreeType& node1 = tree1.Child(0).Child(0);

  // Get the rightmost grandchild of the second tree's root---if it exists.
  if (!tree2.IsLeaf() && !tree2.Child(1).IsLeaf())
  {
    TreeType& node2 = tree2.Child(1).Child(1);

    // Print the minimum and maximum distance between the nodes.
    mlpack::Range dists = node1.RangeDistance(node2);
    std::cout << "Possible distances between two grandchild nodes: ["
        << dists.Lo() << ", " << dists.Hi() << "]." << std::endl;

    // Print the minimum distance between the first node and the first
    // descendant point of the second node.
    const size_t descendantIndex = node2.Descendant(0);
    const double descendantMinDist =
        node1.MinDistance(node2.Dataset().col(descendantIndex));
    std::cout << "Minimum distance between grandchild node and descendant "
        << "point: " << descendantMinDist << "." << std::endl;

    // Which child of node2 is closer to node1?
    const size_t closerIndex = node2.GetNearestChild(node1);
    if (closerIndex == 0)
      std::cout << "The left child of node2 is closer to node1." << std::endl;
    else if (closerIndex == 1)
      std::cout << "The right child of node2 is closer to node1." << std::endl;
    else // closerIndex == 2 in this case.
      std::cout << "Both children of node2 are equally close to node1."
          << std::endl;

    // And which child of node1 is further from node2?
    const size_t furtherIndex = node1.GetFurthestChild(node2);
    if (furtherIndex == 0)
      std::cout << "The left child of node1 is further from node2."
          << std::endl;
    else if (furtherIndex == 1)
      std::cout << "The right child of node1 is further from node2."
          << std::endl;
    else // furtherIndex == 2 in this case.
      std::cout << "Both children of node1 are equally far from node2."
          << std::endl;
  }
}

Build a SpillTree on 32-bit floating point data and save it to disk.

// See https://datasets.mlpack.org/corel-histogram.csv.
arma::fmat dataset;
mlpack::data::Load("corel-histogram.csv", dataset);

// Build the SpillTree using 32-bit floating point data as the matrix type.
// We will still use the default EmptyStatistic and EuclideanDistance
// parameters.
mlpack::SpillTree<mlpack::EuclideanDistance,
                  mlpack::EmptyStatistic,
                  arma::fmat,
                  mlpack::AxisOrthogonalHyperplane,
                  mlpack::MidpointSpaceSplit> tree(
    std::move(dataset), 0.1, 20, 0.95);

// Save the tree to disk with the name 'tree'.
mlpack::data::Save("tree.bin", "tree", tree);

std::cout << "Saved tree with " << tree.Dataset().n_cols << " points to "
    << "'tree.bin'." << std::endl;

Load a 32-bit floating point SpillTree from disk, then traverse it manually and find the number of nodes whose children overlap.

// This assumes the tree has already been saved to 'tree.bin' (as in the example
// above).

// This convenient typedef saves us a long type name!
using TreeType = mlpack::SpillTree<mlpack::EuclideanDistance,
                                   mlpack::EmptyStatistic,
                                   arma::fmat,
                                   mlpack::AxisOrthogonalHyperplane,
                                   mlpack::MidpointSpaceSplit>;

TreeType tree;
mlpack::data::Load("tree.bin", "tree", tree);
std::cout << "Tree loaded with " << tree.NumDescendants() << " points."
    << std::endl;

// Recurse in a depth-first manner.  Count both the total number of non-leaves,
// and the number of non-leaves that have overlapping children.
size_t overlapCount = 0;
size_t totalInternalNodeCount = 0;
std::stack<TreeType*> stack;
stack.push(&tree);
while (!stack.empty())
{
  TreeType* node = stack.top();
  stack.pop();

  if (node->IsLeaf())
    continue;

  if (node->Overlap())
    ++overlapCount;
  ++totalInternalNodeCount;

  stack.push(node->Left());
  stack.push(node->Right());
}

// Note that it would be possible to use TreeType::SingleTreeTraverser to
// perform the recursion above, but that is more well-suited for more complex
// tasks that require pruning and other non-trivial behavior; so using a simple
// stack is the better option here.

// Print the results.
std::cout << overlapCount << " out of " << totalInternalNodeCount
    << " internal nodes have overlapping children." << std::endl;

Use a defeatist traversal to find the approximate nearest neighbor of the third and fourth points in the corel-histogram dataset. (Note: this can also be done more easily with the KNN class! This example is a demonstration of how to use the defeatist traverser.)

For this example, we must first define a RuleType class.

// For simplicity, this only implements those methods required by single-tree
// traversals, and cannot be used with a dual-tree traversal.
//
// `.Reset()` must be called before any additional single-tree traversals after
// the first is run.
class SpillNearestNeighborRule
{
 public:
  // Store the dataset internally.
  SpillNearestNeighborRule(const arma::mat& dataset) :
      dataset(dataset),
      nearestNeighbor(size_t(-1)),
      nearestDistance(DBL_MAX) { }

  // Compute the base case (point-to-point comparison).
  double BaseCase(const size_t queryIndex, const size_t referenceIndex)
  {
    // Skip the base case if the points are the same.
    if (queryIndex == referenceIndex)
      return 0.0;

    const double dist = mlpack::EuclideanDistance::Evaluate(
        dataset.col(queryIndex), dataset.col(referenceIndex));

    if (dist < nearestDistance)
    {
      nearestNeighbor = referenceIndex;
      nearestDistance = dist;
    }

    return dist;
  }

  // Score the given node in the tree; if it is sufficiently far away that it
  // cannot contain a better nearest neighbor candidate, we can prune it.
  template<typename TreeType>
  double Score(const size_t queryIndex, const TreeType& referenceNode) const
  {
    const double minDist = referenceNode.MinDistance(dataset.col(queryIndex));
    if (minDist > nearestDistance)
      return DBL_MAX; // Prune: this cannot contain a better candidate!

    return minDist;
  }

  // Rescore the given node/point combination.  Note that this will not be used
  // by the defeatist traversal as it never backtracks, but we include it for
  // completeness because the RuleType API requires it.
  template<typename TreeType>
  double Rescore(const size_t, const TreeType&, const double oldScore) const
  {
    if (oldScore > nearestDistance)
      return DBL_MAX; // Prune: the node is too far away.
    return oldScore;
  }

  // This is required by defeatist traversals to select the best reference
  // child to recurse into for overlapping nodes.
  template<typename TreeType>
  size_t GetBestChild(const size_t queryIndex, TreeType& referenceNode)
      const
  {
    return referenceNode.GetNearestChild(dataset.col(queryIndex));
  }

  // We must perform at least two base cases in order to have a result.  Note
  // that this is two, and not one, because we skip base cases where the query
  // and reference points are the same.  That can only happen a maximum of once,
  // so to ensure that we compare a query point to a different reference point
  // at least once, we must return 2 here.
  size_t MinimumBaseCases() const { return 2; }

  // Get the results (to be called after the traversal).
  size_t NearestNeighbor() const { return nearestNeighbor; }
  double NearestDistance() const { return nearestDistance; }

  // Reset the internal statistics for an additional traversal.
  void Reset()
  {
    nearestNeighbor = size_t(-1);
    nearestDistance = DBL_MAX;
  }

 private:
  const arma::mat& dataset;

  size_t nearestNeighbor;
  double nearestDistance;
};
// See https://datasets.mlpack.org/corel-histogram.csv.
arma::mat dataset;
mlpack::data::Load("corel-histogram.csv", dataset, true);

typedef mlpack::SpillTree<mlpack::EuclideanDistance,
                          mlpack::EmptyStatistic,
                          arma::mat,
                          mlpack::AxisOrthogonalHyperplane,
                          mlpack::MidpointSpaceSplit> TreeType;

// Build two trees, one with a lot of overlap, and one with no overlap
// (e.g. tau = 0).
TreeType tree1(dataset, 0.5, 10), tree2(dataset, 0.0, 10);

// Construct the rule types, and then the traversals.
SpillNearestNeighborRule r1(dataset), r2(dataset);

TreeType::DefeatistSingleTreeTraverser<SpillNearestNeighborRule> t1(r1);
TreeType::DefeatistSingleTreeTraverser<SpillNearestNeighborRule> t2(r2);

// Search for the approximate nearest neighbor of point 3 using both trees.
t1.Traverse(3, tree1);
t2.Traverse(3, tree2);

std::cout << "Approximate nearest neighbor of point 3:" << std::endl;
std::cout << " - Spill tree with overlap 0.5 found: point "
    << r1.NearestNeighbor() << ", distance " << r1.NearestDistance()
    << "." << std::endl;

std::cout << " - Spill tree with no overlap found: point "
    << r2.NearestNeighbor() << ", distance " << r2.NearestDistance()
    << "." << std::endl;

// Now search for point 6.
r1.Reset();
r2.Reset();

t1.Traverse(6, tree1);
t2.Traverse(6, tree2);

std::cout << "Approximate nearest neighbor of point 6:" << std::endl;
std::cout << " - Spill tree with overlap 0.5 found: point "
    << r1.NearestNeighbor() << ", distance " << r1.NearestDistance()
    << "." << std::endl;

std::cout << " - Spill tree with no overlap found: point "
    << r2.NearestNeighbor() << ", distance " << r2.NearestDistance()
    << "." << std::endl;