I am working on classifying simple data using KNN with Euclidean distance. I have seen an example on what I would like to do that is done with the MATLAB knnsearch
function as shown below:
load fisheriris
x = meas(:,3:4);
gscatter(x(:,1),x(:,2),species)
newpoint = [5 1.45];
[n,d] = knnsearch(x,newpoint,'k',10);
line(x(n,1),x(n,2),'color',[.5 .5 .5],'marker','o','linestyle','none','markersize',10)
The above code takes a new point i.e. [5 1.45]
and finds the 10 closest values to the new point. Can anyone please show me a MATLAB algorithm with a detailed explanation of what the knnsearch
function does? Is there any other way to do this?
The basis of the K-Nearest Neighbour (KNN) algorithm is that you have a data matrix that consists of N
rows and M
columns where N
is the number of data points that we have, while M
is the dimensionality of each data point. For example, if we placed Cartesian co-ordinates inside a data matrix, this is usually a N x 2
or a N x 3
matrix. With this data matrix, you provide a query point and you search for the closest k
points within this data matrix that are the closest to this query point.
We usually use the Euclidean distance between the query and the rest of your points in your data matrix to calculate our distances. However, other distances like the L1 or the City-Block / Manhattan distance are also used. After this operation, you will have N
Euclidean or Manhattan distances which symbolize the distances between the query with each corresponding point in the data set. Once you find these, you simply search for the k
nearest points to the query by sorting the distances in ascending order and retrieving those k
points that have the smallest distance between your data set and the query.
Supposing your data matrix was stored in x
, and newpoint
is a sample point where it has M
columns (i.e. 1 x M
), this is the general procedure you would follow in point form:
newpoint
and every point in x
.k
data points in x
that are closest to newpoint
.Let's do each step slowly.
One way that someone may do this is perhaps in a for
loop like so:
N = size(x,1);
dists = zeros(N,1);
for idx = 1 : N
dists(idx) = sqrt(sum((x(idx,:) - newpoint).^2));
end
If you wanted to implement the Manhattan distance, this would simply be:
N = size(x,1);
dists = zeros(N,1);
for idx = 1 : N
dists(idx) = sum(abs(x(idx,:) - newpoint));
end
dists
would be a N
element vector that contains the distances between each data point in x
and newpoint
. We do an element-by-element subtraction between newpoint
and a data point in x
, square the differences, then sum
them all together. This sum is then square rooted, which completes the Euclidean distance. For the Manhattan distance, you would perform an element by element subtraction, take the absolute values, then sum all of the components together. This is probably the most simplest of the implementations to understand, but it could possibly be the most inefficient... especially for larger sized data sets and larger dimensionality of your data.
Another possible solution would be to replicate newpoint
and make this matrix the same size as x
, then doing an element-by-element subtraction of this matrix, then summing over all of the columns for each row and doing the square root. Therefore, we can do something like this:
N = size(x, 1);
dists = sqrt(sum((x - repmat(newpoint, N, 1)).^2, 2));
For the Manhattan distance, you would do:
N = size(x, 1);
dists = sum(abs(x - repmat(newpoint, N, 1)), 2);
repmat
takes a matrix or vector and repeats them a certain amount of times in a given direction. In our case, we want to take our newpoint
vector, and stack this N
times on top of each other to create a N x M
matrix, where each row is M
elements long. We subtract these two matrices together, then square each component. Once we do this, we sum
over all of the columns for each row and finally take the square root of all result. For the Manhattan distance, we do the subtraction, take the absolute value and then sum.
However, the most efficient way to do this in my opinion would be to use bsxfun
. This essentially does the replication that we talked about under the hood with a single function call. Therefore, the code would simply be this:
dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
To me this looks much cleaner and to the point. For the Manhattan distance, you would do:
dists = sum(abs(bsxfun(@minus, x, newpoint)), 2);
Now that we have our distances, we simply sort them. We can use sort
to sort our distances:
[d,ind] = sort(dists);
d
would contain the distances sorted in ascending order, while ind
tells you for each value in the unsorted array where it appears in the sorted result. We need to use ind
, extract the first k
elements of this vector, then use ind
to index into our x
data matrix to return those points that were the closest to newpoint
.
The final step is to now return those k
data points that are closest to newpoint
. We can do this very simply by:
ind_closest = ind(1:k);
x_closest = x(ind_closest,:);
ind_closest
should contain the indices in the original data matrix x
that are the closest to newpoint
. Specifically, ind_closest
contains which rows you need to sample from in x
to obtain the closest points to newpoint
. x_closest
will contain those actual data points.
For your copying and pasting pleasure, this is what the code looks like:
dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
%// Or do this for Manhattan
% dists = sum(abs(bsxfun(@minus, x, newpoint)), 2);
[d,ind] = sort(dists);
ind_closest = ind(1:k);
x_closest = x(ind_closest,:);
Running through your example, let's see our code in action:
load fisheriris
x = meas(:,3:4);
newpoint = [5 1.45];
k = 10;
%// Use Euclidean
dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
[d,ind] = sort(dists);
ind_closest = ind(1:k);
x_closest = x(ind_closest,:);
By inspecting ind_closest
and x_closest
, this is what we get:
>> ind_closest
ind_closest =
120
53
73
134
84
77
78
51
64
87
>> x_closest
x_closest =
5.0000 1.5000
4.9000 1.5000
4.9000 1.5000
5.1000 1.5000
5.1000 1.6000
4.8000 1.4000
5.0000 1.7000
4.7000 1.4000
4.7000 1.4000
4.7000 1.5000
If you ran knnsearch
, you will see that your variable n
matches up with ind_closest
. However, the variable d
returns the distances from newpoint
to each point x
, not the actual data points themselves. If you want the actual distances, simply do the following after the code I wrote:
dist_sorted = d(1:k);
Note that the above answer uses only one query point in a batch of N
examples. Very frequently KNN is used on multiple examples simultaneously. Supposing that we have Q
query points that we want to test in the KNN. This would result in a k x M x Q
matrix where for each example or each slice, we return the k
closest points with a dimensionality of M
. Alternatively, we can return the IDs of the k
closest points thus resulting in a Q x k
matrix. Let's compute both.
A naive way to do this would be to apply the above code in a loop and loop over every example.
Something like this would work where we allocate a Q x k
matrix and apply the bsxfun
based approach to set each row of the output matrix to the k
closest points in the dataset, where we will use the Fisher Iris dataset just like what we had before. We'll also keep the same dimensionality as we did in the previous example and I'll use four examples, so Q = 4
and M = 2
:
%// Load the data and create the query points
load fisheriris;
x = meas(:,3:4);
newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5];
%// Define k and the output matrices
Q = size(newpoints, 1);
M = size(x, 2);
k = 10;
x_closest = zeros(k, M, Q);
ind_closest = zeros(Q, k);
%// Loop through each point and do logic as seen above:
for ii = 1 : Q
%// Get the point
newpoint = newpoints(ii, :);
%// Use Euclidean
dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
[d,ind] = sort(dists);
%// New - Output the IDs of the match as well as the points themselves
ind_closest(ii, :) = ind(1 : k).';
x_closest(:, :, ii) = x(ind_closest(ii, :), :);
end
Though this is very nice, we can do even better. There is a way to efficiently compute the squared Euclidean distance between two sets of vectors. I'll leave it as an exercise if you want to do this with the Manhattan. Consulting this blog, given that A
is a Q1 x M
matrix where each row is a point of dimensionality M
with Q1
points and B
is a Q2 x M
matrix where each row is also a point of dimensionality M
with Q2
points, we can efficiently compute a distance matrix D(i, j)
where the element at row i
and column j
denotes the distance between row i
of A
and row j
of B
using the following matrix formulation:
nA = sum(A.^2, 2); %// Sum of squares for each row of A
nB = sum(B.^2, 2); %// Sum of squares for each row of B
D = bsxfun(@plus, nA, nB.') - 2*A*B.'; %// Compute distance matrix
D = sqrt(D); %// Compute square root to complete calculation
Therefore, if we let A
be a matrix of query points and B
be the dataset consisting of your original data, we can determine the k
closest points by sorting each row individually and determining the k
locations of each row that were the smallest. We can also additionally use this to retrieve the actual points themselves.
Therefore:
%// Load the data and create the query points
load fisheriris;
x = meas(:,3:4);
newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5];
%// Define k and other variables
k = 10;
Q = size(newpoints, 1);
M = size(x, 2);
nA = sum(newpoints.^2, 2); %// Sum of squares for each row of A
nB = sum(x.^2, 2); %// Sum of squares for each row of B
D = bsxfun(@plus, nA, nB.') - 2*newpoints*x.'; %// Compute distance matrix
D = sqrt(D); %// Compute square root to complete calculation
%// Sort the distances
[d, ind] = sort(D, 2);
%// Get the indices of the closest distances
ind_closest = ind(:, 1:k);
%// Also get the nearest points
x_closest = permute(reshape(x(ind_closest(:), :).', M, k, []), [2 1 3]);
We see that we used the logic for computing the distance matrix is the same but some variables have changed to suit the example. We also sort each row independently using the two input version of sort
and so ind
will contain the IDs per row and d
will contain the corresponding distances. We then figure out which indices are the closest to each query point by simply truncating this matrix to k
columns. We then use permute
and reshape
to determine what the associated closest points are. We first use all of the closest indices and create a point matrix that stacks all of the IDs on top of each other so we get a Q * k x M
matrix. Using reshape
and permute
allows us to create our 3D matrix so that it becomes a k x M x Q
matrix like we have specified. If you wanted to get the actual distances themselves, we can index into d
and grab what we need. To do this, you will need to use sub2ind
to obtain the linear indices so we can index into d
in one shot. The values of ind_closest
already give us which columns we need to access. The rows we need to access are simply 1, k
times, 2, k
times, etc. up to Q
. k
is for the number of points we wanted to return:
row_indices = repmat((1:Q).', 1, k);
linear_ind = sub2ind(size(d), row_indices, ind_closest);
dist_sorted = D(linear_ind);
When we run the above code for the above query points, these are the indices, points and distances we get:
>> ind_closest
ind_closest =
120 134 53 73 84 77 78 51 64 87
123 119 118 106 132 108 131 136 126 110
107 62 86 122 71 127 139 115 60 52
99 65 58 94 60 61 80 44 54 72
>> x_closest
x_closest(:,:,1) =
5.0000 1.5000
6.7000 2.0000
4.5000 1.7000
3.0000 1.1000
5.1000 1.5000
6.9000 2.3000
4.2000 1.5000
3.6000 1.3000
4.9000 1.5000
6.7000 2.2000
x_closest(:,:,2) =
4.5000 1.6000
3.3000 1.0000
4.9000 1.5000
6.6000 2.1000
4.9000 2.0000
3.3000 1.0000
5.1000 1.6000
6.4000 2.0000
4.8000 1.8000
3.9000 1.4000
x_closest(:,:,3) =
4.8000 1.4000
6.3000 1.8000
4.8000 1.8000
3.5000 1.0000
5.0000 1.7000
6.1000 1.9000
4.8000 1.8000
3.5000 1.0000
4.7000 1.4000
6.1000 2.3000
x_closest(:,:,4) =
5.1000 2.4000
1.6000 0.6000
4.7000 1.4000
6.0000 1.8000
3.9000 1.4000
4.0000 1.3000
4.7000 1.5000
6.1000 2.5000
4.5000 1.5000
4.0000 1.3000
>> dist_sorted
dist_sorted =
0.0500 0.1118 0.1118 0.1118 0.1803 0.2062 0.2500 0.3041 0.3041 0.3041
0.3000 0.3162 0.3606 0.4123 0.6000 0.7280 0.9055 0.9487 1.0198 1.0296
0.9434 1.0198 1.0296 1.0296 1.0630 1.0630 1.0630 1.1045 1.1045 1.1180
2.6000 2.7203 2.8178 2.8178 2.8320 2.9155 2.9155 2.9275 2.9732 2.9732
To compare this with knnsearch
, you would instead specify a matrix of points for the second parameter where each row is a query point and you will see that the indices and sorted distances match between this implementation and knnsearch
.
Hope this helps you. Good luck!