diff --git a/src/clusterable.rs b/src/clusterable.rs deleted file mode 100644 index f685424..0000000 --- a/src/clusterable.rs +++ /dev/null @@ -1,43 +0,0 @@ -use std::fmt::Debug; - -pub trait Clusterable where Self: Sized + Clone + PartialEq + Debug { - fn distance(&self, rhs: &Self) -> f64; - fn get_centroid<'a, I>(elements: I) -> Option - where I: Iterator, Self: 'a; -} - -macro_rules! impl_clusterable { - ( $type: ty) => { - impl Clusterable for $type { - fn distance(&self, rhs: &Self) -> f64 { - if self > rhs { - *self as f64 - *rhs as f64 - } else { - *rhs as f64 - *self as f64 - } - } - - fn get_centroid<'a, I>(elements: I) -> Option - where I: Iterator, Self: 'a { - - let mut tmp = 0.0; - let mut count = 0.0; - for element in elements { - tmp += element; - count += 1.0; - } - - if count > 0.0 { - Some(tmp / count) - } else { - None - } - - } - } - } -} - -impl_clusterable!(f32); -impl_clusterable!(f64); - diff --git a/src/example.rs b/src/example.rs index 81434c9..3898829 100644 --- a/src/example.rs +++ b/src/example.rs @@ -5,7 +5,12 @@ use std::fmt::{Display, Formatter, Result}; use std::fs::File; use rand::distributions::Range; use rand::distributions::normal::Normal; -use generic_kmeans::{equal_kmeans, Clusterable}; + +use generic_kmeans::kmeans; + +const CLUSTER_NUMBER: usize = 3; +const VAR: f64 = 1.0; +const MAX: f64 = 10.0; #[derive(PartialEq, Copy, Clone, Debug)] struct Vector2 { @@ -28,38 +33,69 @@ impl Display for Vector2 { } } -impl Clusterable for Vector2 { - fn distance(&self, other: &Self) -> f64 { - (self.x - other.x) * (self.x - other.x) + (self.y - other.y) * (self.y - other.y) +fn distance(v1: &Vector2, v2: &Vector2) -> f64 { + let dx = v2.x - v1.x; + let dy = v2.y - v1.y; + dx * dx + dy * dy +} + +fn centroid(elements: &Vec>) -> Vector2 { + let mut v = Vector2::new(0.0, 0.0); + + for element in elements { + v.x += element.x; + v.y += element.y; } - fn get_centroid<'a, I>(cluster: I) -> Option - where I: Iterator, Self: 'a { + v.x /= elements.len() as f64; + v.y /= elements.len() as f64; - let mut centroid = Vector2::new(0.0, 0.0); - let mut count = 0.0; + v +} - for i in cluster { - centroid.x += i.x; - centroid.y += i.y; - count += 1.0; - } +fn generate_points(centers: &mut Vec<(Vector2, Normal, Normal)>) -> Vec> { - if count > 0.0 { - centroid.x /= count as f64; - centroid.y /= count as f64; - Some(centroid) - } else { - None + let mut rng = rand::thread_rng(); + let mut output = vec![]; + + for &mut (_, x_rng, y_rng) in centers.iter_mut() { + for _ in 0..100 { + use rand::distributions::IndependentSample; + output.push(Vector2::new(x_rng.ind_sample(&mut rng), y_rng.ind_sample(&mut rng))); } } + + output +} + +fn generate_centers(number: usize) -> Vec<(Vector2, Normal, Normal)> { + + let mut output = vec![]; + + let range = Range::new(0.0, MAX); + let mut rng = rand::thread_rng(); + + for _ in 0..number { + + use rand::distributions::IndependentSample; + let center = Vector2::new(range.ind_sample(&mut rng), range.ind_sample(&mut rng)); + output.push((center, Normal::new(center.x, VAR), Normal::new(center.y, VAR))); + + } + + output } fn main() { - const VAR: f64 = 1.0; + let mut centers = generate_centers(CLUSTER_NUMBER); + let elements = generate_points(&mut centers); - use rand::distributions::IndependentSample; + let initial = vec![ + Vector2::new(0.0, 0.0), + Vector2::new(10.0, 0.0), + Vector2::new(0.0, 10.0), + ]; let colors = vec![ "blue", @@ -67,39 +103,13 @@ fn main() { "green", ]; - let range = Range::new(0.0, 10.0); - let mut rng = rand::thread_rng(); - let cluster_number = 3; + let kmeans = kmeans(elements, initial, &distance, ¢roid, 1000).unwrap(); - let mut centers = vec![]; - for _ in 0..cluster_number { - let center = Vector2::new(range.ind_sample(&mut rng), range.ind_sample(&mut rng)); - centers.push((center, Normal::new(center.x, VAR), Normal::new(center.y, VAR))); - } - - let mut elements = vec![]; - - for &mut (_, x_rng, y_rng) in centers.iter_mut() { - - for _ in 0..100 { - elements.push(Vector2::new(x_rng.ind_sample(&mut rng), y_rng.ind_sample(&mut rng))); - } - - } - - let initialization = vec![ - Vector2::new(0.0,0.0), - Vector2::new(10.0,0.0), - Vector2::new(0.0,10.0), - ]; - - - let (clusters, nb_iterations) = equal_kmeans(initialization, elements, 100000).ok().unwrap(); - let clusters = clusters.into_vec_vec(); + println!("Converged in {} iterations.", kmeans.1); let mut output = File::create("plot/dat.dat").unwrap(); - for (index, (cluster, color)) in clusters.iter().zip(colors.iter()).enumerate() { + for (index, (cluster, color)) in kmeans.0.iter().zip(colors.iter()).enumerate() { println!("Cluster {}: {} elements", index, cluster.len()); for element in cluster { use std::io::Write; @@ -107,11 +117,10 @@ fn main() { } } - println!("Finished in {} iterations", nb_iterations); - let mut center_file = File::create("plot/centers.dat").unwrap(); for (&(center, _, _), color) in centers.iter().zip(&colors) { use std::io::Write; writeln!(center_file, "{} {} {}", center.x, center.y, color).unwrap(); } + } diff --git a/src/kmeans.rs b/src/kmeans.rs deleted file mode 100644 index 75510b9..0000000 --- a/src/kmeans.rs +++ /dev/null @@ -1,220 +0,0 @@ -use std; -use std::collections::HashMap; -use std::slice::Iter; -use std::vec::IntoIter; -use std::iter::Zip; -use clusterable::Clusterable; - -pub struct Kmeans { - centroids: Vec, - elements: Vec, - labels: Vec, - cluster_number: usize, -} - -impl Kmeans { - pub fn new(centroids: Vec, data: Vec) -> Kmeans { - let labels = Kmeans::build_labels(¢roids, &data); - let cluster_number = centroids.len(); - Kmeans { - centroids: centroids, - elements: data, - labels: labels, - cluster_number: cluster_number - } - } - - /// \returns True if converged - pub fn iterate(&mut self) -> bool { - // Update the centroids - let centroids = Kmeans::build_centroids(&self.elements, &self.labels, self.cluster_number); - - if self.centroids == centroids { - true - } else { - self.centroids = centroids; - Kmeans::update_labels(&self.centroids, &self.elements, &mut self.labels); - false - } - } - - pub fn build_labels(centroids: &Vec, data: &Vec) -> Vec { - debug_assert_ne!(0, centroids.len()); - - let mut output = vec![0; data.len()]; - Kmeans::update_labels(centroids, data, &mut output); - output - } - - pub fn update_labels(centroids: &Vec, data: &Vec, labels: &mut Vec) { - for (element, new_label) in data.iter().zip(labels.iter_mut()) { - *new_label = centroids - .iter() - .enumerate() - .min_by(|&(_, c1), &(_, c2)| { - c1.distance(element).partial_cmp(&c2.distance(element)).unwrap() - }).unwrap().0; - } - } - - pub fn build_centroids(data: &Vec, labels: &Vec, cluster_number: usize) -> Vec { - - let mut centroids = vec![]; - for label in 0..cluster_number { - - let to_consider = data - .iter() - .enumerate() - .filter(|&(index, _)| labels[index] == label) - .map(|(_, element)| element); - - if let Some(centroid) = T::get_centroid(to_consider) { - centroids.push(centroid); - } - } - - centroids - } - - pub fn iter(&self) -> Zip, Iter> { - debug_assert_eq!(self.elements.len(), self.labels.len()); - return self.elements.iter().zip(self.labels.iter()); - } - - pub fn into_iter(self) -> Zip, IntoIter> { - debug_assert_eq!(self.elements.len(), self.labels.len()); - return self.elements.into_iter().zip(self.labels.into_iter()); - } - - pub fn into_vec_vec(self) -> Vec> { - let mut map = HashMap::new(); - - for (element, label) in self.into_iter() { - let mut entry = map.entry(label).or_insert(vec![]); - entry.push(element); - } - - let mut output = vec![]; - - for (_, cluster) in map { - let mut vec = vec![]; - for element in cluster { - vec.push(element); - } - output.push(vec); - } - - output - } - - pub fn to_vec_vec(&self) -> Vec<(Vec, usize)> { - let mut map = HashMap::new(); - - for (element, label) in self.iter() { - let mut entry = map.entry(*label).or_insert(vec![]); - entry.push(element.clone()); - } - - let mut output = vec![]; - - for (label, cluster) in map { - let mut vec = vec![]; - for element in cluster { - vec.push(element); - } - output.push((vec, label)); - } - - output - } - -} - -pub enum Error { - IterationsLimitExceeded, -} - -pub fn kmeans(centroids: Vec, data: Vec, max_iterations: usize) - -> Result<(Kmeans, usize), Error> { - - let mut kmeans = Kmeans::new(centroids, data); - - for nb_iterations in 0..max_iterations { - - let stable = kmeans.iterate(); - - if stable { - return Ok((kmeans, nb_iterations)); - } - - } - - Err(Error::IterationsLimitExceeded) -} - -pub fn equal_kmeans(centroids: Vec, data: Vec, max_iterations: usize) - -> Result<(Kmeans, usize), Error> { - - let number_of_elements = data.len(); - let (kmeans, nb_iterations) = kmeans(centroids, data, max_iterations)?; - let mut clusters = kmeans.to_vec_vec(); - - for &mut (ref mut cluster, ref mut label) in &mut clusters { - cluster.sort_by(|ref e1, ref e2| { - *&e1.distance(&kmeans.centroids[*label]) - .partial_cmp(&e2.distance(&kmeans.centroids[*label])).unwrap() - }); - } - - let max = (number_of_elements as f64 / clusters.len() as f64).ceil() as usize; - - let mut to_replace = vec![]; - for &mut (ref mut cluster, _) in &mut clusters { - while cluster.len() > max { - to_replace.push(cluster.pop().unwrap()); - } - } - - for element in to_replace { - // Find the best non full cluster to place it in - let mut best_distance = std::f64::MAX; - let mut best_index = 0; - - for (index, centroid) in kmeans.centroids.iter().enumerate() { - - let new_distance = element.distance(centroid); - if new_distance < best_distance && clusters[index].0.len() < max { - best_distance = new_distance; - best_index = index; - } - - } - - clusters[best_index].0.push(element); - } - - let mut elements = vec![]; - let mut labels = vec![]; - let mut centroids = vec![]; - - for (index, &(ref cluster, _)) in clusters.iter().enumerate() { - for element in cluster { - elements.push(element.clone()); - labels.push(index); - } - centroids.push(Clusterable::get_centroid(cluster.iter()).unwrap()); - } - - let len = centroids.len(); - - // Build a new k-means from clusters - let new_kmeans = Kmeans { - centroids: centroids, - elements: elements, - labels: labels, - cluster_number: len, - }; - - Ok((new_kmeans, nb_iterations)) - -} diff --git a/src/lib.rs b/src/lib.rs index 3c60614..340282c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,132 @@ -pub mod kmeans; -pub mod clusterable; +use std::collections::HashMap; -pub use kmeans::Kmeans; -pub use clusterable::Clusterable; -pub use kmeans::Error; -pub use kmeans::kmeans; -pub use kmeans::equal_kmeans; +#[derive(Debug)] +pub enum Error { + TooManyIterations, +} +type Distance = Fn(&T, &T) -> f64; +type Centroid = Fn(&Vec) -> T; +pub struct KmeansData { + pub elements: Vec, + pub labels: Vec, +} + +pub struct Kmeans<'a, T: 'a + Clone + PartialEq> { + data: KmeansData, + centroids: Vec, + distance: &'a Distance, + centroid: &'a Centroid, +} + +impl<'a, T: 'a + Clone + PartialEq> Kmeans<'a, T> { + pub fn new(data: Vec, centroids: Vec, distance: &'a Distance, centroid: &'a Centroid) -> Kmeans<'a, T> { + + let len = data.len(); + Kmeans { + data: KmeansData { elements: data, labels: vec![0; len] }, + centroids: centroids, + distance: distance, + centroid: centroid, + } + + } + + /// True if converged + pub fn iterate(&mut self) -> bool { + // Update the labels + self.update_labels(); + + // Update the centroids + let new_centroids = self.compute_centroids(); + + let ret = new_centroids == self.centroids; + self.centroids = new_centroids; + ret + } + + fn update_labels(&mut self) { + + let iterator = self.data.elements.iter().zip(self.data.labels.iter_mut()); + + for (ref element, ref mut label) in iterator { + + let mut best_distance = std::f64::MAX; + let mut best_label = 0; + + for (index, centroid) in self.centroids.iter().enumerate() { + + let new_distance = (self.distance)(*element, centroid); + if new_distance < best_distance { + best_distance = new_distance; + best_label = index; + } + } + + **label = best_label; + } + + } + + fn compute_centroids(&self) -> Vec { + let mut centroids_map = HashMap::new(); + + for (element, label) in self.data.elements.iter().zip(self.data.labels.iter()) { + let mut centroid = centroids_map.entry(label).or_insert(vec![]); + centroid.push(element.clone()); + } + + let mut new_centroids = vec![]; + + for (_, value) in centroids_map { + new_centroids.push((self.centroid)(&value)); + } + + new_centroids + } +} + +pub fn kmeans<'a, T: 'a + Clone + PartialEq>( + elements: Vec, + initial: Vec, + distance: &'a Distance, + centroid: &'a Centroid, + max_iteration: usize, +) -> Result<(Vec>,usize), Error> { + + let mut clusters = Kmeans::new(elements, initial, distance, centroid); + + let mut counter = 0; + + let iterations = loop { + + counter += 1; + + if clusters.iterate() || counter > max_iteration { + break counter; + } + + }; + + let mut map = HashMap::new(); + + for (element, label) in clusters.data.elements.iter().zip(clusters.data.labels.iter()) { + let mut centroid = map.entry(label).or_insert(vec![]); + centroid.push(element.clone()); + } + + let mut output = vec![]; + + for (_, value) in map { + let mut cluster = vec![]; + + for element in value { + cluster.push(element); + } + + output.push(cluster); + } + + Ok((output, iterations)) +}