kmeans/src/kmeans.rs

83 lines
2.4 KiB
Rust

use std::slice::Iter;
use std::iter::Zip;
use clusterable::Clusterable;
pub struct Kmeans<T: Clusterable> {
centroids: Vec<T>,
elements: Vec<T>,
labels: Vec<usize>,
cluster_number: usize,
}
impl<T:Clusterable> Kmeans<T> {
pub fn new(centroids: Vec<T>, data: Vec<T>) -> Kmeans<T> {
let labels = Kmeans::build_labels(&centroids, &data);
let cluster_number = centroids.len();
Kmeans {
centroids: centroids,
elements: data,
labels: labels,
cluster_number: cluster_number
}
}
/// \returns True if converged
pub fn iterate(&mut self) -> bool {
// Update the centroids
let centroids = Kmeans::build_centroids(&self.elements, &self.labels, self.cluster_number);
if self.centroids == centroids {
true
} else {
self.centroids = centroids;
Kmeans::update_labels(&self.centroids, &self.elements, &mut self.labels);
false
}
}
pub fn build_labels(centroids: &Vec<T>, data: &Vec<T>) -> Vec<usize> {
debug_assert_ne!(0, centroids.len());
let mut output = vec![0; data.len()];
Kmeans::update_labels(centroids, data, &mut output);
output
}
pub fn update_labels(centroids: &Vec<T>, data: &Vec<T>, labels: &mut Vec<usize>) {
for (element, new_label) in data.iter().zip(labels.iter_mut()) {
*new_label = centroids
.iter()
.enumerate()
.min_by(|&(_, c1), &(_, c2)| {
c1.distance(element).partial_cmp(&c2.distance(element)).unwrap()
}).unwrap().0;
}
}
pub fn build_centroids(data: &Vec<T>, labels: &Vec<usize>, cluster_number: usize) -> Vec<T> {
let mut centroids = vec![];
for label in 0..cluster_number {
let to_consider = data
.iter()
.enumerate()
.filter(|&(index, _)| labels[index] == label)
.map(|(_, element)| element);
if let Some(centroid) = T::get_centroid(to_consider) {
centroids.push(centroid);
}
}
centroids
}
pub fn iter(&self) -> Zip<Iter<T>, Iter<usize>> {
debug_assert_eq!(self.elements.len(), self.labels.len());
return self.elements.iter().zip(self.labels.iter());
}
}