Added equal kmeans... not very good though

This commit is contained in:
Thomas Forgione 2018-02-26 10:32:25 +01:00
parent 164d15c6e7
commit c9de1ca333
No known key found for this signature in database
GPG Key ID: C75CD416BD1FFCE1
3 changed files with 153 additions and 26 deletions

View File

@ -5,7 +5,7 @@ use std::fmt::{Display, Formatter, Result};
use std::fs::File; use std::fs::File;
use rand::distributions::Range; use rand::distributions::Range;
use rand::distributions::normal::Normal; use rand::distributions::normal::Normal;
use generic_kmeans::{kmeans, Clusterable}; use generic_kmeans::{equal_kmeans, Clusterable};
#[derive(PartialEq, Copy, Clone, Debug)] #[derive(PartialEq, Copy, Clone, Debug)]
struct Vector2<T> { struct Vector2<T> {
@ -57,6 +57,8 @@ impl Clusterable for Vector2<f64> {
fn main() { fn main() {
const VAR: f64 = 1.0;
use rand::distributions::IndependentSample; use rand::distributions::IndependentSample;
let colors = vec![ let colors = vec![
@ -72,7 +74,7 @@ fn main() {
let mut centers = vec![]; let mut centers = vec![];
for _ in 0..cluster_number { for _ in 0..cluster_number {
let center = Vector2::new(range.ind_sample(&mut rng), range.ind_sample(&mut rng)); let center = Vector2::new(range.ind_sample(&mut rng), range.ind_sample(&mut rng));
centers.push((center, Normal::new(center.x, 0.5), Normal::new(center.y, 0.5))); centers.push((center, Normal::new(center.x, VAR), Normal::new(center.y, VAR)));
} }
let mut elements = vec![]; let mut elements = vec![];
@ -92,13 +94,17 @@ fn main() {
]; ];
let (clusters, nb_iterations) = kmeans(initialization, elements, 100000).ok().unwrap(); let (clusters, nb_iterations) = equal_kmeans(initialization, elements, 100000).ok().unwrap();
let clusters = clusters.into_vec_vec();
let mut output = File::create("plot/dat.dat").unwrap(); let mut output = File::create("plot/dat.dat").unwrap();
for (element, &label) in clusters.iter() { for (index, (cluster, color)) in clusters.iter().zip(colors.iter()).enumerate() {
println!("Cluster {}: {} elements", index, cluster.len());
for element in cluster {
use std::io::Write; use std::io::Write;
writeln!(output, "{} {} {}", element.x, element.y, colors[label]).unwrap(); writeln!(output, "{} {} {}", element.x, element.y, color).unwrap();
}
} }
println!("Finished in {} iterations", nb_iterations); println!("Finished in {} iterations", nb_iterations);

View File

@ -1,4 +1,7 @@
use std;
use std::collections::HashMap;
use std::slice::Iter; use std::slice::Iter;
use std::vec::IntoIter;
use std::iter::Zip; use std::iter::Zip;
use clusterable::Clusterable; use clusterable::Clusterable;
@ -78,5 +81,140 @@ impl<T:Clusterable> Kmeans<T> {
return self.elements.iter().zip(self.labels.iter()); return self.elements.iter().zip(self.labels.iter());
} }
pub fn into_iter(self) -> Zip<IntoIter<T>, IntoIter<usize>> {
debug_assert_eq!(self.elements.len(), self.labels.len());
return self.elements.into_iter().zip(self.labels.into_iter());
}
pub fn into_vec_vec(self) -> Vec<Vec<T>> {
let mut map = HashMap::new();
for (element, label) in self.into_iter() {
let mut entry = map.entry(label).or_insert(vec![]);
entry.push(element);
}
let mut output = vec![];
for (_, cluster) in map {
let mut vec = vec![];
for element in cluster {
vec.push(element);
}
output.push(vec);
}
output
}
pub fn to_vec_vec(&self) -> Vec<(Vec<T>, usize)> {
let mut map = HashMap::new();
for (element, label) in self.iter() {
let mut entry = map.entry(*label).or_insert(vec![]);
entry.push(element.clone());
}
let mut output = vec![];
for (label, cluster) in map {
let mut vec = vec![];
for element in cluster {
vec.push(element);
}
output.push((vec, label));
}
output
}
} }
pub enum Error {
IterationsLimitExceeded,
}
pub fn kmeans<T: Clusterable>(centroids: Vec<T>, data: Vec<T>, max_iterations: usize)
-> Result<(Kmeans<T>, usize), Error> {
let mut kmeans = Kmeans::new(centroids, data);
for nb_iterations in 0..max_iterations {
let stable = kmeans.iterate();
if stable {
return Ok((kmeans, nb_iterations));
}
}
Err(Error::IterationsLimitExceeded)
}
pub fn equal_kmeans<T: Clusterable>(centroids: Vec<T>, data: Vec<T>, max_iterations: usize)
-> Result<(Kmeans<T>, usize), Error> {
let number_of_elements = data.len();
let (kmeans, nb_iterations) = kmeans(centroids, data, max_iterations)?;
let mut clusters = kmeans.to_vec_vec();
for &mut (ref mut cluster, ref mut label) in &mut clusters {
cluster.sort_by(|ref e1, ref e2| {
*&e1.distance(&kmeans.centroids[*label])
.partial_cmp(&e2.distance(&kmeans.centroids[*label])).unwrap()
});
}
let max = (number_of_elements as f64 / clusters.len() as f64).ceil() as usize;
let mut to_replace = vec![];
for &mut (ref mut cluster, _) in &mut clusters {
while cluster.len() > max {
to_replace.push(cluster.pop().unwrap());
}
}
for element in to_replace {
// Find the best non full cluster to place it in
let mut best_distance = std::f64::MAX;
let mut best_index = 0;
for (index, centroid) in kmeans.centroids.iter().enumerate() {
let new_distance = element.distance(centroid);
if new_distance < best_distance && clusters[index].0.len() < max {
best_distance = new_distance;
best_index = index;
}
}
clusters[best_index].0.push(element);
}
let mut elements = vec![];
let mut labels = vec![];
let mut centroids = vec![];
for (index, &(ref cluster, _)) in clusters.iter().enumerate() {
for element in cluster {
elements.push(element.clone());
labels.push(index);
}
centroids.push(Clusterable::get_centroid(cluster.iter()).unwrap());
}
let len = centroids.len();
// Build a new k-means from clusters
let new_kmeans = Kmeans {
centroids: centroids,
elements: elements,
labels: labels,
cluster_number: len,
};
Ok((new_kmeans, nb_iterations))
}

View File

@ -3,25 +3,8 @@ pub mod clusterable;
pub use kmeans::Kmeans; pub use kmeans::Kmeans;
pub use clusterable::Clusterable; pub use clusterable::Clusterable;
pub use kmeans::Error;
pub use kmeans::kmeans;
pub use kmeans::equal_kmeans;
pub enum Error {
IterationsLimitExceeded,
}
pub fn kmeans<T: Clusterable>(centroids: Vec<T>, data: Vec<T>, max_iterations: usize)
-> Result<(Kmeans<T>, usize), Error> {
let mut kmeans = Kmeans::new(centroids, data);
for nb_iterations in 0..max_iterations {
let stable = kmeans.iterate();
if stable {
return Ok((kmeans, nb_iterations));
}
}
Err(Error::IterationsLimitExceeded)
}