From 14455d3693f84fb822b7c6c42f855060c8d91b63 Mon Sep 17 00:00:00 2001 From: Thomas Forgione Date: Fri, 16 Feb 2018 17:21:05 +0100 Subject: [PATCH] Update, clean, test, example --- Cargo.toml | 7 +- src/cluster.rs | 38 +++++++++ src/example.rs | 82 ++++++++++++++++++++ src/kmeans.rs | 48 ++++++++++++ src/kmeansdata.rs | 168 ++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 193 ++++------------------------------------------ src/test.rs | 48 +++++++----- 7 files changed, 387 insertions(+), 197 deletions(-) create mode 100644 src/cluster.rs create mode 100644 src/example.rs create mode 100644 src/kmeans.rs create mode 100644 src/kmeansdata.rs diff --git a/Cargo.toml b/Cargo.toml index f883c0a..17c5fd0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,11 @@ [package] -name = "kmeans" +name = "generic_kmeans" version = "0.1.0" authors = ["Thomas Forgione "] [dependencies] + +[[bin]] +name = "example" +path = "src/example.rs" + diff --git a/src/cluster.rs b/src/cluster.rs new file mode 100644 index 0000000..7b0e8ae --- /dev/null +++ b/src/cluster.rs @@ -0,0 +1,38 @@ +pub type Cluster = Vec; + +pub trait Clusterable where Self: Sized + Clone + PartialEq { + fn distance(&self, rhs: &Self) -> f64; + fn get_centroid(elements: &Vec) -> Option; +} + +macro_rules! impl_clusterable { + ( $type: ty) => { + impl Clusterable for $type { + fn distance(&self, rhs: &Self) -> f64 { + if self > rhs { + *self as f64 - *rhs as f64 + } else { + *rhs as f64 - *self as f64 + } + } + + fn get_centroid(elements: &Vec) -> Option { + + if elements.len() == 0 { + return None; + } + + let mut tmp = 0.0 as Self; + for element in elements { + tmp += *element as Self; + } + Some(tmp / elements.len() as Self) + + } + } + } +} + +impl_clusterable!(f32); +impl_clusterable!(f64); + diff --git a/src/example.rs b/src/example.rs new file mode 100644 index 0000000..6620720 --- /dev/null +++ b/src/example.rs @@ -0,0 +1,82 @@ +extern crate generic_kmeans; + +use std::fmt::{Display, Formatter, Result}; + +use generic_kmeans::{kmeans, Clusterable}; + +#[derive(PartialEq, Clone)] +struct Vector2 { + pub x: T, + pub y: T, +} + +impl Vector2 { + pub fn new(x: T, y: T) -> Vector2 { + Vector2 { + x: x, + y: y, + } + } +} + +impl Display for Vector2 { + fn fmt(&self, formatter: &mut Formatter) -> Result { + write!(formatter, "({}, {})", self.x, self.y) + } +} + +impl Clusterable for Vector2 { + fn distance(&self, other: &Self) -> f64 { + (self.x - other.x) * (self.x - other.y) + (self.y - other.y) * (self.y - other.y) + } + + fn get_centroid(cluster: &Vec>) -> Option> { + + let len = cluster.len(); + + if len == 0 { + return None; + } + + let mut centroid = Vector2::new(0.0, 0.0); + + for i in cluster { + centroid.x += i.x; + centroid.y += i.y; + } + + centroid.x /= len as f64; + centroid.y /= len as f64; + + Some(centroid) + } +} + +fn main() { + + let elements = vec![ + Vector2::new(8.0, 3.0), + Vector2::new(9.0, 3.0), + Vector2::new(9.0, 2.0), + Vector2::new(1.0, 8.0), + Vector2::new(2.0, 9.0), + Vector2::new(3.0, 8.0), + ]; + + let initial = vec![ + Vector2::new(1.0, 10.0), + Vector2::new(10.0, 0.0), + ]; + + let (clusters, nb_iterations) = kmeans(initial, elements, 1000).ok().unwrap(); + + println!("{}", nb_iterations); + + for (index, cluster) in clusters.iter().enumerate() { + println!("CLUSTER {}", index); + + for element in cluster { + println!("\t{}", element); + } + } +} diff --git a/src/kmeans.rs b/src/kmeans.rs new file mode 100644 index 0000000..29a3c8d --- /dev/null +++ b/src/kmeans.rs @@ -0,0 +1,48 @@ +use std; + +use kmeansdata::KmeansData; +use cluster::{Cluster, Clusterable}; + +pub struct Kmeans { + pub centroids: Vec, + pub data: KmeansData, +} + +impl Kmeans { + pub fn new(centroids: Vec, data: Vec>) -> Kmeans { + Kmeans { + centroids: centroids, + data: KmeansData::from_clusters(data), + } + } + + pub fn guess_centroids(data: Vec>) -> Kmeans { + let mut centroids = vec![]; + for cluster in &data { + if let Some(centroid) = T::get_centroid(cluster) { + centroids.push(centroid); + } + } + + Kmeans::new(centroids, data) + + } + + fn from_data(centroids: Vec, data: KmeansData) -> Kmeans { + Kmeans { + centroids: centroids, + data: data, + } + } + + pub fn next_iteration(self) -> (Kmeans, bool) { + let (new_centroids, data) = self.data.iterate(&self.centroids); + let stable = new_centroids == self.centroids; + (Kmeans::from_data(new_centroids, data), stable) + } + + pub fn iter(&self) -> std::slice::Iter> { + self.data.clusters() + } +} + diff --git a/src/kmeansdata.rs b/src/kmeansdata.rs new file mode 100644 index 0000000..aca75ce --- /dev/null +++ b/src/kmeansdata.rs @@ -0,0 +1,168 @@ +use std; +use std::marker::PhantomData; + +use cluster::{Clusterable, Cluster}; + + +pub struct KmeansData { + clusters: Vec>, +} + +impl KmeansData { + + pub fn from_clusters(data: Vec>) -> KmeansData { + KmeansData { + clusters: data, + } + } + + pub fn new(centroids: &Vec, data: KmeansDataIntoIter) -> KmeansData { + + let mut clusters = vec![]; + + for _ in centroids { + clusters.push(Cluster::new()); + } + + let mut new_kmeans: KmeansData = KmeansData { + clusters: clusters, + }; + + for element in data { + + // Compute the distance + let mut distance = std::f64::MAX; + let mut index = 0; + + for (new_index, centroid) in centroids.iter().enumerate() { + let new_distance = element.distance(centroid); + + if new_distance < distance { + distance = new_distance; + index = new_index; + } + } + + // Add element to the new kmeans + new_kmeans.clusters[index].push(element) + + } + + new_kmeans + } + + pub fn add_cluster(&mut self) { + self.clusters.push(Cluster::new()); + } + + pub fn iterate(self, centroids: &Vec) -> (Vec, KmeansData) { + // Compute the result with the given centroids + let result = KmeansData::new(¢roids, self.into_iter()); + + // Compute the new centroids + let mut new_centroids = vec![]; + + for cluster in &result.clusters { + + if let Some(centroid) = T::get_centroid(cluster) { + new_centroids.push(centroid); + } + + } + + (new_centroids, result) + } + + pub fn iter(&self) -> KmeansDataIter { + KmeansDataIter { + global_iter: self.clusters.iter(), + local_iter: None, + _phantom: PhantomData, + } + } + + pub fn clusters(&self) -> std::slice::Iter> { + self.clusters.iter() + } + +} + +pub struct KmeansDataIter<'a, T> where T:'a, T: Clusterable { + global_iter: std::slice::Iter<'a, std::vec::Vec>, + local_iter: Option>, + _phantom: PhantomData, +} + +impl<'a, T: 'a + Clusterable> Iterator for KmeansDataIter<'a, T> { + type Item = &'a T; + fn next(&mut self) -> Option<&'a T> { + if let Some(ref mut local_iter) = self.local_iter { + match local_iter.next() { + Some(t) => Some(t), + None => { + if let Some(next) = self.global_iter.next() { + *local_iter = next.iter(); + match local_iter.next() { + Some(t) => Some(t), + None => None, + } + } else { + None + } + } + } + } else { + self.local_iter = match self.global_iter.next() { + None => None, + Some(t) => Some(t.iter()), + }; + self.next() + } + } +} + +pub struct KmeansDataIntoIter where T: Clusterable { + global_iter: std::vec::IntoIter>, + local_iter: Option>, + _phantom: PhantomData, +} + +impl Iterator for KmeansDataIntoIter { + type Item = T; + fn next(&mut self) -> Option { + if let Some(ref mut local_iter) = self.local_iter { + match local_iter.next() { + Some(t) => Some(t), + None => { + if let Some(next) = self.global_iter.next() { + *local_iter = next.into_iter(); + match local_iter.next() { + Some(t) => Some(t), + None => None, + } + } else { + None + } + } + } + } else { + self.local_iter = match self.global_iter.next() { + None => None, + Some(t) => Some(t.into_iter()), + }; + self.next() + } + } +} + +impl IntoIterator for KmeansData { + type Item = T; + type IntoIter = KmeansDataIntoIter; + fn into_iter(self) -> Self::IntoIter { + Self::IntoIter { + global_iter: self.clusters.into_iter(), + local_iter: None, + _phantom: PhantomData, + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 26860e9..84dd9d6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,190 +1,31 @@ +pub mod kmeans; +pub mod kmeansdata; +pub mod cluster; pub mod test; -use std::marker::PhantomData; +pub use kmeans::Kmeans; +pub use kmeansdata::KmeansData; +pub use cluster::{Cluster, Clusterable}; -pub trait Clusterable where Self: Sized { - fn distance(&self, rhs: &Self) -> f64; - fn get_centroid(elements: &Vec) -> Option; +pub enum Error { + IterationsLimitExceeded, } -macro_rules! impl_clusterable { - ( $type: ty) => { - impl Clusterable for $type { - fn distance(&self, rhs: &Self) -> f64 { - if self > rhs { - *self as f64 - *rhs as f64 - } else { - *rhs as f64 - *self as f64 - } - } +pub fn kmeans(centroids: Vec, data: Vec, max_iterations: usize) + -> Result<(Kmeans, usize), Error> { - fn get_centroid(elements: &Vec) -> Option { + let mut kmeans = Kmeans::new(centroids, vec![data]); - if elements.len() == 0 { - return None; - } + for nb_iterations in 0..max_iterations { - let mut tmp = 0.0 as Self; - for element in elements { - tmp += *element as Self; - } - Some(tmp / elements.len() as Self) + let (new_kmeans, stable) = kmeans.next_iteration(); + kmeans = new_kmeans; - } - } - } -} - -impl_clusterable!(f32); -impl_clusterable!(f64); - - -pub type Cluster = Vec; - -pub struct Kmeans { - clusters: Vec>, -} - -impl Kmeans { - - pub fn new(centroids: Vec, data: KmeansIntoIter) -> Kmeans { - - let mut clusters = vec![]; - - for _ in ¢roids { - clusters.push(Cluster::new()); + if stable { + return Ok((kmeans, nb_iterations)); } - let mut new_kmeans: Kmeans = Kmeans { - clusters: clusters, - }; - - for element in data { - - // Compute the distance - let mut distance = std::f64::MAX; - let mut index = 0; - - for (new_index, centroid) in centroids.iter().enumerate() { - let new_distance = element.distance(centroid); - - if new_distance < distance { - distance = new_distance; - index = new_index; - } - } - - // Add element to the new kmeans - new_kmeans.clusters[index].push(element) - - } - - new_kmeans - } - - pub fn add_cluster(&mut self) { - self.clusters.push(Cluster::new()); - } - - pub fn iterate(self) -> Kmeans { - let mut centroids = vec![]; - - // Compute the centroids - for cluster in &self.clusters { - - if let Some(centroid) = T::get_centroid(cluster) { - centroids.push(centroid); - } - - } - - Kmeans::new(centroids, self.into_iter()) - } - - pub fn iter(&self) -> std::slice::Iter> { - self.clusters.iter() } -} - -pub struct KmeansIter<'a, T> where T:'a, T: Clusterable { - global_iter: std::slice::Iter<'a, std::vec::Vec>, - local_iter: Option>, - _phantom: PhantomData, -} - -impl<'a, T: 'a + Clusterable> Iterator for KmeansIter<'a, T> { - type Item = &'a T; - fn next(&mut self) -> Option<&'a T> { - if let Some(ref mut local_iter) = self.local_iter { - match local_iter.next() { - Some(t) => Some(t), - None => { - if let Some(next) = self.global_iter.next() { - *local_iter = next.iter(); - match local_iter.next() { - Some(t) => Some(t), - None => None, - } - } else { - None - } - } - } - } else { - self.local_iter = match self.global_iter.next() { - None => None, - Some(t) => Some(t.iter()), - }; - self.next() - } - } -} - -pub struct KmeansIntoIter where T: Clusterable { - global_iter: std::vec::IntoIter>, - local_iter: Option>, - _phantom: PhantomData, -} - -impl Iterator for KmeansIntoIter { - type Item = T; - fn next(&mut self) -> Option { - if let Some(ref mut local_iter) = self.local_iter { - match local_iter.next() { - Some(t) => Some(t), - None => { - if let Some(next) = self.global_iter.next() { - *local_iter = next.into_iter(); - match local_iter.next() { - Some(t) => Some(t), - None => None, - } - } else { - None - } - } - } - } else { - self.local_iter = match self.global_iter.next() { - None => None, - Some(t) => Some(t.into_iter()), - }; - self.next() - } - } -} - - - -impl IntoIterator for Kmeans { - type Item = T; - type IntoIter = KmeansIntoIter; - fn into_iter(self) -> Self::IntoIter { - Self::IntoIter { - global_iter: self.clusters.into_iter(), - local_iter: None, - _phantom: PhantomData, - } - } + Err(Error::IterationsLimitExceeded) } diff --git a/src/test.rs b/src/test.rs index fdb2e52..e78028a 100644 --- a/src/test.rs +++ b/src/test.rs @@ -2,14 +2,12 @@ mod test { #[test] fn iterators() { - use Kmeans; + use kmeansdata::KmeansData; let data = vec![4.0, 5.0, 11.0, 12.0, 13.0]; - let kmeans = Kmeans { - clusters: vec![ - vec![4.0, 5.0], - vec![11.0, 12.0, 13.0], - ], - }; + let kmeans = KmeansData::from_clusters(vec![ + vec![4.0, 5.0], + vec![11.0, 12.0, 13.0], + ]); for (val1, val2) in kmeans.into_iter().zip(data) { assert_eq!(val1, val2); @@ -18,19 +16,29 @@ mod test { #[test] fn iterate() { - use Kmeans; - let kmeans = Kmeans { - clusters: vec![ - vec![4.0, 5.0, 11.0, 12.0], - vec![13.0], - ], - }; - let kmeans = kmeans.iterate(); - let kmeans = kmeans.iterate(); - let kmeans = kmeans.iterate(); - let kmeans = kmeans.iterate(); + use kmeans::Kmeans; - assert_eq!(kmeans.clusters[0], vec![4.0, 5.0]); - assert_eq!(kmeans.clusters[1], vec![11.0, 12.0, 13.0]); + let data = vec![ + vec![4.0, 5.0, 11.0, 12.0], + vec![13.0], + ]; + + let solution = vec![ + vec![4.0, 5.0], + vec![11.0, 12.0, 13.0], + ]; + + let mut kmeans = Kmeans::guess_centroids(data.clone()); + + for _ in 0..4 { + let (new_kmeans, stable) = kmeans.next_iteration(); + kmeans = new_kmeans; + } + + for (k1, k2) in kmeans.iter().zip(solution) { + for (i, j) in k1.iter().zip(&k2) { + assert_eq!(i, j); + } + } } }