🔗 KNN: k-nearest-neighbor search
The KNN class implements k-nearest neighbor search, a core computational task
that is useful in many machine learning situations. Either exact or approximate
nearest neighbors can be computed. mlpack’s KNN class uses
trees, by default the KDTree,
to provide significantly accelerated computation; depending on input options, an
efficient dual-tree or single-tree algorithm is used.
The exact nearest neighbor of the query point q₀ is r₂.
The exact nearest neighbor of the query point q₁ is r₀.
Approximate search will not return exact neighbors, but results will be close;
e.g., r₃ could be returned as the approximate nearest neighbor of
q₁.
Given a reference set of points and a query set of queries, the KNN class
will compute the nearest neighbors in the reference set of every point in the
query set. If no query set is given, then KNN will find the nearest neighbors
of every point in the reference set; this is also called the
all-nearest-neighbors problem.
The KNN class supports configurable behavior, with numerous runtime and
compile-time parameters, including the distance metric, type of data, search
strategy, and tree type.
Note that the KNN class is not a classifier, and instead focuses on simply
computing the nearest neighbors of points.
Simple usage example:
// Compute the 5 exact nearest neighbors of every point of random numeric data.
// All data is uniform random: 10-dimensional data. Replace with a Load()
// call or similar for a real application.
arma::mat referenceSet(10, 1000, arma::fill::randu); // 1000 points.
mlpack::KNN knn; // Step 1: create object.
knn.Train(referenceSet); // Step 2: set the reference set.
arma::mat distances;
arma::Mat<size_t> neighbors;
knn.Search(5, neighbors, distances); // Step 3: find 5 nearest neighbors of
// every point in `referenceSet`.
// Note: you can also call `knn.Search(querySet, 5, neighbors, distances)` to
// find the nearest neighbors in `referenceSet` of a different set of points.
// Print some information about the results.
std::cout << "Found " << neighbors.n_rows << " neighbors for each of "
<< neighbors.n_cols << " points in the dataset." << std::endl;
Quick links:
- Constructors: create
KNNobjects. - Search strategies: details of search strategies
supported by
KNN. - Setting the reference set (
Train()): set the dataset that will be searched for nearest neighbors. - Searching for neighbors: call
Search()to compute nearest neighbors (exact or approximate). - Computing quality metrics to determine how accurate the computed nearest neighbors are, if approximate search was used.
- Other functionality for loading, saving, and inspecting.
- Examples of simple usage.
- Template parameters for configuring behavior, including distance metrics, tree types, and different element types.
- Advanced examples that make use of custom template parameters.
See also:
- mlpack trees
- mlpack geometric algorithms
- Nearest neighbor search on Wikipedia
- Tree-Independent Dual-Tree Algorithms (pdf)
KFN(k-furthest-neighbors)
🔗 Constructors
knn = KNN()knn = KNN(strategy=DUAL_TREE, epsilon=0)knn = KNN(referenceSet)knn = KNN(referenceSet, strategy=DUAL_TREE, epsilon=0)- Construct a
KNNobject on the given set of reference points, using the givenstrategyfor search andepsilonfor maximum relative approximation level. - This will build a
KDTreewith default parameters onreferenceSet, ifstrategyis notNAIVE. - If
referenceSetis not needed elsewhere, pass withstd::move()(e.g.std::move(referenceSet)) to avoid copyingreferenceSet. The dataset will still be accessible viaReferenceSet(), but points may be in shuffled order.
- Construct a
knn = KNN(referenceTree)knn = KNN(referenceTree, strategy=DUAL_TREE, epsilon=0)- Construct a
KNNobject with a pre-built treereferenceTree, which should be of typeKNN::Tree(a convenience typedef ofKDTreethat usesNearestNeighborStatas itsStatisticType). - The search strategy will be set to
strategyand maximum relative approximation level will be set toepsilon. - If
referenceTreeis not needed elsewhere, pass withstd::move()(e.g.std::move(referenceTree)) to avoid copyingreferenceTree. The tree will still be accessible viaReferenceTree().
- Construct a
Note: if std::move() is not used to pass referenceSet or
referenceTree, those objects will be copied—which can be expensive! Be sure
to use std::move() if possible.
Constructor Parameters:
| name | type | description | default |
|---|---|---|---|
referenceSet |
arma::mat |
Column-major matrix containing dataset to search for nearest neighbors in. | (N/A) |
referenceTree |
KNN::Tree (a KDTree) |
Pre-built kd-tree on reference data. | (N/A) |
strategy |
enum NeighborSearchStrategy |
The search strategy that will be used when Search() is called. Must be one of NAIVE, SINGLE_TREE, DUAL_TREE, or GREEDY_SINGLE_TREE. More details. |
DUAL_TREE |
epsilon |
double |
Allowed relative approximation error. 0 means exact search. Must be non-negative. |
0.0 |
Notes:
-
By default, exact nearest neighbors are found. When
strategyisSINGLE_TREEorDUAL_TREE, setepsilonto a positive value to enable approximation (higherepsilonmeans more approximation is allowed). See more in Search strategies, below. -
If constructing a tree manually, the
KNN::Treetype can be used (e.g.,tree = KNN::Tree(referenceData)).KNN::Treeis a convenience typedef of eitherKDTreeor the chosenTreeTypeif custom template parameters are being used.
🔗 Search strategies
The KNN class can search for nearest neighbors using one of the following four
strategies. These can be specified in the constructor as the strategy
parameter, or by calling knn.SearchStrategy() = strategy.
DUAL_TREE(default): two trees will be used at search time with a dual-tree algorithm (pdf) to allow the maximum amount of pruning.- This is generally the fastest strategy for exact search.
- Under some assumptions on the structure of the dataset and the tree type
being used, dual-tree search
scales linearly (pdf) (e.g.
O(1)time for each point whose nearest neighbors are being computed). - Backtracking search is performed to find either exact nearest neighbors, or
approximate nearest neighbors if
knn.Epsilon() > 0.
SINGLE_TREE: a tree built on the reference points will be traversed once for each point whose nearest neighbors are being searched for.- Single-tree search generally empirically scales logarithmically.
- Backtracking search is performed to find either exact nearest neighbors, or
approximate nearest neighbors if
knn.Epsilon() > 0.
GREEDY_SINGLE_TREE: for each point whose nearest neighbors are being searched for, a tree built on the reference points will be traversed in a greedy manner—recursing directly and only to the nearest node in the tree to find nearest neighbor candidates.- The approximation level with this strategy cannot be controlled; the
setting of
knn.Epsilon()is ignored. - Greedy single-tree search scales logarithmically (e.g.
O(log N)for each point whose neighbors are being computed, if the size of the reference set isN); however, since no backtracking is performed, results are obtained extremely efficiently. - This strategy is most effective when
spill trees are used; to do this, use
SPTreeor another spill tree variant as theTreeTypetemplate parameter.
- The approximation level with this strategy cannot be controlled; the
setting of
NAIVE: brute-force search—for each point whose nearest neighbors are being searched for, compute the distance to every point in the reference set.- This strategy always gives exact results; the setting of
knn.Epsilon()is ignored. - Brute-force search scales poorly, with a runtime cost of
O(N)per point, whereNis the size of the reference set. - However, brute-force search does not suffer from poor performance in high dimensions as trees often do.
- When this strategy is used, no tree structure is used.
- This strategy always gives exact results; the setting of
🔗 Setting the reference set (Train())
If the reference set was not set in the constructor, or if it needs to be
changed to a new reference set, the Train() method can be used.
knn.Train(referenceSet)- Set the reference set to
referenceSet. - This will build a
KDTreewith default parameters onreferenceSet, ifstrategyis notNAIVE. - If
referenceSetis not needed elsewhere, pass withstd::move()(e.g.std::move(referenceSet)) to avoid copyingreferenceSet. The dataset will still be accessible viaReferenceSet(), but points may be in shuffled order.
- Set the reference set to
knn.Train(referenceTree)- Set the reference tree to
referenceTree, which should be of typeKNN::Tree(a convenience typedef ofKDTreethat usesNearestNeighborStatas itsStatisticType). - If
referenceTreeis not needed elsewhere, pass withstd::move()(e.g.std::move(referenceTree)) to avoid copyingreferenceTree. The tree will still be accessible viaReferenceTree().
- Set the reference tree to
🔗 Searching for neighbors
Once the reference set and parameters are set, searching for nearest neighbors
can be done with the Search() method.
knn.Search(k, neighbors, distances)- Search for the
knearest neighbors of all points in the reference set (e.g.knn.ReferenceSet()), storing the results inneighborsanddistances. neighborsanddistanceswill be set to havekrows andknn.ReferenceSet().n_colscolumns.neighbors(i, j)(e.g. theith row andjth column ofneighbors) will hold the column index of theith nearest neighbor of thejth point inknn.ReferenceSet().- That is, the
ith nearest neighbor ofknn.ReferenceSet().col(j)isknn.ReferenceSet().col(neighbors(i, j)). distances(i, j)will hold the distance between thejth point inknn.ReferenceSet()and itsith nearest neighbor.
- Search for the
knn.Search(querySet, k, neighbors, distances)- Search for the
knearest neighbors in the reference set of all points inquerySet, storing the results inneighborsanddistances. neighborsanddistanceswill be set to havekrows andquerySet.n_colscolumns.neighbors(i, j)(e.g. theith row andjth column ofneighbors) will hold the column index inknn.ReferenceSet()of theith nearest neighbor of thejth point inquerySet.- That is, the
ith nearest neighbor ofquerySet.col(j)isknn.ReferenceSet().col(neighbors(i, j)). distances(i, j)will hold the distance between thejth point inquerySetand itsith nearest neighbor inknn.ReferenceSet().
- Search for the
knn.Search(queryTree, k, neighbors, distances, sameSet=false)- Search for the
knearest neighbors in the reference set of all points inqueryTree, storing the result inneighborsanddistances. neighborsanddistanceswill be set to havekrows andqueryTree.Dataset().n_colscolumns.neighbors(i, j)(e.g. theith row andjth column ofneighbors) will hold the column index inknn.ReferenceSet()of theith nearest neighbor of thejth point inqueryTree.Dataset().- That is, the
ith nearest neighbor ofqueryTree.Dataset().col(j)isknn.ReferenceSet().col(neighbors(i, j)). distances(i, j)will hold the distance between thejth point inqueryTree.Dataset()and itsith nearest neighbor inknn.ReferenceSet().- If
sameSetistrue, then the query set is understood to be the same as the reference set, and query points will not return their own index as their nearest neighbor.
- Search for the
Notes:
-
When
querySetandqueryTreeare not specified, or whensameSetistrue, a point will not return itself as its nearest neighbor. However, if there are duplicate pointsxandyin the dataset,ymay be returned as the nearest neighbor ofx. -
If
knn.Epsilon() > 0and the search strategy isDUAL_TREEorSINGLE_TREE, then the search will return approximate nearest neighbors within a relative distance ofknn.Epsilon()of the true nearest neighbor. -
knn.Epsilon()is ignored when the search strategy isGREEDY_SINGLE_TREEorNAIVE. -
When using a
queryTreemultiple times, the bounds in the tree must be reset. Callknn.ResetTree(queryTree)after each call toSearch()to reset the bounds, or callnode.Stat().Reset()on each node inqueryTree. -
When searching for approximate neighbors, it is possible that no nearest neighbor candidate will be found. If this is true, then the corresponding element in
neighborswill be set toSIZE_MAX(e.g.size_t(-1)), and the corresponding element indistanceswill be set toDBL_MAX.
Search Parameters:
| name | type | description |
|---|---|---|
querySet |
arma::mat |
Column-major matrix of query points for which the nearest neighbors in the reference set should be found. |
k |
size_t |
Number of nearest neighbors to search for. |
neighbors |
arma::Mat<size_t> |
Matrix to store indices of nearest neighbors into. Will be set to size k x N, where N is the number of points in the query set (if specified), or the reference set (if not). |
distances |
arma::mat |
Matrix to store distances to nearest neighbors into. Will be set to the same size as neighbors. |
sameSet |
bool |
(Only for Search() with a query set.) If true, then querySet is the same set as the reference set. |
🔗 Computing quality metrics
If approximate nearest neighbor search is performed (e.g. if
knn.Epsilon() > 0), and exact nearest neighbors are known, it is possible to
compute quality metrics of the approximate search.
double error = knn.EffectiveError(computedDistances, exactDistances)- Given a matrix of exact distances and computed approximate distances, both
with the same size (rows equal to
k, columns equal to the number of points in the query or reference set), compute the average relative error of the computed distances. - Any neighbors with distance either 0 (e.g. the same point) or
DBL_MAX(e.g. no neighbor candidate found) inexactDistanceswill be ignored for the computation. computedDistancesandexactDistancesshould be matrices produced byknn.Search().- When dual-tree or single-tree search was used,
errorwill be no greater thanknn.Epsilon().
- Given a matrix of exact distances and computed approximate distances, both
with the same size (rows equal to
double recall = knn.Recall(computedNeighbors, trueNeighbors)- Given a matrix containing indices of exact nearest neighbors and computed
approximate neighbors, both with the same size (rows equal to
k, columns equal to the number of points in the query or reference set), compute the recall (percentage of true neighbors found). computedNeighborsandtrueNeighborsshould be matrices produced byknn.Search().- The recall will be between
0.0and1.0, with1.0indicating perfect recall.
- Given a matrix containing indices of exact nearest neighbors and computed
approximate neighbors, both with the same size (rows equal to
🔗 Other functionality
knn.ReferenceSet()will return aconst arma::mat&representing the data points in the reference set. This matrix cannot be modified.- If a
custom
MatTypetemplate parameter has been specified, then the return type will beconst MatType&.
- If a
custom
knn.ReferenceTree()will return aKNN::Tree*(aKDTreewithNearestNeighborStatas theStatisticType).- This is the tree that will be used at search time, if the search strategy
is not
NAIVE. - If the search strategy was
NAIVEwhen the object was constructed, thenknn.ReferenceTree()will returnnullptr. - If a
custom
TreeTypetemplate parameter has been specified, thenKNN::Tree*will be that type of tree, not aKDTree.
- This is the tree that will be used at search time, if the search strategy
is not
knn.SearchStrategy()will return the search strategy that will be used whenknn.Search()is called.knn.SearchStrategy() = newStrategywill set the search strategy tonewStrategy.newStrategymust be one of the supported search strategies.
knn.Epsilon()returns adoublerepresenting the allowed level of approximation. If0andknn.SearchStrategy()is either dual- or single-tree search, thenknn.Search()will return exact results.knn.Epsilon() = epswill set the allowed level of approximation toeps.epsmust be greater than or equal to0.0.
-
After calling
knn.Search(),knn.BaseCases()will return asize_trepresenting the number of point-to-point distance computations that were performed, if a tree-traversing search strategy was used. -
After calling
knn.Search(),knn.Scores()will return asize_tindicating the number of tree nodes that were visited during search, if a tree-traversing search strategy was used. -
A
KNNobject can be serialized withSave()andLoad(). Note that for large reference sets, this will also serialize the dataset (knn.ReferenceSet()) and the tree (knn.Tree()), and so the resulting file may be quite large. KNN::Treeis a convenience typedef representing the type of the tree that is used for searching.- By default, this is a
KDTree; specifically:KNN::TreeisKDTree<EuclideanDistance, NearestNeighborStat, arma::mat>. - If a
custom
TreeType,DistanceType, and/orMatTypeare specified, thenKNNType<DistanceType, TreeType, MatType>::Tree = TreeType<DistanceType, NearestNeighborStat, MatType>. - A custom tree can be built and passed to
Train()or the constructor with, e.g.,tree = KNN::Tree(referenceSet)ortree = KNN::Tree(std::move(referenceSet)).
- By default, this is a
🔗 Simple examples
Find the exact nearest neighbor of every point in the cloud dataset.
// See https://datasets.mlpack.org/cloud.csv.
arma::mat dataset;
mlpack::Load("cloud.csv", dataset);
// Construct the KNN object; this will avoid copies via std::move(), and build a
// kd-tree on the dataset.
mlpack::KNN knn(std::move(dataset));
arma::mat distances;
arma::Mat<size_t> neighbors;
// Compute the exact nearest neighbor.
knn.Search(1, neighbors, distances);
// Print the nearest neighbor and distance of the fifth point in the dataset.
std::cout << "Point 4:" << std::endl;
std::cout << " - Point values: " << knn.ReferenceSet().col(4).t();
std::cout << " - Index of nearest neighbor: " << neighbors(0, 4) << "."
<< std::endl;
std::cout << " - Distance to nearest neighbor: " << distances(0, 4) << "."
<< std::endl;
Split the corel-histogram dataset into two sets, and find the exact nearest
neighbor in the first set of every point in the second set.
// See https://datasets.mlpack.org/corel-histogram.csv.
arma::mat dataset;
mlpack::Load("corel-histogram.csv", dataset);
// Split the dataset into two equal-sized sets randomly with `Split()`.
arma::mat referenceSet, querySet;
mlpack::Split(dataset, referenceSet, querySet, 0.5);
// Construct the KNN object, building a tree on the reference set. Copies are
// avoided by the use of `std::move()`.
mlpack::KNN knn(std::move(referenceSet));
arma::mat distances;
arma::Mat<size_t> neighbors;
// Compute the exact nearest neighbor in `referenceSet` of every point in
// `querySet`.
knn.Search(querySet, 1, neighbors, distances);
// Print information about the dual-tree traversal.
std::cout << "Dual-tree traversal computed " << knn.BaseCases()
<< " point-to-point distances and " << knn.Scores()
<< " tree node-to-tree node distances." << std::endl;
// Print information about nearest neighbors of the points in the query set.
std::cout << "The nearest neighbor of query point 3 is reference point index "
<< neighbors(0, 3) << ", with distance " << distances(0, 3) << "."
<< std::endl;
std::cout << "The L2-norm of reference point " << neighbors(0, 3) << " is "
<< arma::norm(knn.ReferenceSet().col(neighbors(0, 3)), 2) << "."
<< std::endl;
// Compute the average nearest neighbor distance for all points in the query
// set.
const double averageDist = arma::mean(arma::vectorise(distances));
std::cout << "Average distance between a query point and its nearest neighbor: "
<< averageDist << "." << std::endl;
Perform approximate single-tree search to find 5 nearest neighbors of the first
point in a subset of the LCDM dataset.
// See https://datasets.mlpack.org/lcdm_tiny.csv.
arma::mat dataset;
mlpack::Load("lcdm_tiny.csv", dataset);
// Build a KNN object on the LCDM dataset, and pass with `std::move()` so that
// we can avoid copying the dataset. Set the search strategy to single-tree.
mlpack::KNN knn(std::move(dataset), mlpack::SINGLE_TREE);
// Now we will compute the 5 nearest neighbors of the first point in the
// dataset.
//
// NOTE: because the first point is in the reference set, and because we are
// passing a separate query set, KNN will return that the nearest neighbor of
// the point is itself! This is an important caveat to be aware of when calling
// Search() with a query set.
arma::mat distances;
arma::Mat<size_t> neighbors;
knn.Search(knn.ReferenceSet().col(0), 5, neighbors, distances);
std::cout << "The five nearest neighbors of the first point in the LCDM dataset"
<< " are:" << std::endl;
for (size_t k = 0; k < 5; ++k)
{
std::cout << " - " << neighbors(k, 0) << " (with distance " << distances(k, 0)
<< ")." << std::endl;
if (k == 0)
{
std::cout << " (the first point's nearest neighbor is itself, because it"
<< " is in the reference set, and we called Query() with a separate "
<< "query set!)" << std::endl;
}
}
Use greedy single-tree search to find 5 approximate nearest neighbors of every
point in the cloud dataset. Then, compute the exact nearest neighbors, and
use these to find the average error and recall of the approximate search.
Note: greedy single-tree search is far more effective when using spill trees—see the advanced examples for another example that does exactly that.
// See https://datasets.mlpack.org/cloud.csv.
arma::mat dataset;
mlpack::Load("cloud.csv", dataset);
// Build a tree on the dataset and set the search strategy to the greedy single
// tree strategy.
mlpack::KNN knn(std::move(dataset), mlpack::GREEDY_SINGLE_TREE);
// Compute the 5 approximate nearest neighbors of every point in the dataset.
arma::mat distances;
arma::Mat<size_t> neighbors;
knn.Search(5, neighbors, distances);
std::cout << "Greedy approximate kNN search computed " << knn.BaseCases()
<< " point-to-point distances and visited " << knn.Scores()
<< " tree nodes in total." << std::endl;
// Now switch to exact computation and compute the true neighbors and distances.
arma::Mat<size_t> trueNeighbors;
arma::mat trueDistances;
knn.SearchStrategy() = mlpack::DUAL_TREE;
knn.Epsilon() = 0.0;
knn.Search(5, trueNeighbors, trueDistances);
// Compute the recall and effective error.
const double recall = knn.Recall(neighbors, trueNeighbors);
const double effectiveError = knn.EffectiveError(distances, trueDistances);
std::cout << "Recall of greedy search: " << recall << "." << std::endl;
std::cout << "Effective error of greedy search: " << effectiveError << "."
<< std::endl;
Build a KNN object on the cloud dataset and save it to disk.
// See https://datasets.mlpack.org/cloud.csv.
arma::mat dataset;
mlpack::Load("cloud.csv", dataset);
// Build the reference tree.
mlpack::KNN knn(std::move(dataset));
// Save the KNN object to disk with the name 'knn'.
mlpack::Save("knn.bin", "knn", knn);
std::cout << "Successfully saved KNN model to 'knn.bin'." << std::endl;
Load a KNN object from disk, and inspect the
KDTree that is held in the object.
// Load the KNN object with name 'knn' from 'knn.bin'.
mlpack::KNN knn;
mlpack::Load("knn.bin", knn);
// Inspect the KDTree held by the KNN object.
std::cout << "The KDTree in the KNN object in 'knn.bin' holds "
<< knn.ReferenceTree().NumDescendants() << " points." << std::endl;
std::cout << "The root of the tree has " << knn.ReferenceTree().NumChildren()
<< " children." << std::endl;
if (knn.ReferenceTree().NumChildren() == 2)
{
std::cout << " - The left child holds "
<< knn.ReferenceTree().Child(0).NumDescendants() << " points."
<< std::endl;
std::cout << " - The right child holds "
<< knn.ReferenceTree().Child(1).NumDescendants() << " points."
<< std::endl;
}
Compute the 5 approximate nearest neighbors of two subsets of the
corel-histogram dataset using a pre-built query tree. Then, reuse the query
tree to compute the exact neighbors and compute the effective error.
// See https://datasets.mlpack.org/corel-histogram.csv.
arma::mat dataset;
mlpack::Load("corel-histogram.csv", dataset);
// Split the covertype dataset into two parts of equal size.
arma::mat referenceSet, querySet;
mlpack::Split(dataset, referenceSet, querySet, 0.5);
// Build the KNN object, passing the reference set with `std::move()` to avoid a
// copy. We use the default dual-tree strategy for search and set the maximum
// allowed relative error to 0.1 (10%).
mlpack::KNN knn(std::move(referenceSet), mlpack::DUAL_TREE, 0.1);
// Now build a tree on the query points. This is a KDTree, and we manually
// specify a leaf size of 50 points. Note that the KDTree rearranges the
// ordering of points in the query set.
mlpack::KNN::Tree queryTree(std::move(querySet));
// Compute the 5 approximate nearest neighbors of all points in the query set.
arma::mat distances;
arma::Mat<size_t> neighbors;
knn.Search(queryTree, 5, neighbors, distances);
// Now compute the exact neighbors---but since we are using dual-tree search and
// an externally-constructed query tree, we must reset the bounds!
arma::mat trueDistances;
arma::Mat<size_t> trueNeighbors;
knn.ResetTree(queryTree);
knn.Epsilon() = 0;
knn.Search(queryTree, 5, trueNeighbors, trueDistances);
// Compute the effective error.
const double effectiveError = knn.EffectiveError(distances, trueDistances);
std::cout << "Effective error of approximate dual-tree search was "
<< effectiveError << " (limit via knn.Epsilon() was 0.1)." << std::endl;
🔗 Advanced functionality: template parameters
The KNN class is a typedef of the configurable KNNType class, which has five
template parameters that can be used for custom behavior. The full signature of
the class is:
KNNType<DistanceType,
TreeType,
MatType,
DualTreeTraversalType,
SingleTreeTraversalType>
DistanceType: specifies the distance metric to be used for finding nearest neighbors.TreeType: specifies the type of tree to be used for indexing points for fast tree-based search.MatType: specifies the type of matrix used for representation of data.DualTreeTraversalType: specifies the traversal strategy that will be used when searching with the dual-tree strategy.SingleTreeTraversalType: specifies the traversal strategy that will be used when searching with the dual-tree strategy.
When custom template parameters are specified:
- The
referenceSetandquerySetparameters to the constructor,Train(), andSearch()must have typeMatTypeinstead ofarma::mat. - The
distancesparameter toSearch()should have typeMatType. - When nearest neighbor candidates are not found during
Search(), the corresponding returned distance will be the maximum value supported by the element type ofMatType(e.g.DBL_MAXfordouble,FLT_MAXforfloat, etc.). - The convenience typedef
Tree(e.g.KNNType<DistanceType, TreeType, MatType, DualTreeTraversalType, SingleTreeTraversalType>::Tree) will be equivalent toTreeType<DistanceType, NearestNeighborStat, MatType>. - All tree parameters (
referenceTreeandqueryTree) should have typeTreeType<DistanceType, NearestNeighborStat, MatType>.
DistanceType
-
Specifies the distance metric that will be used when searching for nearest neighbors.
-
The default distance type is
EuclideanDistance. -
Many pre-implemented distance metrics are available for use, such as
ManhattanDistanceandChebyshevDistanceand others. -
Custom distance metrics are easy to implement, but must satisfy the triangle inequality to provide correct results when searching with trees (e.g.
knn.SearchStrategy()is notNAIVE).- NOTE: the cosine distance does not satisfy the triangle inequality.
TreeType
-
Specifies the tree type that will be built on the reference set (and possibly query set), if
knn.SearchStrategy()is notNAIVE. -
The default tree type is
KDTree. -
Numerous pre-implemented tree types are available for use.
-
Custom trees are very difficult to implement, but it is possible if desired.
- If you have implemented a fully-working
TreeTypeyourself, please contribute it upstream if possible!
- If you have implemented a fully-working
MatType
-
Specifies the type of matrix to use for representing data (the reference set and the query set).
-
The default
MatTypeisarma::mat(dense 64-bit precision matrix). -
Any matrix type implementing the Armadillo API will work; so, for instance,
arma::fmatorarma::sp_matcan also be used.
DualTreeTraversalType
-
Specifies the traversal strategy to use when performing a dual-tree search to find nearest neighbors (e.g. when
knn.SearchStrategy()isDUAL_TREE). -
By default, the
TreeType’s default dual-tree traversal (e.g.TreeType<DistanceType, NearestNeighborStat, MatType>::DualTreeTraversalType) will be used. -
In general, this parameter does not need to be specified, except when a custom type of traversal is desired.
- For instance, the
SpillTreeclass provides theDefeatistDualTreeTraversalstrategy, which is a specific greedy strategy for spill trees when performing approximate nearest neighbor search.
- For instance, the
SingleTreeTraversalType
-
Specifies the traversal strategy to use when performing a single-tree search to find nearest neighbors (e.g. when
knn.SearchStrategy()isSINGLE_TREE). -
By default, the
TreeType’s default dual-tree traversal (e.g.TreeType<DistanceType, NearestNeighborStat, MatType>::SingleTreeTraversalType) will be used. -
In general, this parameter does not need to be specified, except when a custom type of traversal is desired.
- For instance, the
SpillTreeclass provides theDefeatistSingleTreeTraversalstrategy, which is a specific greedy strategy for spill trees when performing approximate nearest neighbor search.
- For instance, the
🔗 Advanced examples
Perform exact nearest neighbor search to find the nearest neighbor of every
point in the cloud dataset, using 32-bit floats to represent the data.
// See https://datasets.mlpack.org/cloud.csv.
arma::fmat dataset;
mlpack::Load("cloud.csv", dataset);
// Construct the KNN object; this will avoid copies via std::move(), and build a
// kd-tree on the dataset.
mlpack::KNNType<mlpack::EuclideanDistance, mlpack::KDTree, arma::fmat> knn(
std::move(dataset));
arma::fmat distances; // This type is arma::fmat, just like the dataset.
arma::Mat<size_t> neighbors;
// Compute the exact nearest neighbor.
knn.Search(1, neighbors, distances);
// Print the nearest neighbor and distance of the fifth point in the dataset.
std::cout << "Point 4:" << std::endl;
std::cout << " - Point values: " << knn.ReferenceSet().col(4).t();
std::cout << " - Index of nearest neighbor: " << neighbors(0, 4) << "."
<< std::endl;
std::cout << " - Distance to nearest neighbor: " << distances(0, 4) << "."
<< std::endl;
Perform approximate single-tree nearest neighbor search using the Chebyshev (L-infinity) distance as the distance metric.
// See https://datasets.mlpack.org/cloud.csv.
arma::mat dataset;
mlpack::Load("cloud.csv", dataset);
// Construct the KNN object; this will avoid copies via std::move(), and build a
// kd-tree on the dataset.
mlpack::KNNType<mlpack::ChebyshevDistance> knn(std::move(dataset),
mlpack::SINGLE_TREE, 0.2);
arma::mat distances;
arma::Mat<size_t> neighbors;
// Compute the exact nearest neighbor.
knn.Search(1, neighbors, distances);
// Print the nearest neighbor and distance of the fifth point in the dataset.
std::cout << "Point 4:" << std::endl;
std::cout << " - Point values: " << knn.ReferenceSet().col(4).t();
std::cout << " - Index of approximate nearest neighbor: " << neighbors(0, 4)
<< "." << std::endl;
std::cout << " - Chebyshev distance to approximate nearest neighbor: "
<< distances(0, 4) << "." << std::endl;
Use an Octree (a tree known to be faster in very few dimensions) to find the exact nearest neighbors of all the points in a tiny subset of the 3-dimensional LCDM dataset.
// See https://datasets.mlpack.org/lcdm_tiny.csv.
arma::mat dataset;
mlpack::Load("lcdm_tiny.csv", dataset);
// Construct the KNN object with Octrees.
mlpack::KNNType<mlpack::EuclideanDistance, mlpack::Octree> knn(
std::move(dataset));
arma::mat distances;
arma::Mat<size_t> neighbors;
// Find the exact nearest neighbor of every point.
knn.Search(1, neighbors, distances);
// Print the average, minimum, and maximum nearest neighbor distances.
std::cout << "Average nearest neighbor distance: " <<
arma::mean(arma::vectorise(distances)) << "." << std::endl;
std::cout << "Minimum nearest neighbor distance: " <<
arma::min(arma::vectorise(distances)) << "." << std::endl;
std::cout << "Maximum nearest neighbor distance: " <<
arma::max(arma::vectorise(distances)) << "." << std::endl;
Using a 32-bit floating point representation, split the lcdm_tiny dataset into
a query and a reference set, and then use KNN with preconstructed random
projection trees (RPTree) to find the 5
approximate nearest neighbors of each point in the query set with the Manhattan
distance. Then, compute exact nearest neighbors and the average error and
recall.
// See https://datasets.mlpack.org/lcdm_tiny.csv.
arma::fmat dataset;
mlpack::Load("lcdm_tiny.csv", dataset);
// Split the dataset into a query set and a reference set (each with the same
// size).
arma::fmat referenceSet, querySet;
mlpack::Split(dataset, referenceSet, querySet, 0.5);
// This is the type of tree we will build on the datasets.
using TreeType = mlpack::RPTree<mlpack::ManhattanDistance,
mlpack::NearestNeighborStat,
arma::fmat>;
// Note that we could also define TreeType as below (it is the same type!):
//
// using TreeType = mlpack::KNNType<mlpack::ManhattanDistance,
// mlpack::RPTree,
// arma::fmat>::Tree;
// We build the trees here with std::move() in order to avoid copying data.
//
// For RPTrees, this reorders the points in the dataset, but if original indices
// are needed, trees can be constructed with mapping objects. (See the RPTree
// documentation for more details.)
TreeType referenceTree(std::move(referenceSet));
TreeType queryTree(std::move(querySet));
// Construct the KNN object with the prebuilt reference tree.
mlpack::KNNType<mlpack::ManhattanDistance, mlpack::RPTree, arma::fmat> knn(
std::move(referenceTree), mlpack::DUAL_TREE, 0.1 /* max 10% error */);
// Find 5 approximate nearest neighbors.
arma::fmat approxDistances;
arma::Mat<size_t> approxNeighbors;
knn.Search(queryTree, 5, approxNeighbors, approxDistances);
std::cout << "Computed approximate neighbors." << std::endl;
// Now compute exact neighbors. When reusing the query tree, this requires
// resetting the statistics inside the query tree manually.
arma::fmat exactDistances;
arma::Mat<size_t> exactNeighbors;
knn.ResetTree(queryTree);
knn.Epsilon() = 0.0; // Error tolerance is now 0% (exact search).
knn.Search(queryTree, 5, exactNeighbors, exactDistances);
std::cout << "Computed exact neighbors." << std::endl;
// Compute error measures.
const double recall = knn.Recall(approxNeighbors, exactNeighbors);
const double effectiveError = knn.EffectiveError(approxDistances,
exactDistances);
std::cout << "Recall of approximate search: " << recall << "." << std::endl;
std::cout << "Effective relative error of approximate search: "
<< effectiveError << " (vs. limit of 0.1)." << std::endl;
Use spill trees to perform greedy single-tree
approximate nearest neighbor search on the cloud dataset, and compare with
the other spill tree traversers and exact results. Compare with the results in
the simple examples section where the default
KDTree is used—spill trees perform significantly
better for greedy search!
// See https://datasets.mlpack.org/cloud.csv.
arma::mat dataset;
mlpack::Load("cloud.csv", dataset);
// Build a spill tree on the dataset and set the search strategy to the greedy
// single tree strategy. We will build the tree manually, so that we can
// configure the build-time parameters (see the SPTree documentation for more
// details).
using TreeType = mlpack::SPTree<mlpack::EuclideanDistance,
mlpack::NearestNeighborStat,
arma::mat>;
TreeType referenceTree(std::move(dataset), 10.0 /* tau, overlap parameter */);
mlpack::KNNType<mlpack::EuclideanDistance,
mlpack::SPTree,
arma::mat,
// Use the special defeatist spill tree traversers.
TreeType::template DefeatistDualTreeTraverser,
TreeType::template DefeatistSingleTreeTraverser> knn(
std::move(referenceTree),
mlpack::GREEDY_SINGLE_TREE);
arma::mat greedyDistances, dualDistances, singleDistances, exactDistances;
arma::Mat<size_t> greedyNeighbors, dualNeighbors, singleNeighbors,
exactNeighbors;
// Compute the 5 approximate nearest neighbors of every point in the dataset.
knn.Search(5, greedyNeighbors, greedyDistances);
std::cout << "Greedy approximate kNN search computed " << knn.BaseCases()
<< " point-to-point distances and visited " << knn.Scores()
<< " tree nodes in total." << std::endl;
// Now do the same thing, but with defeatist dual-tree search. Note that
// defeatist dual-tree search is not backtracking, so we don't need to set
// knn.Epsilon().
knn.SearchStrategy() = mlpack::DUAL_TREE;
knn.Search(5, dualNeighbors, dualDistances);
std::cout << "Dual-tree approximate kNN search computed " << knn.BaseCases()
<< " point-to-point distances and visited " << knn.Scores()
<< " tree nodes in total." << std::endl;
// Finally, use defeatist single-tree search.
knn.SearchStrategy() = mlpack::SINGLE_TREE;
knn.Search(5, singleNeighbors, singleDistances);
std::cout << "Single-tree approximate kNN search computed " << knn.BaseCases()
<< " point-to-point distances and visited " << knn.Scores()
<< " tree nodes in total." << std::endl;
// Now switch to the exact naive strategy and compute the true neighbors and
// distances.
knn.SearchStrategy() = mlpack::NAIVE;
knn.Epsilon() = 0.0;
knn.Search(5, exactNeighbors, exactDistances);
// Compute the recall and effective error for each strategy.
const double greedyRecall = knn.Recall(greedyNeighbors, exactNeighbors);
const double dualRecall = knn.Recall(dualNeighbors, exactNeighbors);
const double singleRecall = knn.Recall(singleNeighbors, exactNeighbors);
const double greedyError = knn.EffectiveError(greedyDistances, exactDistances);
const double dualError = knn.EffectiveError(dualDistances, exactDistances);
const double singleError = knn.EffectiveError(singleDistances, exactDistances);
// Print the results. To tune the results, try constructing the SPTrees
// manually and specifying different construction parameters.
std::cout << std::endl;
std::cout << "Recall with spill trees:" << std::endl;
std::cout << " - Greedy search: " << greedyRecall << "." << std::endl;
std::cout << " - Dual-tree search: " << dualRecall << "." << std::endl;
std::cout << " - Single-tree search: " << singleRecall << "." << std::endl;
std::cout << std::endl;
std::cout << "Effective error with spill trees:" << std::endl;
std::cout << " - Greedy search: " << greedyError << "." << std::endl;
std::cout << " - Dual-tree search: " << dualError << "." << std::endl;
std::cout << " - Single-tree search: " << singleError << "." << std::endl;