Even cleaner
This commit is contained in:
parent
35fec4243d
commit
77ea216075
|
@ -103,13 +103,14 @@ fn main() {
|
||||||
"green",
|
"green",
|
||||||
];
|
];
|
||||||
|
|
||||||
let kmeans = kmeans(elements, initial, &distance, ¢roid, 1000).unwrap();
|
let (kmeans, nb_iterations) = kmeans(elements, initial, &distance, ¢roid, 1000).unwrap();
|
||||||
|
let kmeans = kmeans.into_iter();
|
||||||
|
|
||||||
println!("Converged in {} iterations.", kmeans.1);
|
println!("Converged in {} iterations.", nb_iterations);
|
||||||
|
|
||||||
let mut output = File::create("plot/dat.dat").unwrap();
|
let mut output = File::create("plot/dat.dat").unwrap();
|
||||||
|
|
||||||
for (index, (cluster, color)) in kmeans.0.iter().zip(colors.iter()).enumerate() {
|
for (index, (cluster, color)) in kmeans.iter().zip(colors.iter()).enumerate() {
|
||||||
println!("Cluster {}: {} elements", index, cluster.len());
|
println!("Cluster {}: {} elements", index, cluster.len());
|
||||||
for element in cluster {
|
for element in cluster {
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
72
src/lib.rs
72
src/lib.rs
|
@ -1,3 +1,5 @@
|
||||||
|
use std::slice::Iter;
|
||||||
|
use std::iter::Zip;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -5,23 +7,20 @@ pub enum Error {
|
||||||
TooManyIterations,
|
TooManyIterations,
|
||||||
}
|
}
|
||||||
|
|
||||||
type Distance<T> = Fn(&T, &T) -> f64;
|
|
||||||
type Centroid<T> = Fn(&Vec<T>) -> T;
|
|
||||||
|
|
||||||
pub struct KmeansData<T: Clone + PartialEq> {
|
pub struct KmeansData<T: Clone + PartialEq> {
|
||||||
pub elements: Vec<T>,
|
pub elements: Vec<T>,
|
||||||
pub labels: Vec<usize>,
|
pub labels: Vec<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Kmeans<'a, T: 'a + Clone + PartialEq> {
|
pub struct Kmeans<T: Clone + PartialEq, D: Fn(&T, &T) -> f64, C: Fn(&Vec<T>) -> T> {
|
||||||
data: KmeansData<T>,
|
data: KmeansData<T>,
|
||||||
centroids: Vec<T>,
|
centroids: Vec<T>,
|
||||||
distance: &'a Distance<T>,
|
distance: D,
|
||||||
centroid: &'a Centroid<T>,
|
centroid: C,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: 'a + Clone + PartialEq> Kmeans<'a, T> {
|
impl<T: Clone + PartialEq, D: Fn(&T, &T) -> f64, C: Fn(&Vec<T>) -> T> Kmeans<T, D, C> {
|
||||||
pub fn new(data: Vec<T>, centroids: Vec<T>, distance: &'a Distance<T>, centroid: &'a Centroid<T>) -> Kmeans<'a, T> {
|
pub fn new(data: Vec<T>, centroids: Vec<T>, distance: D, centroid: C) -> Kmeans<T, D, C> {
|
||||||
|
|
||||||
let len = data.len();
|
let len = data.len();
|
||||||
Kmeans {
|
Kmeans {
|
||||||
|
@ -85,33 +84,15 @@ impl<'a, T: 'a + Clone + PartialEq> Kmeans<'a, T> {
|
||||||
|
|
||||||
new_centroids
|
new_centroids
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn iter(&self) -> Zip<Iter<T>, Iter<usize>> {
|
||||||
|
self.data.elements.iter().zip(self.data.labels.iter())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn kmeans<'a, T: 'a + Clone + PartialEq>(
|
pub fn into_iter(self) -> Vec<Vec<T>> {
|
||||||
elements: Vec<T>,
|
|
||||||
initial: Vec<T>,
|
|
||||||
distance: &'a Distance<T>,
|
|
||||||
centroid: &'a Centroid<T>,
|
|
||||||
max_iteration: usize,
|
|
||||||
) -> Result<(Vec<Vec<T>>,usize), Error> {
|
|
||||||
|
|
||||||
let mut clusters = Kmeans::new(elements, initial, distance, centroid);
|
|
||||||
|
|
||||||
let mut counter = 0;
|
|
||||||
|
|
||||||
let iterations = loop {
|
|
||||||
|
|
||||||
counter += 1;
|
|
||||||
|
|
||||||
if clusters.iterate() || counter > max_iteration {
|
|
||||||
break counter;
|
|
||||||
}
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut map = HashMap::new();
|
let mut map = HashMap::new();
|
||||||
|
|
||||||
for (element, label) in clusters.data.elements.iter().zip(clusters.data.labels.iter()) {
|
for (element, label) in self.data.elements.iter().zip(self.data.labels.iter()) {
|
||||||
let mut centroid = map.entry(label).or_insert(vec![]);
|
let mut centroid = map.entry(label).or_insert(vec![]);
|
||||||
centroid.push(element.clone());
|
centroid.push(element.clone());
|
||||||
}
|
}
|
||||||
|
@ -128,5 +109,32 @@ pub fn kmeans<'a, T: 'a + Clone + PartialEq>(
|
||||||
output.push(cluster);
|
output.push(cluster);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok((output, iterations))
|
output
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn kmeans<T: Clone + PartialEq, D: Fn(&T, &T) -> f64, C: Fn(&Vec<T>) -> T>(
|
||||||
|
elements: Vec<T>,
|
||||||
|
initial: Vec<T>,
|
||||||
|
distance: D,
|
||||||
|
centroid: C,
|
||||||
|
max_iteration: usize,
|
||||||
|
) -> Result<(Kmeans<T, D, C>,usize), Error> {
|
||||||
|
|
||||||
|
let mut clusters = Kmeans::new(elements, initial, distance, centroid);
|
||||||
|
|
||||||
|
let mut counter = 0;
|
||||||
|
|
||||||
|
let iterations = loop {
|
||||||
|
|
||||||
|
counter += 1;
|
||||||
|
|
||||||
|
if clusters.iterate() || counter > max_iteration {
|
||||||
|
break counter;
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((clusters, iterations))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue