721
edits
Line 39: | Line 39: | ||
</syntaxhighlight> | </syntaxhighlight> | ||
Our classifier is written in a class called "KMeans" because it most closely resembles a KMeans classifier. | Our classifier is written in a class called "KMeans" because it most closely resembles a KMeans classifier. | ||
So inside the "KMeans.cpp" you will find the guts of our classifier. You can see that, when the KMeans classifier gets instantiated (=the contructor is called), we feed it with the points and the class labels | So inside the "KMeans.cpp" you will find the guts of our classifier. You can see that, when the KMeans classifier gets instantiated (=the contructor is called), we feed it with the points and the class labels and it immediately calculates the centroids for each class. | ||
<syntaxhighlight lang="c++"> | <syntaxhighlight lang="c++"> | ||
Line 45: | Line 45: | ||
{ | { | ||
_n_classes = n_classes; | _n_classes = n_classes; | ||
for(int c=0; c < n_classes; c++) | for(int c=0; c < n_classes; c++) // traverse once for every class | ||
{ | { | ||
_centroids.push_back(Point2D(0,0)); | _centroids.push_back(Point2D(0,0)); | ||
Line 62: | Line 62: | ||
_centroids[c] = _centroids[c]/numPoints; | _centroids[c] = _centroids[c]/numPoints; | ||
} | } | ||
} | |||
</syntaxhighlight> | |||
For that it traverses the vector of points a couple of times. Once for every class we have. In each of these traverses it only looks for points of the same class, sums them up, keeps a record of how many they were, and finally divides by the number of points it found. So we end up with the centroid for every class and save it in a vector of centroids called "_centroids". This is a private variable of our class KMeans. | |||
Now, when we classify, we know the centroids and just need to calculate the distance from our new point to every centroid. | |||
<syntaxhighlight lang="c++"> | |||
int KMeans::classify(Point2D newPoint) | |||
{ | |||
float min_distance = 99238719884798124; // just a biiig distance to start with | |||
int class_label = -1; // and a wrong class label | |||
for(int c=0; c<_n_classes ; c++) | |||
{ | |||
float distance = _centroids[c].getDistance(newPoint); | |||
if(distance < min_distance) | |||
{ | |||
min_distance = distance; | |||
class_label = c; | |||
} | |||
} | |||
return class_label; | |||
} | } | ||
</syntaxhighlight> | </syntaxhighlight> |