From 77ea21607528ef1f2865a9ac843b1a5208f25a0e Mon Sep 17 00:00:00 2001 From: Thomas Forgione Date: Wed, 28 Feb 2018 15:08:36 +0100 Subject: [PATCH] Even cleaner --- src/example.rs | 7 ++--- src/lib.rs | 72 ++++++++++++++++++++++++++++---------------------- 2 files changed, 44 insertions(+), 35 deletions(-) diff --git a/src/example.rs b/src/example.rs index 3898829..7c251f2 100644 --- a/src/example.rs +++ b/src/example.rs @@ -103,13 +103,14 @@ fn main() { "green", ]; - let kmeans = kmeans(elements, initial, &distance, ¢roid, 1000).unwrap(); + let (kmeans, nb_iterations) = kmeans(elements, initial, &distance, ¢roid, 1000).unwrap(); + let kmeans = kmeans.into_iter(); - println!("Converged in {} iterations.", kmeans.1); + println!("Converged in {} iterations.", nb_iterations); let mut output = File::create("plot/dat.dat").unwrap(); - for (index, (cluster, color)) in kmeans.0.iter().zip(colors.iter()).enumerate() { + for (index, (cluster, color)) in kmeans.iter().zip(colors.iter()).enumerate() { println!("Cluster {}: {} elements", index, cluster.len()); for element in cluster { use std::io::Write; diff --git a/src/lib.rs b/src/lib.rs index 340282c..0665eee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +use std::slice::Iter; +use std::iter::Zip; use std::collections::HashMap; #[derive(Debug)] @@ -5,23 +7,20 @@ pub enum Error { TooManyIterations, } -type Distance = Fn(&T, &T) -> f64; -type Centroid = Fn(&Vec) -> T; - pub struct KmeansData { pub elements: Vec, pub labels: Vec, } -pub struct Kmeans<'a, T: 'a + Clone + PartialEq> { +pub struct Kmeans f64, C: Fn(&Vec) -> T> { data: KmeansData, centroids: Vec, - distance: &'a Distance, - centroid: &'a Centroid, + distance: D, + centroid: C, } -impl<'a, T: 'a + Clone + PartialEq> Kmeans<'a, T> { - pub fn new(data: Vec, centroids: Vec, distance: &'a Distance, centroid: &'a Centroid) -> Kmeans<'a, T> { +impl f64, C: Fn(&Vec) -> T> Kmeans { + pub fn new(data: Vec, centroids: Vec, distance: D, centroid: C) -> Kmeans { let len = data.len(); Kmeans { @@ -85,15 +84,42 @@ impl<'a, T: 'a + Clone + PartialEq> Kmeans<'a, T> { new_centroids } + + pub fn iter(&self) -> Zip, Iter> { + self.data.elements.iter().zip(self.data.labels.iter()) + } + + pub fn into_iter(self) -> Vec> { + let mut map = HashMap::new(); + + for (element, label) in self.data.elements.iter().zip(self.data.labels.iter()) { + let mut centroid = map.entry(label).or_insert(vec![]); + centroid.push(element.clone()); + } + + let mut output = vec![]; + + for (_, value) in map { + let mut cluster = vec![]; + + for element in value { + cluster.push(element); + } + + output.push(cluster); + } + + output + } } -pub fn kmeans<'a, T: 'a + Clone + PartialEq>( +pub fn kmeans f64, C: Fn(&Vec) -> T>( elements: Vec, initial: Vec, - distance: &'a Distance, - centroid: &'a Centroid, + distance: D, + centroid: C, max_iteration: usize, -) -> Result<(Vec>,usize), Error> { +) -> Result<(Kmeans,usize), Error> { let mut clusters = Kmeans::new(elements, initial, distance, centroid); @@ -109,24 +135,6 @@ pub fn kmeans<'a, T: 'a + Clone + PartialEq>( }; - let mut map = HashMap::new(); - - for (element, label) in clusters.data.elements.iter().zip(clusters.data.labels.iter()) { - let mut centroid = map.entry(label).or_insert(vec![]); - centroid.push(element.clone()); - } - - let mut output = vec![]; - - for (_, value) in map { - let mut cluster = vec![]; - - for element in value { - cluster.push(element); - } - - output.push(cluster); - } - - Ok((output, iterations)) + Ok((clusters, iterations)) } +