use std::slice::Iter; use std::iter::Zip; use clusterable::Clusterable; pub struct Kmeans { centroids: Vec, elements: Vec, labels: Vec, cluster_number: usize, } impl Kmeans { pub fn new(centroids: Vec, data: Vec) -> Kmeans { let labels = Kmeans::build_labels(¢roids, &data); let cluster_number = centroids.len(); Kmeans { centroids: centroids, elements: data, labels: labels, cluster_number: cluster_number } } /// \returns True if converged pub fn iterate(&mut self) -> bool { // Update the centroids let centroids = Kmeans::build_centroids(&self.elements, &self.labels, self.cluster_number); if self.centroids == centroids { true } else { self.centroids = centroids; Kmeans::update_labels(&self.centroids, &self.elements, &mut self.labels); false } } pub fn build_labels(centroids: &Vec, data: &Vec) -> Vec { debug_assert_ne!(0, centroids.len()); let mut output = vec![0; data.len()]; Kmeans::update_labels(centroids, data, &mut output); output } pub fn update_labels(centroids: &Vec, data: &Vec, labels: &mut Vec) { for (element, new_label) in data.iter().zip(labels.iter_mut()) { *new_label = centroids .iter() .enumerate() .min_by(|&(_, c1), &(_, c2)| { c1.distance(element).partial_cmp(&c2.distance(element)).unwrap() }).unwrap().0; } } pub fn build_centroids(data: &Vec, labels: &Vec, cluster_number: usize) -> Vec { let mut centroids = vec![]; for label in 0..cluster_number { let to_consider = data .iter() .enumerate() .filter(|&(index, _)| labels[index] == label) .map(|(_, element)| element); if let Some(centroid) = T::get_centroid(to_consider) { centroids.push(centroid); } } centroids } pub fn iter(&self) -> Zip, Iter> { debug_assert_eq!(self.elements.len(), self.labels.len()); return self.elements.iter().zip(self.labels.iter()); } }