mlpack

NonOrtSPTree

The NonOrtSPTree class implements the hybrid spill tree with non-axis-orthogonal splitting hyperplanes; this is a binary space partitioning tree that allows overlapping volumes between nodes. This type of tree can be more effective than trees like the KDTree for approximate nearest neighbor search and related tasks.

NonOrtSPTree supports three template parameters for configurable behavior, and implements all the functionality required by the TreeType API, plus some additional functionality specific to spill trees. NonOrtSPTree is built on the more generic SpillTree class, so if fully custom behavior is desired, that

πŸ”— See also

πŸ”— Template parameters

In accordance with the TreeType API (see also this more detailed section), the NonOrtSPTree class takes three template parameters:

NonOrtSPTree<DistanceType, StatisticType, MatType>

The NonOrtSPTree class itself is a convenience typedef of the generic SpillTree class, using the Hyperplane class as the splitting hyperplane type, and the MidpointSpaceSplit class as the splitting strategy.

If no template parameters are explicitly specified, then defaults are used:

NonOrtSPTree<> = NonOrtSPTree<EuclideanDistance, EmptyStatistic, arma::mat>

πŸ”— Constructors

NonOrtSPTrees are constructed by iteratively finding splitting hyperplanes, and points within a margin of the hyperplane are assigned to both child nodes. Unlike the constructors of BinarySpaceTree, the dataset is not permuted during construction.





Notes:


πŸ”— Constructor parameters:

name type description default
data MatType Column-major matrix to build the tree on. (N/A)
tau double Width of spill margin: points within tau of the splitting hyperplane of a node will be contained in both left and right children. 0.0
maxLeafSize size_t Maximum number of points to store in each leaf. 20
rho double Balance threshold. When splitting, if either overlapping node would contain a fraction of more than rho of the points, a non-overlapping split is performed. Must be in the range [0.0, 1.0). 0.7

Caveats:

πŸ”— Basic tree properties

Once an NonOrtSPTree object is constructed, various properties of the tree can be accessed or inspected. Many of these functions are required by the TreeType API.


πŸ”— Accessing members of a tree

See also the developer documentation for basic tree functionality in mlpack.


πŸ”— Accessing data held in a tree


πŸ”— Accessing computed bound quantities of a tree

The following quantities are cached for each node in a NonOrtSPTree, and so accessing them does not require any computation. In the documentation below, ElemType is the element type of the given MatType; e.g., if MatType is arma::mat, then ElemType is double.

Note: for more details on each bound quantity, see the developer documentation on bound quantities for trees.


πŸ”— Other functionality

πŸ”— Bounding distances with the tree

The primary use of trees in mlpack is bounding distances to points or other tree nodes. The following functions can be used for these tasks.


πŸ”— Tree traversals

Like every mlpack tree, the NonOrtSPTree class provides a single-tree and dual-tree traversal that can be paired with a RuleType class to implement a single-tree or dual-tree algorithm.

However, spill trees are primarily useful because the overlapping nodes allow defeatist search to be effective. Defeatist search is non-backtracking: the tree is traversed to one leaf only. For example, finding the approximate nearest neighbor of a point p with defeatist search is done by recursing in the tree, choosing the child with smallest minimum distance to p, and when a leaf is encountered, choosing the closest point in the leaf to p as the nearest neighbor. This is the strategy used in the original spill tree paper (pdf).

Defeatist traversers, matching the API for a regular traversal are made available as the following two classes:

Any RuleType that is being used with a defeatist traversal, in addition to the functions required by the RuleType API, must implement the following functions:

// This is only required for single-tree defeatist traversals.
// It should return the index of the branch that should be chosen for the given
// query point and reference node.
template<typename VecType, typename TreeType>
size_t GetBestChild(const VecType& queryPoint, TreeType& referenceNode);

// This is only required for dual-tree defeatist traversals.
// It should return the index of the best child of the reference node that
// should be chosen for the given query node.
template<typename TreeType>
size_t GetBestChild(TreeType& queryNode, TreeType& referenceNode);

// Return the minimum number of base cases (point-to-point computations) that
// are required during the traversal.
size_t MinimumBaseCases();

πŸ”— Example usage

Build an NonOrtSPTree on the cloud dataset and print basic statistics about the tree.

// See https://datasets.mlpack.org/cloud.csv.
arma::mat dataset;
mlpack::data::Load("cloud.csv", dataset, true);

// Build the spill tree with a tau (margin) of 0.2 and a leaf size of 10.
// (This means that nodes are split until they contain 10 or fewer points.)
//
// The std::move() means that `dataset` will be empty after this call, and no
// data will be copied during tree building.
//
// When C++20 is enabled, then the <> is not necessary and the following line
// will work:
// mlpack::NonOrtSPTree tree(std::move(dataset), 0.2, 10);
mlpack::NonOrtSPTree<> tree(std::move(dataset), 0.2, 10);

// Print the bounding ball of the root node.
std::cout << "Bounding ball of root node:" << std::endl;
std::cout << "  Center: " << tree.Bound().Center().t();
std::cout << "  Radius: " << tree.Bound().Radius() << "." << std::endl;
std::cout << std::endl;

// Print the number of descendant points of the root, and of each of its
// children.
std::cout << "Descendant points of root:        "
    << tree.NumDescendants() << "." << std::endl;
std::cout << "Descendant points of left child:  "
    << tree.Left()->NumDescendants() << "." << std::endl;
std::cout << "Descendant points of right child: "
    << tree.Right()->NumDescendants() << "." << std::endl;
std::cout << std::endl;

// Compute the center of the NonOrtSPTree.  THis is the same as the center of
// the bounding ball of the root.
arma::vec center;
tree.Center(center);
std::cout << "Center of tree: " << center.t();

Build two NonOrtSPTrees on subsets of the corel dataset and compute minimum and maximum distances between different nodes in the tree.

// See https://datasets.mlpack.org/corel-histogram.csv.
arma::mat dataset;
mlpack::data::Load("corel-histogram.csv", dataset, true);

// Build trees on the first half and the second half of points.  Use a tau
// (overlap) parameter of 0.3, which is tuned to this dataset, and a rho value
// of 0.6 to prevent the trees getting too deep.
mlpack::NonOrtSPTree<> tree1(dataset.cols(0, dataset.n_cols / 2), 0.3, 20, 0.6);
mlpack::NonOrtSPTree<> tree2(dataset.cols(dataset.n_cols / 2 + 1,
                                          dataset.n_cols - 1), 0.3, 20, 0.6);

// Compute the maximum distance between the trees.
std::cout << "Maximum distance between tree root nodes: "
    << tree1.MaxDistance(tree2) << "." << std::endl;

// Get the leftmost grandchild of the first tree's root---if it exists.
if (!tree1.IsLeaf() && !tree1.Child(0).IsLeaf())
{
  mlpack::NonOrtSPTree<>& node1 = tree1.Child(0).Child(0);

  // Get the rightmost grandchild of the second tree's root---if it exists.
  if (!tree2.IsLeaf() && !tree2.Child(1).IsLeaf())
  {
    mlpack::NonOrtSPTree<>& node2 = tree2.Child(1).Child(1);

    // Print the minimum and maximum distance between the nodes.
    mlpack::Range dists = node1.RangeDistance(node2);
    std::cout << "Possible distances between two grandchild nodes: ["
        << dists.Lo() << ", " << dists.Hi() << "]." << std::endl;

    // Print the minimum distance between the first node and the first
    // descendant point of the second node.
    const size_t descendantIndex = node2.Descendant(0);
    const double descendantMinDist =
        node1.MinDistance(node2.Dataset().col(descendantIndex));
    std::cout << "Minimum distance between grandchild node and descendant "
        << "point: " << descendantMinDist << "." << std::endl;

    // Which child of node2 is closer to node1?
    const size_t closerIndex = node2.GetNearestChild(node1);
    if (closerIndex == 0)
      std::cout << "The left child of node2 is closer to node1." << std::endl;
    else if (closerIndex == 1)
      std::cout << "The right child of node2 is closer to node1." << std::endl;
    else // closerIndex == 2 in this case.
      std::cout << "Both children of node2 are equally close to node1."
          << std::endl;

    // And which child of node1 is further from node2?
    const size_t furtherIndex = node1.GetFurthestChild(node2);
    if (furtherIndex == 0)
      std::cout << "The left child of node1 is further from node2."
          << std::endl;
    else if (furtherIndex == 1)
      std::cout << "The right child of node1 is further from node2."
          << std::endl;
    else // furtherIndex == 2 in this case.
      std::cout << "Both children of node1 are equally far from node2."
          << std::endl;
  }
}

Build a NonOrtSPTree on 32-bit floating point data and save it to disk.

// See https://datasets.mlpack.org/corel-histogram.csv.
arma::fmat dataset;
mlpack::data::Load("corel-histogram.csv", dataset);

// Build the NonOrtSPTree using 32-bit floating point data as the matrix type.
// We will still use the default EmptyStatistic and EuclideanDistance
// parameters.
mlpack::NonOrtSPTree<mlpack::EuclideanDistance,
                     mlpack::EmptyStatistic,
                     arma::fmat> tree(std::move(dataset), 0.1, 20, 0.6);

// Save the tree to disk with the name 'tree'.
mlpack::data::Save("tree.bin", "tree", tree);

std::cout << "Saved tree with " << tree.Dataset().n_cols << " points to "
    << "'tree.bin'." << std::endl;

Load a 32-bit floating point NonOrtSPTree from disk, then traverse it manually and find the number of nodes whose children overlap.

// This assumes the tree has already been saved to 'tree.bin' (as in the example
// above).

// This convenient typedef saves us a long type name!
using TreeType = mlpack::NonOrtSPTree<mlpack::EuclideanDistance,
                                      mlpack::EmptyStatistic,
                                      arma::fmat>;

TreeType tree;
mlpack::data::Load("tree.bin", "tree", tree);
std::cout << "Tree loaded with " << tree.NumDescendants() << " points."
    << std::endl;

// Recurse in a depth-first manner.  Count both the total number of non-leaves,
// and the number of non-leaves that have overlapping children.
size_t overlapCount = 0;
size_t totalInternalNodeCount = 0;
std::stack<TreeType*> stack;
stack.push(&tree);
while (!stack.empty())
{
  TreeType* node = stack.top();
  stack.pop();

  if (node->IsLeaf())
    continue;

  if (node->Overlap())
    ++overlapCount;
  ++totalInternalNodeCount;

  stack.push(node->Left());
  stack.push(node->Right());
}

// Note that it would be possible to use TreeType::SingleTreeTraverser to
// perform the recursion above, but that is more well-suited for more complex
// tasks that require pruning and other non-trivial behavior; so using a simple
// stack is the better option here.

// Print the results.
std::cout << overlapCount << " out of " << totalInternalNodeCount
    << " internal nodes have overlapping children." << std::endl;

Use a defeatist traversal to find the approximate nearest neighbor of the third and fourth points in the corel-histogram dataset. (Note: this can also be done more easily with the KNN class! This example is a demonstration of how to use the defeatist traverser.)

For this example, we must first define a RuleType class.

// For simplicity, this only implements those methods required by single-tree
// traversals, and cannot be used with a dual-tree traversal.
//
// `.Reset()` must be called before any additional single-tree traversals after
// the first is run.
class SpillNearestNeighborRule
{
 public:
  // Store the dataset internally.
  SpillNearestNeighborRule(const arma::mat& dataset) :
      dataset(dataset),
      nearestNeighbor(size_t(-1)),
      nearestDistance(DBL_MAX) { }

  // Compute the base case (point-to-point comparison).
  double BaseCase(const size_t queryIndex, const size_t referenceIndex)
  {
    // Skip the base case if the points are the same.
    if (queryIndex == referenceIndex)
      return 0.0;

    const double dist = mlpack::EuclideanDistance::Evaluate(
        dataset.col(queryIndex), dataset.col(referenceIndex));

    if (dist < nearestDistance)
    {
      nearestNeighbor = referenceIndex;
      nearestDistance = dist;
    }

    return dist;
  }

  // Score the given node in the tree; if it is sufficiently far away that it
  // cannot contain a better nearest neighbor candidate, we can prune it.
  template<typename TreeType>
  double Score(const size_t queryIndex, const TreeType& referenceNode) const
  {
    const double minDist = referenceNode.MinDistance(dataset.col(queryIndex));
    if (minDist > nearestDistance)
      return DBL_MAX; // Prune: this cannot contain a better candidate!

    return minDist;
  }

  // Rescore the given node/point combination.  Note that this will not be used
  // by the defeatist traversal as it never backtracks, but we include it for
  // completeness because the RuleType API requires it.
  template<typename TreeType>
  double Rescore(const size_t, const TreeType&, const double oldScore) const
  {
    if (oldScore > nearestDistance)
      return DBL_MAX; // Prune: the node is too far away.
    return oldScore;
  }

  // This is required by defeatist traversals to select the best reference
  // child to recurse into for overlapping nodes.
  template<typename TreeType>
  size_t GetBestChild(const size_t queryIndex, TreeType& referenceNode)
      const
  {
    return referenceNode.GetNearestChild(dataset.col(queryIndex));
  }

  // We must perform at least two base cases in order to have a result.  Note
  // that this is two, and not one, because we skip base cases where the query
  // and reference points are the same.  That can only happen a maximum of once,
  // so to ensure that we compare a query point to a different reference point
  // at least once, we must return 2 here.
  size_t MinimumBaseCases() const { return 2; }

  // Get the results (to be called after the traversal).
  size_t NearestNeighbor() const { return nearestNeighbor; }
  double NearestDistance() const { return nearestDistance; }

  // Reset the internal statistics for an additional traversal.
  void Reset()
  {
    nearestNeighbor = size_t(-1);
    nearestDistance = DBL_MAX;
  }

 private:
  const arma::mat& dataset;

  size_t nearestNeighbor;
  double nearestDistance;
};
// See https://datasets.mlpack.org/corel-histogram.csv.
arma::mat dataset;
mlpack::data::Load("corel-histogram.csv", dataset, true);

// Build two trees, one with a lot of overlap, and one with no overlap
// (e.g. tau = 0).
mlpack::NonOrtSPTree<> tree1(dataset, 0.5, 10), tree2(dataset, 0.0, 10);

// Construct the rule types, and then the traversals.
SpillNearestNeighborRule r1(dataset), r2(dataset);

mlpack::NonOrtSPTree<>::DefeatistSingleTreeTraverser<SpillNearestNeighborRule>
    t1(r1), t2(r2);

// Search for the approximate nearest neighbor of point 3 using both trees.
t1.Traverse(3, tree1);
t2.Traverse(3, tree2);

std::cout << "Approximate nearest neighbor of point 3:" << std::endl;
std::cout << " - Non-axis-aligned spill tree with overlap 0.5 found: point "
    << r1.NearestNeighbor() << ", distance " << r1.NearestDistance()
    << "." << std::endl;

std::cout << " - Non-axis-aligned spill tree with no overlap found: point "
    << r2.NearestNeighbor() << ", distance " << r2.NearestDistance()
    << "." << std::endl;

// Now search for point 6.
r1.Reset();
r2.Reset();

t1.Traverse(6, tree1);
t2.Traverse(6, tree2);

std::cout << "Approximate nearest neighbor of point 6:" << std::endl;
std::cout << " - Non-axis-aligned spill tree with overlap 0.5 found: point "
    << r1.NearestNeighbor() << ", distance " << r1.NearestDistance()
    << "." << std::endl;

std::cout << " - Non-axis-aligned spill tree with no overlap found: point "
    << r2.NearestNeighbor() << ", distance " << r2.NearestDistance()
    << "." << std::endl;