83 lines
2.4 KiB
Rust
83 lines
2.4 KiB
Rust
use std::slice::Iter;
|
|
use std::iter::Zip;
|
|
use clusterable::Clusterable;
|
|
|
|
pub struct Kmeans<T: Clusterable> {
|
|
centroids: Vec<T>,
|
|
elements: Vec<T>,
|
|
labels: Vec<usize>,
|
|
cluster_number: usize,
|
|
}
|
|
|
|
impl<T:Clusterable> Kmeans<T> {
|
|
pub fn new(centroids: Vec<T>, data: Vec<T>) -> Kmeans<T> {
|
|
let labels = Kmeans::build_labels(¢roids, &data);
|
|
let cluster_number = centroids.len();
|
|
Kmeans {
|
|
centroids: centroids,
|
|
elements: data,
|
|
labels: labels,
|
|
cluster_number: cluster_number
|
|
}
|
|
}
|
|
|
|
/// \returns True if converged
|
|
pub fn iterate(&mut self) -> bool {
|
|
// Update the centroids
|
|
let centroids = Kmeans::build_centroids(&self.elements, &self.labels, self.cluster_number);
|
|
|
|
if self.centroids == centroids {
|
|
true
|
|
} else {
|
|
self.centroids = centroids;
|
|
Kmeans::update_labels(&self.centroids, &self.elements, &mut self.labels);
|
|
false
|
|
}
|
|
}
|
|
|
|
pub fn build_labels(centroids: &Vec<T>, data: &Vec<T>) -> Vec<usize> {
|
|
debug_assert_ne!(0, centroids.len());
|
|
|
|
let mut output = vec![0; data.len()];
|
|
Kmeans::update_labels(centroids, data, &mut output);
|
|
output
|
|
}
|
|
|
|
pub fn update_labels(centroids: &Vec<T>, data: &Vec<T>, labels: &mut Vec<usize>) {
|
|
for (element, new_label) in data.iter().zip(labels.iter_mut()) {
|
|
*new_label = centroids
|
|
.iter()
|
|
.enumerate()
|
|
.min_by(|&(_, c1), &(_, c2)| {
|
|
c1.distance(element).partial_cmp(&c2.distance(element)).unwrap()
|
|
}).unwrap().0;
|
|
}
|
|
}
|
|
|
|
pub fn build_centroids(data: &Vec<T>, labels: &Vec<usize>, cluster_number: usize) -> Vec<T> {
|
|
|
|
let mut centroids = vec![];
|
|
for label in 0..cluster_number {
|
|
|
|
let to_consider = data
|
|
.iter()
|
|
.enumerate()
|
|
.filter(|&(index, _)| labels[index] == label)
|
|
.map(|(_, element)| element);
|
|
|
|
if let Some(centroid) = T::get_centroid(to_consider) {
|
|
centroids.push(centroid);
|
|
}
|
|
}
|
|
|
|
centroids
|
|
}
|
|
|
|
pub fn iter(&self) -> Zip<Iter<T>, Iter<usize>> {
|
|
debug_assert_eq!(self.elements.len(), self.labels.len());
|
|
return self.elements.iter().zip(self.labels.iter());
|
|
}
|
|
|
|
}
|
|
|