From c9de1ca3337991de0035ae863e3939ba57f9df30 Mon Sep 17 00:00:00 2001 From: Thomas Forgione Date: Mon, 26 Feb 2018 10:32:25 +0100 Subject: [PATCH] Added equal kmeans... not very good though --- src/example.rs | 18 ++++--- src/kmeans.rs | 138 +++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 23 ++------- 3 files changed, 153 insertions(+), 26 deletions(-) diff --git a/src/example.rs b/src/example.rs index 79c5668..81434c9 100644 --- a/src/example.rs +++ b/src/example.rs @@ -5,7 +5,7 @@ use std::fmt::{Display, Formatter, Result}; use std::fs::File; use rand::distributions::Range; use rand::distributions::normal::Normal; -use generic_kmeans::{kmeans, Clusterable}; +use generic_kmeans::{equal_kmeans, Clusterable}; #[derive(PartialEq, Copy, Clone, Debug)] struct Vector2 { @@ -57,6 +57,8 @@ impl Clusterable for Vector2 { fn main() { + const VAR: f64 = 1.0; + use rand::distributions::IndependentSample; let colors = vec![ @@ -72,7 +74,7 @@ fn main() { let mut centers = vec![]; for _ in 0..cluster_number { let center = Vector2::new(range.ind_sample(&mut rng), range.ind_sample(&mut rng)); - centers.push((center, Normal::new(center.x, 0.5), Normal::new(center.y, 0.5))); + centers.push((center, Normal::new(center.x, VAR), Normal::new(center.y, VAR))); } let mut elements = vec![]; @@ -92,13 +94,17 @@ fn main() { ]; - let (clusters, nb_iterations) = kmeans(initialization, elements, 100000).ok().unwrap(); + let (clusters, nb_iterations) = equal_kmeans(initialization, elements, 100000).ok().unwrap(); + let clusters = clusters.into_vec_vec(); let mut output = File::create("plot/dat.dat").unwrap(); - for (element, &label) in clusters.iter() { - use std::io::Write; - writeln!(output, "{} {} {}", element.x, element.y, colors[label]).unwrap(); + for (index, (cluster, color)) in clusters.iter().zip(colors.iter()).enumerate() { + println!("Cluster {}: {} elements", index, cluster.len()); + for element in cluster { + use std::io::Write; + writeln!(output, "{} {} {}", element.x, element.y, color).unwrap(); + } } println!("Finished in {} iterations", nb_iterations); diff --git a/src/kmeans.rs b/src/kmeans.rs index a245dc2..75510b9 100644 --- a/src/kmeans.rs +++ b/src/kmeans.rs @@ -1,4 +1,7 @@ +use std; +use std::collections::HashMap; use std::slice::Iter; +use std::vec::IntoIter; use std::iter::Zip; use clusterable::Clusterable; @@ -78,5 +81,140 @@ impl Kmeans { return self.elements.iter().zip(self.labels.iter()); } + pub fn into_iter(self) -> Zip, IntoIter> { + debug_assert_eq!(self.elements.len(), self.labels.len()); + return self.elements.into_iter().zip(self.labels.into_iter()); + } + + pub fn into_vec_vec(self) -> Vec> { + let mut map = HashMap::new(); + + for (element, label) in self.into_iter() { + let mut entry = map.entry(label).or_insert(vec![]); + entry.push(element); + } + + let mut output = vec![]; + + for (_, cluster) in map { + let mut vec = vec![]; + for element in cluster { + vec.push(element); + } + output.push(vec); + } + + output + } + + pub fn to_vec_vec(&self) -> Vec<(Vec, usize)> { + let mut map = HashMap::new(); + + for (element, label) in self.iter() { + let mut entry = map.entry(*label).or_insert(vec![]); + entry.push(element.clone()); + } + + let mut output = vec![]; + + for (label, cluster) in map { + let mut vec = vec![]; + for element in cluster { + vec.push(element); + } + output.push((vec, label)); + } + + output + } + } +pub enum Error { + IterationsLimitExceeded, +} + +pub fn kmeans(centroids: Vec, data: Vec, max_iterations: usize) + -> Result<(Kmeans, usize), Error> { + + let mut kmeans = Kmeans::new(centroids, data); + + for nb_iterations in 0..max_iterations { + + let stable = kmeans.iterate(); + + if stable { + return Ok((kmeans, nb_iterations)); + } + + } + + Err(Error::IterationsLimitExceeded) +} + +pub fn equal_kmeans(centroids: Vec, data: Vec, max_iterations: usize) + -> Result<(Kmeans, usize), Error> { + + let number_of_elements = data.len(); + let (kmeans, nb_iterations) = kmeans(centroids, data, max_iterations)?; + let mut clusters = kmeans.to_vec_vec(); + + for &mut (ref mut cluster, ref mut label) in &mut clusters { + cluster.sort_by(|ref e1, ref e2| { + *&e1.distance(&kmeans.centroids[*label]) + .partial_cmp(&e2.distance(&kmeans.centroids[*label])).unwrap() + }); + } + + let max = (number_of_elements as f64 / clusters.len() as f64).ceil() as usize; + + let mut to_replace = vec![]; + for &mut (ref mut cluster, _) in &mut clusters { + while cluster.len() > max { + to_replace.push(cluster.pop().unwrap()); + } + } + + for element in to_replace { + // Find the best non full cluster to place it in + let mut best_distance = std::f64::MAX; + let mut best_index = 0; + + for (index, centroid) in kmeans.centroids.iter().enumerate() { + + let new_distance = element.distance(centroid); + if new_distance < best_distance && clusters[index].0.len() < max { + best_distance = new_distance; + best_index = index; + } + + } + + clusters[best_index].0.push(element); + } + + let mut elements = vec![]; + let mut labels = vec![]; + let mut centroids = vec![]; + + for (index, &(ref cluster, _)) in clusters.iter().enumerate() { + for element in cluster { + elements.push(element.clone()); + labels.push(index); + } + centroids.push(Clusterable::get_centroid(cluster.iter()).unwrap()); + } + + let len = centroids.len(); + + // Build a new k-means from clusters + let new_kmeans = Kmeans { + centroids: centroids, + elements: elements, + labels: labels, + cluster_number: len, + }; + + Ok((new_kmeans, nb_iterations)) + +} diff --git a/src/lib.rs b/src/lib.rs index dcc08ff..3c60614 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,25 +3,8 @@ pub mod clusterable; pub use kmeans::Kmeans; pub use clusterable::Clusterable; +pub use kmeans::Error; +pub use kmeans::kmeans; +pub use kmeans::equal_kmeans; -pub enum Error { - IterationsLimitExceeded, -} -pub fn kmeans(centroids: Vec, data: Vec, max_iterations: usize) - -> Result<(Kmeans, usize), Error> { - - let mut kmeans = Kmeans::new(centroids, data); - - for nb_iterations in 0..max_iterations { - - let stable = kmeans.iterate(); - - if stable { - return Ok((kmeans, nb_iterations)); - } - - } - - Err(Error::IterationsLimitExceeded) -}