I'm struggling to understand how to implement a least square linear classifier for my data in matlab. My data has N rows, each row is 10 columns wide. Each row represents a data point with 10 features. There are only two classes, the first N/2 rows of my test data is Class 1 and the rest are Class 2.
All explanations online about least squares make sense, but I'm not able to adapt them to my data, I just need a little bit of conceptual explanation relating to my data and the least square method.
The idea of using least squares to create a linear classifier is to define a linear function
f(x) = wTx
and adjust w
so that f(x)
is close to 1
for your data points of one class and close to -1
for the other class. The adjustment of w
is done by minimizing for each data point the squared distance between f(x)
and either 1
or -1
, depending on its class.
% Create a two-cluster data set with 100 points in each cluster
N = 100;
X1 = 0.3*bsxfun(@plus, randn(N, 2), [6 6]);
X2 = 0.6*bsxfun(@plus, randn(N, 2), [-2 -1]);
% Create a 200 by 3 data matrix similar to the one you have
% (see note below why 200 by 3 and not 200 by 2)
X = [[X1; X2] ones(2*N, 1)];
% Create 200 by 1 vector containing 1 for each point in the first cluster
% and -1 for each point in the second cluster
b = [ones(N, 1); -ones(N, 1)]
% Solve least squares problem
z = lsqlin(X, b);
% Plot data points and linear separator found above
y = -z(3)/z(2) - (z(1)/z(2))*x;
hold on;
plot(X(:, 1), X(:, 2), 'bx'); xlim([-3 3]); ylim([-3 3]);
plot(x, y, 'r');
I have added an additional column of ones to the data matrix in order to allow for a shift of the separator, thus making it a little more versatile. If you don't do this, you force the separator to pass through the origin, which will more often than not result in worse classification results.