mlpack

The KNN class implements k-nearest neighbor search, a core computational task that is useful in many machine learning situations. Either exact or approximate nearest neighbors can be computed. mlpack’s KNN class uses trees, by default the KDTree, to provide significantly accelerated computation; depending on input options, an efficient dual-tree or single-tree algorithm is used.

r₀ r₁ r₂ r₃ r₄ q₀ q₁

The exact nearest neighbor of the query point q₀ is r₂.
The exact nearest neighbor of the query point q₁ is r₀.
Approximate search will not return exact neighbors, but results will be close;
e.g., r₃ could be returned as the approximate nearest neighbor of q₁.

Given a reference set of points and a query set of queries, the KNN class will compute the nearest neighbors in the reference set of every point in the query set. If no query set is given, then KNN will find the nearest neighbors of every point in the reference set; this is also called the all-nearest-neighbors problem.

The KNN class supports configurable behavior, with numerous runtime and compile-time parameters, including the distance metric, type of data, search strategy, and tree type.

Note that the KNN class is not a classifier, and instead focuses on simply computing the nearest neighbors of points.

Simple usage example:

// Compute the 5 exact nearest neighbors of every point of random numeric data.

// All data is uniform random: 10-dimensional data.  Replace with a Load()
// call or similar for a real application.
arma::mat referenceSet(10, 1000, arma::fill::randu); // 1000 points.

mlpack::KNN knn;                     // Step 1: create object.
knn.Train(referenceSet);             // Step 2: set the reference set.
arma::mat distances;
arma::Mat<size_t> neighbors;
knn.Search(5, neighbors, distances); // Step 3: find 5 nearest neighbors of
                                     //         every point in `referenceSet`.

// Note: you can also call `knn.Search(querySet, 5, neighbors, distances)` to
// find the nearest neighbors in `referenceSet` of a different set of points.

// Print some information about the results.
std::cout << "Found " << neighbors.n_rows << " neighbors for each of "
    << neighbors.n_cols << " points in the dataset." << std::endl;

More examples...

See also:

🔗 Constructors

Note: if std::move() is not used to pass referenceSet or referenceTree, those objects will be copied—which can be expensive! Be sure to use std::move() if possible.


Constructor Parameters:

name type description default
referenceSet arma::mat Column-major matrix containing dataset to search for nearest neighbors in. (N/A)
referenceTree KNN::Tree (a KDTree) Pre-built kd-tree on reference data. (N/A)
strategy enum NeighborSearchStrategy The search strategy that will be used when Search() is called. Must be one of NAIVE, SINGLE_TREE, DUAL_TREE, or GREEDY_SINGLE_TREE. More details. DUAL_TREE
epsilon double Allowed relative approximation error. 0 means exact search. Must be non-negative. 0.0

Notes:

🔗 Search strategies

The KNN class can search for nearest neighbors using one of the following four strategies. These can be specified in the constructor as the strategy parameter, or by calling knn.SearchStrategy() = strategy.

🔗 Setting the reference set (Train())

If the reference set was not set in the constructor, or if it needs to be changed to a new reference set, the Train() method can be used.

🔗 Searching for neighbors

Once the reference set and parameters are set, searching for nearest neighbors can be done with the Search() method.

Notes:


Search Parameters:

name type description
querySet arma::mat Column-major matrix of query points for which the nearest neighbors in the reference set should be found.
k size_t Number of nearest neighbors to search for.
neighbors arma::Mat<size_t> Matrix to store indices of nearest neighbors into. Will be set to size k x N, where N is the number of points in the query set (if specified), or the reference set (if not).
distances arma::mat Matrix to store distances to nearest neighbors into. Will be set to the same size as neighbors.
sameSet bool (Only for Search() with a query set.) If true, then querySet is the same set as the reference set.

🔗 Computing quality metrics

If approximate nearest neighbor search is performed (e.g. if knn.Epsilon() > 0), and exact nearest neighbors are known, it is possible to compute quality metrics of the approximate search.

🔗 Other functionality

🔗 Simple examples

Find the exact nearest neighbor of every point in the cloud dataset.

// See https://datasets.mlpack.org/cloud.csv.
arma::mat dataset;
mlpack::Load("cloud.csv", dataset);

// Construct the KNN object; this will avoid copies via std::move(), and build a
// kd-tree on the dataset.
mlpack::KNN knn(std::move(dataset));

arma::mat distances;
arma::Mat<size_t> neighbors;

// Compute the exact nearest neighbor.
knn.Search(1, neighbors, distances);

// Print the nearest neighbor and distance of the fifth point in the dataset.
std::cout << "Point 4:" << std::endl;
std::cout << " - Point values: " << knn.ReferenceSet().col(4).t();
std::cout << " - Index of nearest neighbor: " << neighbors(0, 4) << "."
    << std::endl;
std::cout << " - Distance to nearest neighbor: " << distances(0, 4) << "."
    << std::endl;

Split the corel-histogram dataset into two sets, and find the exact nearest neighbor in the first set of every point in the second set.

// See https://datasets.mlpack.org/corel-histogram.csv.
arma::mat dataset;
mlpack::Load("corel-histogram.csv", dataset);

// Split the dataset into two equal-sized sets randomly with `Split()`.
arma::mat referenceSet, querySet;
mlpack::Split(dataset, referenceSet, querySet, 0.5);

// Construct the KNN object, building a tree on the reference set.  Copies are
// avoided by the use of `std::move()`.
mlpack::KNN knn(std::move(referenceSet));

arma::mat distances;
arma::Mat<size_t> neighbors;

// Compute the exact nearest neighbor in `referenceSet` of every point in
// `querySet`.
knn.Search(querySet, 1, neighbors, distances);

// Print information about the dual-tree traversal.
std::cout << "Dual-tree traversal computed " << knn.BaseCases()
    << " point-to-point distances and " << knn.Scores()
    << " tree node-to-tree node distances." << std::endl;

// Print information about nearest neighbors of the points in the query set.
std::cout << "The nearest neighbor of query point 3 is reference point index "
    << neighbors(0, 3) << ", with distance " << distances(0, 3) << "."
    << std::endl;
std::cout << "The L2-norm of reference point " << neighbors(0, 3) << " is "
    << arma::norm(knn.ReferenceSet().col(neighbors(0, 3)), 2) << "."
    << std::endl;

// Compute the average nearest neighbor distance for all points in the query
// set.
const double averageDist = arma::mean(arma::vectorise(distances));
std::cout << "Average distance between a query point and its nearest neighbor: "
    << averageDist << "." << std::endl;

Perform approximate single-tree search to find 5 nearest neighbors of the first point in a subset of the LCDM dataset.

// See https://datasets.mlpack.org/lcdm_tiny.csv.
arma::mat dataset;
mlpack::Load("lcdm_tiny.csv", dataset);

// Build a KNN object on the LCDM dataset, and pass with `std::move()` so that
// we can avoid copying the dataset.  Set the search strategy to single-tree.
mlpack::KNN knn(std::move(dataset), mlpack::SINGLE_TREE);

// Now we will compute the 5 nearest neighbors of the first point in the
// dataset.
//
// NOTE: because the first point is in the reference set, and because we are
// passing a separate query set, KNN will return that the nearest neighbor of
// the point is itself!  This is an important caveat to be aware of when calling
// Search() with a query set.
arma::mat distances;
arma::Mat<size_t> neighbors;
knn.Search(knn.ReferenceSet().col(0), 5, neighbors, distances);

std::cout << "The five nearest neighbors of the first point in the LCDM dataset"
    << " are:" << std::endl;
for (size_t k = 0; k < 5; ++k)
{
  std::cout << " - " << neighbors(k, 0) << " (with distance " << distances(k, 0)
      << ")." << std::endl;
  if (k == 0)
  {
    std::cout << "    (the first point's nearest neighbor is itself, because it"
        << " is in the reference set, and we called Query() with a separate "
        << "query set!)" << std::endl;
  }
}

Use greedy single-tree search to find 5 approximate nearest neighbors of every point in the cloud dataset. Then, compute the exact nearest neighbors, and use these to find the average error and recall of the approximate search.

Note: greedy single-tree search is far more effective when using spill trees—see the advanced examples for another example that does exactly that.

// See https://datasets.mlpack.org/cloud.csv.
arma::mat dataset;
mlpack::Load("cloud.csv", dataset);

// Build a tree on the dataset and set the search strategy to the greedy single
// tree strategy.
mlpack::KNN knn(std::move(dataset), mlpack::GREEDY_SINGLE_TREE);

// Compute the 5 approximate nearest neighbors of every point in the dataset.
arma::mat distances;
arma::Mat<size_t> neighbors;
knn.Search(5, neighbors, distances);

std::cout << "Greedy approximate kNN search computed " << knn.BaseCases()
    << " point-to-point distances and visited " << knn.Scores()
    << " tree nodes in total." << std::endl;

// Now switch to exact computation and compute the true neighbors and distances.
arma::Mat<size_t> trueNeighbors;
arma::mat trueDistances;
knn.SearchStrategy() = mlpack::DUAL_TREE;
knn.Epsilon() = 0.0;
knn.Search(5, trueNeighbors, trueDistances);

// Compute the recall and effective error.
const double recall = knn.Recall(neighbors, trueNeighbors);
const double effectiveError = knn.EffectiveError(distances, trueDistances);

std::cout << "Recall of greedy search: " << recall << "." << std::endl;
std::cout << "Effective error of greedy search: " << effectiveError << "."
    << std::endl;

Build a KNN object on the cloud dataset and save it to disk.

// See https://datasets.mlpack.org/cloud.csv.
arma::mat dataset;
mlpack::Load("cloud.csv", dataset);

// Build the reference tree.
mlpack::KNN knn(std::move(dataset));

// Save the KNN object to disk with the name 'knn'.
mlpack::Save("knn.bin", "knn", knn);

std::cout << "Successfully saved KNN model to 'knn.bin'." << std::endl;

Load a KNN object from disk, and inspect the KDTree that is held in the object.

// Load the KNN object with name 'knn' from 'knn.bin'.
mlpack::KNN knn;
mlpack::Load("knn.bin", knn);

// Inspect the KDTree held by the KNN object.
std::cout << "The KDTree in the KNN object in 'knn.bin' holds "
    << knn.ReferenceTree().NumDescendants() << " points." << std::endl;
std::cout << "The root of the tree has " << knn.ReferenceTree().NumChildren()
    << " children." << std::endl;
if (knn.ReferenceTree().NumChildren() == 2)
{
  std::cout << " - The left child holds "
      << knn.ReferenceTree().Child(0).NumDescendants() << " points."
      << std::endl;
  std::cout << " - The right child holds "
      << knn.ReferenceTree().Child(1).NumDescendants() << " points."
      << std::endl;
}

Compute the 5 approximate nearest neighbors of two subsets of the corel-histogram dataset using a pre-built query tree. Then, reuse the query tree to compute the exact neighbors and compute the effective error.

// See https://datasets.mlpack.org/corel-histogram.csv.
arma::mat dataset;
mlpack::Load("corel-histogram.csv", dataset);

// Split the covertype dataset into two parts of equal size.
arma::mat referenceSet, querySet;
mlpack::Split(dataset, referenceSet, querySet, 0.5);

// Build the KNN object, passing the reference set with `std::move()` to avoid a
// copy.  We use the default dual-tree strategy for search and set the maximum
// allowed relative error to 0.1 (10%).
mlpack::KNN knn(std::move(referenceSet), mlpack::DUAL_TREE, 0.1);

// Now build a tree on the query points.  This is a KDTree, and we manually
// specify a leaf size of 50 points.  Note that the KDTree rearranges the
// ordering of points in the query set.
mlpack::KNN::Tree queryTree(std::move(querySet));

// Compute the 5 approximate nearest neighbors of all points in the query set.
arma::mat distances;
arma::Mat<size_t> neighbors;
knn.Search(queryTree, 5, neighbors, distances);

// Now compute the exact neighbors---but since we are using dual-tree search and
// an externally-constructed query tree, we must reset the bounds!
arma::mat trueDistances;
arma::Mat<size_t> trueNeighbors;
knn.ResetTree(queryTree);
knn.Epsilon() = 0;
knn.Search(queryTree, 5, trueNeighbors, trueDistances);

// Compute the effective error.
const double effectiveError = knn.EffectiveError(distances, trueDistances);

std::cout << "Effective error of approximate dual-tree search was "
    << effectiveError << " (limit via knn.Epsilon() was 0.1)." << std::endl;

🔗 Advanced functionality: template parameters

The KNN class is a typedef of the configurable KNNType class, which has five template parameters that can be used for custom behavior. The full signature of the class is:

KNNType<DistanceType,
        TreeType,
        MatType,
        DualTreeTraversalType,
        SingleTreeTraversalType>

When custom template parameters are specified:


DistanceType


TreeType


MatType


DualTreeTraversalType


SingleTreeTraversalType


🔗 Advanced examples

Perform exact nearest neighbor search to find the nearest neighbor of every point in the cloud dataset, using 32-bit floats to represent the data.

// See https://datasets.mlpack.org/cloud.csv.
arma::fmat dataset;
mlpack::Load("cloud.csv", dataset);

// Construct the KNN object; this will avoid copies via std::move(), and build a
// kd-tree on the dataset.
mlpack::KNNType<mlpack::EuclideanDistance, mlpack::KDTree, arma::fmat> knn(
    std::move(dataset));

arma::fmat distances; // This type is arma::fmat, just like the dataset.
arma::Mat<size_t> neighbors;

// Compute the exact nearest neighbor.
knn.Search(1, neighbors, distances);

// Print the nearest neighbor and distance of the fifth point in the dataset.
std::cout << "Point 4:" << std::endl;
std::cout << " - Point values: " << knn.ReferenceSet().col(4).t();
std::cout << " - Index of nearest neighbor: " << neighbors(0, 4) << "."
    << std::endl;
std::cout << " - Distance to nearest neighbor: " << distances(0, 4) << "."
    << std::endl;

Perform approximate single-tree nearest neighbor search using the Chebyshev (L-infinity) distance as the distance metric.

// See https://datasets.mlpack.org/cloud.csv.
arma::mat dataset;
mlpack::Load("cloud.csv", dataset);

// Construct the KNN object; this will avoid copies via std::move(), and build a
// kd-tree on the dataset.
mlpack::KNNType<mlpack::ChebyshevDistance> knn(std::move(dataset),
    mlpack::SINGLE_TREE, 0.2);

arma::mat distances;
arma::Mat<size_t> neighbors;

// Compute the exact nearest neighbor.
knn.Search(1, neighbors, distances);

// Print the nearest neighbor and distance of the fifth point in the dataset.
std::cout << "Point 4:" << std::endl;
std::cout << " - Point values: " << knn.ReferenceSet().col(4).t();
std::cout << " - Index of approximate nearest neighbor: " << neighbors(0, 4)
    << "." << std::endl;
std::cout << " - Chebyshev distance to approximate nearest neighbor: "
    << distances(0, 4) << "." << std::endl;

Use an Octree (a tree known to be faster in very few dimensions) to find the exact nearest neighbors of all the points in a tiny subset of the 3-dimensional LCDM dataset.

// See https://datasets.mlpack.org/lcdm_tiny.csv.
arma::mat dataset;
mlpack::Load("lcdm_tiny.csv", dataset);

// Construct the KNN object with Octrees.
mlpack::KNNType<mlpack::EuclideanDistance, mlpack::Octree> knn(
    std::move(dataset));

arma::mat distances;
arma::Mat<size_t> neighbors;

// Find the exact nearest neighbor of every point.
knn.Search(1, neighbors, distances);

// Print the average, minimum, and maximum nearest neighbor distances.
std::cout << "Average nearest neighbor distance: " <<
    arma::mean(arma::vectorise(distances)) << "." << std::endl;
std::cout << "Minimum nearest neighbor distance: " <<
    arma::min(arma::vectorise(distances)) << "." << std::endl;
std::cout << "Maximum nearest neighbor distance: " <<
    arma::max(arma::vectorise(distances)) << "." << std::endl;

Using a 32-bit floating point representation, split the lcdm_tiny dataset into a query and a reference set, and then use KNN with preconstructed random projection trees (RPTree) to find the 5 approximate nearest neighbors of each point in the query set with the Manhattan distance. Then, compute exact nearest neighbors and the average error and recall.

// See https://datasets.mlpack.org/lcdm_tiny.csv.
arma::fmat dataset;
mlpack::Load("lcdm_tiny.csv", dataset);

// Split the dataset into a query set and a reference set (each with the same
// size).
arma::fmat referenceSet, querySet;
mlpack::Split(dataset, referenceSet, querySet, 0.5);

// This is the type of tree we will build on the datasets.
using TreeType = mlpack::RPTree<mlpack::ManhattanDistance,
                                mlpack::NearestNeighborStat,
                                arma::fmat>;

// Note that we could also define TreeType as below (it is the same type!):
//
// using TreeType = mlpack::KNNType<mlpack::ManhattanDistance,
//                                  mlpack::RPTree,
//                                  arma::fmat>::Tree;

// We build the trees here with std::move() in order to avoid copying data.
//
// For RPTrees, this reorders the points in the dataset, but if original indices
// are needed, trees can be constructed with mapping objects.  (See the RPTree
// documentation for more details.)
TreeType referenceTree(std::move(referenceSet));
TreeType queryTree(std::move(querySet));

// Construct the KNN object with the prebuilt reference tree.
mlpack::KNNType<mlpack::ManhattanDistance, mlpack::RPTree, arma::fmat> knn(
    std::move(referenceTree), mlpack::DUAL_TREE, 0.1 /* max 10% error */);

// Find 5 approximate nearest neighbors.
arma::fmat approxDistances;
arma::Mat<size_t> approxNeighbors;
knn.Search(queryTree, 5, approxNeighbors, approxDistances);
std::cout << "Computed approximate neighbors." << std::endl;

// Now compute exact neighbors.  When reusing the query tree, this requires
// resetting the statistics inside the query tree manually.
arma::fmat exactDistances;
arma::Mat<size_t> exactNeighbors;
knn.ResetTree(queryTree);
knn.Epsilon() = 0.0; // Error tolerance is now 0% (exact search).
knn.Search(queryTree, 5, exactNeighbors, exactDistances);
std::cout << "Computed exact neighbors." << std::endl;

// Compute error measures.
const double recall = knn.Recall(approxNeighbors, exactNeighbors);
const double effectiveError = knn.EffectiveError(approxDistances,
    exactDistances);

std::cout << "Recall of approximate search: " << recall << "." << std::endl;
std::cout << "Effective relative error of approximate search: "
    << effectiveError << " (vs. limit of 0.1)." << std::endl;

Use spill trees to perform greedy single-tree approximate nearest neighbor search on the cloud dataset, and compare with the other spill tree traversers and exact results. Compare with the results in the simple examples section where the default KDTree is used—spill trees perform significantly better for greedy search!

// See https://datasets.mlpack.org/cloud.csv.
arma::mat dataset;
mlpack::Load("cloud.csv", dataset);

// Build a spill tree on the dataset and set the search strategy to the greedy
// single tree strategy.  We will build the tree manually, so that we can
// configure the build-time parameters (see the SPTree documentation for more
// details).
using TreeType = mlpack::SPTree<mlpack::EuclideanDistance,
                                mlpack::NearestNeighborStat,
                                arma::mat>;
TreeType referenceTree(std::move(dataset), 10.0 /* tau, overlap parameter */);

mlpack::KNNType<mlpack::EuclideanDistance,
                mlpack::SPTree,
                arma::mat,
                // Use the special defeatist spill tree traversers.
                TreeType::template DefeatistDualTreeTraverser,
                TreeType::template DefeatistSingleTreeTraverser> knn(
    std::move(referenceTree),
    mlpack::GREEDY_SINGLE_TREE);

arma::mat greedyDistances, dualDistances, singleDistances, exactDistances;
arma::Mat<size_t> greedyNeighbors, dualNeighbors, singleNeighbors,
    exactNeighbors;

// Compute the 5 approximate nearest neighbors of every point in the dataset.
knn.Search(5, greedyNeighbors, greedyDistances);

std::cout << "Greedy approximate kNN search computed " << knn.BaseCases()
    << " point-to-point distances and visited " << knn.Scores()
    << " tree nodes in total." << std::endl;

// Now do the same thing, but with defeatist dual-tree search.  Note that
// defeatist dual-tree search is not backtracking, so we don't need to set
// knn.Epsilon().
knn.SearchStrategy() = mlpack::DUAL_TREE;
knn.Search(5, dualNeighbors, dualDistances);

std::cout << "Dual-tree approximate kNN search computed " << knn.BaseCases()
    << " point-to-point distances and visited " << knn.Scores()
    << " tree nodes in total." << std::endl;

// Finally, use defeatist single-tree search.
knn.SearchStrategy() = mlpack::SINGLE_TREE;
knn.Search(5, singleNeighbors, singleDistances);

std::cout << "Single-tree approximate kNN search computed " << knn.BaseCases()
    << " point-to-point distances and visited " << knn.Scores()
    << " tree nodes in total." << std::endl;

// Now switch to the exact naive strategy and compute the true neighbors and
// distances.
knn.SearchStrategy() = mlpack::NAIVE;
knn.Epsilon() = 0.0;
knn.Search(5, exactNeighbors, exactDistances);

// Compute the recall and effective error for each strategy.
const double greedyRecall = knn.Recall(greedyNeighbors, exactNeighbors);
const double dualRecall = knn.Recall(dualNeighbors, exactNeighbors);
const double singleRecall = knn.Recall(singleNeighbors, exactNeighbors);

const double greedyError = knn.EffectiveError(greedyDistances, exactDistances);
const double dualError = knn.EffectiveError(dualDistances, exactDistances);
const double singleError = knn.EffectiveError(singleDistances, exactDistances);

// Print the results.  To tune the results, try constructing the SPTrees
// manually and specifying different construction parameters.
std::cout << std::endl;
std::cout << "Recall with spill trees:" << std::endl;
std::cout << " - Greedy search:      " << greedyRecall << "." << std::endl;
std::cout << " - Dual-tree search:   " << dualRecall << "." << std::endl;
std::cout << " - Single-tree search: " << singleRecall << "." << std::endl;
std::cout << std::endl;
std::cout << "Effective error with spill trees:" << std::endl;
std::cout << " - Greedy search:      " << greedyError << "." << std::endl;
std::cout << " - Dual-tree search:   " << dualError << "." << std::endl;
std::cout << " - Single-tree search: " << singleError << "." << std::endl;