Even cleaner

This commit is contained in:
Thomas Forgione 2018-02-28 15:08:36 +01:00
parent 35fec4243d
commit 77ea216075
No known key found for this signature in database
GPG Key ID: C75CD416BD1FFCE1
2 changed files with 44 additions and 35 deletions

View File

@ -103,13 +103,14 @@ fn main() {
"green", "green",
]; ];
let kmeans = kmeans(elements, initial, &distance, &centroid, 1000).unwrap(); let (kmeans, nb_iterations) = kmeans(elements, initial, &distance, &centroid, 1000).unwrap();
let kmeans = kmeans.into_iter();
println!("Converged in {} iterations.", kmeans.1); println!("Converged in {} iterations.", nb_iterations);
let mut output = File::create("plot/dat.dat").unwrap(); let mut output = File::create("plot/dat.dat").unwrap();
for (index, (cluster, color)) in kmeans.0.iter().zip(colors.iter()).enumerate() { for (index, (cluster, color)) in kmeans.iter().zip(colors.iter()).enumerate() {
println!("Cluster {}: {} elements", index, cluster.len()); println!("Cluster {}: {} elements", index, cluster.len());
for element in cluster { for element in cluster {
use std::io::Write; use std::io::Write;

View File

@ -1,3 +1,5 @@
use std::slice::Iter;
use std::iter::Zip;
use std::collections::HashMap; use std::collections::HashMap;
#[derive(Debug)] #[derive(Debug)]
@ -5,23 +7,20 @@ pub enum Error {
TooManyIterations, TooManyIterations,
} }
type Distance<T> = Fn(&T, &T) -> f64;
type Centroid<T> = Fn(&Vec<T>) -> T;
pub struct KmeansData<T: Clone + PartialEq> { pub struct KmeansData<T: Clone + PartialEq> {
pub elements: Vec<T>, pub elements: Vec<T>,
pub labels: Vec<usize>, pub labels: Vec<usize>,
} }
pub struct Kmeans<'a, T: 'a + Clone + PartialEq> { pub struct Kmeans<T: Clone + PartialEq, D: Fn(&T, &T) -> f64, C: Fn(&Vec<T>) -> T> {
data: KmeansData<T>, data: KmeansData<T>,
centroids: Vec<T>, centroids: Vec<T>,
distance: &'a Distance<T>, distance: D,
centroid: &'a Centroid<T>, centroid: C,
} }
impl<'a, T: 'a + Clone + PartialEq> Kmeans<'a, T> { impl<T: Clone + PartialEq, D: Fn(&T, &T) -> f64, C: Fn(&Vec<T>) -> T> Kmeans<T, D, C> {
pub fn new(data: Vec<T>, centroids: Vec<T>, distance: &'a Distance<T>, centroid: &'a Centroid<T>) -> Kmeans<'a, T> { pub fn new(data: Vec<T>, centroids: Vec<T>, distance: D, centroid: C) -> Kmeans<T, D, C> {
let len = data.len(); let len = data.len();
Kmeans { Kmeans {
@ -85,33 +84,15 @@ impl<'a, T: 'a + Clone + PartialEq> Kmeans<'a, T> {
new_centroids new_centroids
} }
}
pub fn kmeans<'a, T: 'a + Clone + PartialEq>( pub fn iter(&self) -> Zip<Iter<T>, Iter<usize>> {
elements: Vec<T>, self.data.elements.iter().zip(self.data.labels.iter())
initial: Vec<T>,
distance: &'a Distance<T>,
centroid: &'a Centroid<T>,
max_iteration: usize,
) -> Result<(Vec<Vec<T>>,usize), Error> {
let mut clusters = Kmeans::new(elements, initial, distance, centroid);
let mut counter = 0;
let iterations = loop {
counter += 1;
if clusters.iterate() || counter > max_iteration {
break counter;
} }
}; pub fn into_iter(self) -> Vec<Vec<T>> {
let mut map = HashMap::new(); let mut map = HashMap::new();
for (element, label) in clusters.data.elements.iter().zip(clusters.data.labels.iter()) { for (element, label) in self.data.elements.iter().zip(self.data.labels.iter()) {
let mut centroid = map.entry(label).or_insert(vec![]); let mut centroid = map.entry(label).or_insert(vec![]);
centroid.push(element.clone()); centroid.push(element.clone());
} }
@ -128,5 +109,32 @@ pub fn kmeans<'a, T: 'a + Clone + PartialEq>(
output.push(cluster); output.push(cluster);
} }
Ok((output, iterations)) output
}
} }
pub fn kmeans<T: Clone + PartialEq, D: Fn(&T, &T) -> f64, C: Fn(&Vec<T>) -> T>(
elements: Vec<T>,
initial: Vec<T>,
distance: D,
centroid: C,
max_iteration: usize,
) -> Result<(Kmeans<T, D, C>,usize), Error> {
let mut clusters = Kmeans::new(elements, initial, distance, centroid);
let mut counter = 0;
let iterations = loop {
counter += 1;
if clusters.iterate() || counter > max_iteration {
break counter;
}
};
Ok((clusters, iterations))
}