commit c3aff5c77d9718d0ef1100d3483192ed481ba86d Author: Thomas Forgione Date: Fri Feb 16 15:34:13 2018 +0100 First try diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..143b1ca --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ + +/target/ +**/*.rs.bk +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..f883c0a --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,6 @@ +[package] +name = "kmeans" +version = "0.1.0" +authors = ["Thomas Forgione "] + +[dependencies] diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..26860e9 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,190 @@ +pub mod test; + +use std::marker::PhantomData; + +pub trait Clusterable where Self: Sized { + fn distance(&self, rhs: &Self) -> f64; + fn get_centroid(elements: &Vec) -> Option; +} + +macro_rules! impl_clusterable { + ( $type: ty) => { + impl Clusterable for $type { + fn distance(&self, rhs: &Self) -> f64 { + if self > rhs { + *self as f64 - *rhs as f64 + } else { + *rhs as f64 - *self as f64 + } + } + + fn get_centroid(elements: &Vec) -> Option { + + if elements.len() == 0 { + return None; + } + + let mut tmp = 0.0 as Self; + for element in elements { + tmp += *element as Self; + } + Some(tmp / elements.len() as Self) + + } + } + } +} + +impl_clusterable!(f32); +impl_clusterable!(f64); + + +pub type Cluster = Vec; + +pub struct Kmeans { + clusters: Vec>, +} + +impl Kmeans { + + pub fn new(centroids: Vec, data: KmeansIntoIter) -> Kmeans { + + let mut clusters = vec![]; + + for _ in ¢roids { + clusters.push(Cluster::new()); + } + + let mut new_kmeans: Kmeans = Kmeans { + clusters: clusters, + }; + + for element in data { + + // Compute the distance + let mut distance = std::f64::MAX; + let mut index = 0; + + for (new_index, centroid) in centroids.iter().enumerate() { + let new_distance = element.distance(centroid); + + if new_distance < distance { + distance = new_distance; + index = new_index; + } + } + + // Add element to the new kmeans + new_kmeans.clusters[index].push(element) + + } + + new_kmeans + } + + pub fn add_cluster(&mut self) { + self.clusters.push(Cluster::new()); + } + + pub fn iterate(self) -> Kmeans { + let mut centroids = vec![]; + + // Compute the centroids + for cluster in &self.clusters { + + if let Some(centroid) = T::get_centroid(cluster) { + centroids.push(centroid); + } + + } + + Kmeans::new(centroids, self.into_iter()) + } + + pub fn iter(&self) -> std::slice::Iter> { + self.clusters.iter() + } + +} + +pub struct KmeansIter<'a, T> where T:'a, T: Clusterable { + global_iter: std::slice::Iter<'a, std::vec::Vec>, + local_iter: Option>, + _phantom: PhantomData, +} + +impl<'a, T: 'a + Clusterable> Iterator for KmeansIter<'a, T> { + type Item = &'a T; + fn next(&mut self) -> Option<&'a T> { + if let Some(ref mut local_iter) = self.local_iter { + match local_iter.next() { + Some(t) => Some(t), + None => { + if let Some(next) = self.global_iter.next() { + *local_iter = next.iter(); + match local_iter.next() { + Some(t) => Some(t), + None => None, + } + } else { + None + } + } + } + } else { + self.local_iter = match self.global_iter.next() { + None => None, + Some(t) => Some(t.iter()), + }; + self.next() + } + } +} + +pub struct KmeansIntoIter where T: Clusterable { + global_iter: std::vec::IntoIter>, + local_iter: Option>, + _phantom: PhantomData, +} + +impl Iterator for KmeansIntoIter { + type Item = T; + fn next(&mut self) -> Option { + if let Some(ref mut local_iter) = self.local_iter { + match local_iter.next() { + Some(t) => Some(t), + None => { + if let Some(next) = self.global_iter.next() { + *local_iter = next.into_iter(); + match local_iter.next() { + Some(t) => Some(t), + None => None, + } + } else { + None + } + } + } + } else { + self.local_iter = match self.global_iter.next() { + None => None, + Some(t) => Some(t.into_iter()), + }; + self.next() + } + } +} + + + +impl IntoIterator for Kmeans { + type Item = T; + type IntoIter = KmeansIntoIter; + fn into_iter(self) -> Self::IntoIter { + Self::IntoIter { + global_iter: self.clusters.into_iter(), + local_iter: None, + _phantom: PhantomData, + } + } +} diff --git a/src/test.rs b/src/test.rs new file mode 100644 index 0000000..fdb2e52 --- /dev/null +++ b/src/test.rs @@ -0,0 +1,36 @@ +#[cfg(test)] +mod test { + #[test] + fn iterators() { + use Kmeans; + let data = vec![4.0, 5.0, 11.0, 12.0, 13.0]; + let kmeans = Kmeans { + clusters: vec![ + vec![4.0, 5.0], + vec![11.0, 12.0, 13.0], + ], + }; + + for (val1, val2) in kmeans.into_iter().zip(data) { + assert_eq!(val1, val2); + } + } + + #[test] + fn iterate() { + use Kmeans; + let kmeans = Kmeans { + clusters: vec![ + vec![4.0, 5.0, 11.0, 12.0], + vec![13.0], + ], + }; + let kmeans = kmeans.iterate(); + let kmeans = kmeans.iterate(); + let kmeans = kmeans.iterate(); + let kmeans = kmeans.iterate(); + + assert_eq!(kmeans.clusters[0], vec![4.0, 5.0]); + assert_eq!(kmeans.clusters[1], vec![11.0, 12.0, 13.0]); + } +}