Even cleaner

This commit is contained in:
Thomas Forgione 2018-02-28 15:08:36 +01:00
parent 35fec4243d
commit 77ea216075
No known key found for this signature in database
GPG Key ID: C75CD416BD1FFCE1
2 changed files with 44 additions and 35 deletions

View File

@ -103,13 +103,14 @@ fn main() {
"green", "green",
]; ];
let kmeans = kmeans(elements, initial, &distance, &centroid, 1000).unwrap(); let (kmeans, nb_iterations) = kmeans(elements, initial, &distance, &centroid, 1000).unwrap();
let kmeans = kmeans.into_iter();
println!("Converged in {} iterations.", kmeans.1); println!("Converged in {} iterations.", nb_iterations);
let mut output = File::create("plot/dat.dat").unwrap(); let mut output = File::create("plot/dat.dat").unwrap();
for (index, (cluster, color)) in kmeans.0.iter().zip(colors.iter()).enumerate() { for (index, (cluster, color)) in kmeans.iter().zip(colors.iter()).enumerate() {
println!("Cluster {}: {} elements", index, cluster.len()); println!("Cluster {}: {} elements", index, cluster.len());
for element in cluster { for element in cluster {
use std::io::Write; use std::io::Write;

View File

@ -1,3 +1,5 @@
use std::slice::Iter;
use std::iter::Zip;
use std::collections::HashMap; use std::collections::HashMap;
#[derive(Debug)] #[derive(Debug)]
@ -5,23 +7,20 @@ pub enum Error {
TooManyIterations, TooManyIterations,
} }
type Distance<T> = Fn(&T, &T) -> f64;
type Centroid<T> = Fn(&Vec<T>) -> T;
pub struct KmeansData<T: Clone + PartialEq> { pub struct KmeansData<T: Clone + PartialEq> {
pub elements: Vec<T>, pub elements: Vec<T>,
pub labels: Vec<usize>, pub labels: Vec<usize>,
} }
pub struct Kmeans<'a, T: 'a + Clone + PartialEq> { pub struct Kmeans<T: Clone + PartialEq, D: Fn(&T, &T) -> f64, C: Fn(&Vec<T>) -> T> {
data: KmeansData<T>, data: KmeansData<T>,
centroids: Vec<T>, centroids: Vec<T>,
distance: &'a Distance<T>, distance: D,
centroid: &'a Centroid<T>, centroid: C,
} }
impl<'a, T: 'a + Clone + PartialEq> Kmeans<'a, T> { impl<T: Clone + PartialEq, D: Fn(&T, &T) -> f64, C: Fn(&Vec<T>) -> T> Kmeans<T, D, C> {
pub fn new(data: Vec<T>, centroids: Vec<T>, distance: &'a Distance<T>, centroid: &'a Centroid<T>) -> Kmeans<'a, T> { pub fn new(data: Vec<T>, centroids: Vec<T>, distance: D, centroid: C) -> Kmeans<T, D, C> {
let len = data.len(); let len = data.len();
Kmeans { Kmeans {
@ -85,15 +84,42 @@ impl<'a, T: 'a + Clone + PartialEq> Kmeans<'a, T> {
new_centroids new_centroids
} }
pub fn iter(&self) -> Zip<Iter<T>, Iter<usize>> {
self.data.elements.iter().zip(self.data.labels.iter())
}
pub fn into_iter(self) -> Vec<Vec<T>> {
let mut map = HashMap::new();
for (element, label) in self.data.elements.iter().zip(self.data.labels.iter()) {
let mut centroid = map.entry(label).or_insert(vec![]);
centroid.push(element.clone());
}
let mut output = vec![];
for (_, value) in map {
let mut cluster = vec![];
for element in value {
cluster.push(element);
}
output.push(cluster);
}
output
}
} }
pub fn kmeans<'a, T: 'a + Clone + PartialEq>( pub fn kmeans<T: Clone + PartialEq, D: Fn(&T, &T) -> f64, C: Fn(&Vec<T>) -> T>(
elements: Vec<T>, elements: Vec<T>,
initial: Vec<T>, initial: Vec<T>,
distance: &'a Distance<T>, distance: D,
centroid: &'a Centroid<T>, centroid: C,
max_iteration: usize, max_iteration: usize,
) -> Result<(Vec<Vec<T>>,usize), Error> { ) -> Result<(Kmeans<T, D, C>,usize), Error> {
let mut clusters = Kmeans::new(elements, initial, distance, centroid); let mut clusters = Kmeans::new(elements, initial, distance, centroid);
@ -109,24 +135,6 @@ pub fn kmeans<'a, T: 'a + Clone + PartialEq>(
}; };
let mut map = HashMap::new(); Ok((clusters, iterations))
for (element, label) in clusters.data.elements.iter().zip(clusters.data.labels.iter()) {
let mut centroid = map.entry(label).or_insert(vec![]);
centroid.push(element.clone());
}
let mut output = vec![];
for (_, value) in map {
let mut cluster = vec![];
for element in value {
cluster.push(element);
}
output.push(cluster);
}
Ok((output, iterations))
} }