Added equal kmeans... not very good though
This commit is contained in:
parent
164d15c6e7
commit
c9de1ca333
|
@ -5,7 +5,7 @@ use std::fmt::{Display, Formatter, Result};
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use rand::distributions::Range;
|
use rand::distributions::Range;
|
||||||
use rand::distributions::normal::Normal;
|
use rand::distributions::normal::Normal;
|
||||||
use generic_kmeans::{kmeans, Clusterable};
|
use generic_kmeans::{equal_kmeans, Clusterable};
|
||||||
|
|
||||||
#[derive(PartialEq, Copy, Clone, Debug)]
|
#[derive(PartialEq, Copy, Clone, Debug)]
|
||||||
struct Vector2<T> {
|
struct Vector2<T> {
|
||||||
|
@ -57,6 +57,8 @@ impl Clusterable for Vector2<f64> {
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
|
|
||||||
|
const VAR: f64 = 1.0;
|
||||||
|
|
||||||
use rand::distributions::IndependentSample;
|
use rand::distributions::IndependentSample;
|
||||||
|
|
||||||
let colors = vec![
|
let colors = vec![
|
||||||
|
@ -72,7 +74,7 @@ fn main() {
|
||||||
let mut centers = vec![];
|
let mut centers = vec![];
|
||||||
for _ in 0..cluster_number {
|
for _ in 0..cluster_number {
|
||||||
let center = Vector2::new(range.ind_sample(&mut rng), range.ind_sample(&mut rng));
|
let center = Vector2::new(range.ind_sample(&mut rng), range.ind_sample(&mut rng));
|
||||||
centers.push((center, Normal::new(center.x, 0.5), Normal::new(center.y, 0.5)));
|
centers.push((center, Normal::new(center.x, VAR), Normal::new(center.y, VAR)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut elements = vec![];
|
let mut elements = vec![];
|
||||||
|
@ -92,13 +94,17 @@ fn main() {
|
||||||
];
|
];
|
||||||
|
|
||||||
|
|
||||||
let (clusters, nb_iterations) = kmeans(initialization, elements, 100000).ok().unwrap();
|
let (clusters, nb_iterations) = equal_kmeans(initialization, elements, 100000).ok().unwrap();
|
||||||
|
let clusters = clusters.into_vec_vec();
|
||||||
|
|
||||||
let mut output = File::create("plot/dat.dat").unwrap();
|
let mut output = File::create("plot/dat.dat").unwrap();
|
||||||
|
|
||||||
for (element, &label) in clusters.iter() {
|
for (index, (cluster, color)) in clusters.iter().zip(colors.iter()).enumerate() {
|
||||||
use std::io::Write;
|
println!("Cluster {}: {} elements", index, cluster.len());
|
||||||
writeln!(output, "{} {} {}", element.x, element.y, colors[label]).unwrap();
|
for element in cluster {
|
||||||
|
use std::io::Write;
|
||||||
|
writeln!(output, "{} {} {}", element.x, element.y, color).unwrap();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("Finished in {} iterations", nb_iterations);
|
println!("Finished in {} iterations", nb_iterations);
|
||||||
|
|
138
src/kmeans.rs
138
src/kmeans.rs
|
@ -1,4 +1,7 @@
|
||||||
|
use std;
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::slice::Iter;
|
use std::slice::Iter;
|
||||||
|
use std::vec::IntoIter;
|
||||||
use std::iter::Zip;
|
use std::iter::Zip;
|
||||||
use clusterable::Clusterable;
|
use clusterable::Clusterable;
|
||||||
|
|
||||||
|
@ -78,5 +81,140 @@ impl<T:Clusterable> Kmeans<T> {
|
||||||
return self.elements.iter().zip(self.labels.iter());
|
return self.elements.iter().zip(self.labels.iter());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn into_iter(self) -> Zip<IntoIter<T>, IntoIter<usize>> {
|
||||||
|
debug_assert_eq!(self.elements.len(), self.labels.len());
|
||||||
|
return self.elements.into_iter().zip(self.labels.into_iter());
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_vec_vec(self) -> Vec<Vec<T>> {
|
||||||
|
let mut map = HashMap::new();
|
||||||
|
|
||||||
|
for (element, label) in self.into_iter() {
|
||||||
|
let mut entry = map.entry(label).or_insert(vec![]);
|
||||||
|
entry.push(element);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut output = vec![];
|
||||||
|
|
||||||
|
for (_, cluster) in map {
|
||||||
|
let mut vec = vec![];
|
||||||
|
for element in cluster {
|
||||||
|
vec.push(element);
|
||||||
|
}
|
||||||
|
output.push(vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_vec_vec(&self) -> Vec<(Vec<T>, usize)> {
|
||||||
|
let mut map = HashMap::new();
|
||||||
|
|
||||||
|
for (element, label) in self.iter() {
|
||||||
|
let mut entry = map.entry(*label).or_insert(vec![]);
|
||||||
|
entry.push(element.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut output = vec![];
|
||||||
|
|
||||||
|
for (label, cluster) in map {
|
||||||
|
let mut vec = vec![];
|
||||||
|
for element in cluster {
|
||||||
|
vec.push(element);
|
||||||
|
}
|
||||||
|
output.push((vec, label));
|
||||||
|
}
|
||||||
|
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub enum Error {
|
||||||
|
IterationsLimitExceeded,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn kmeans<T: Clusterable>(centroids: Vec<T>, data: Vec<T>, max_iterations: usize)
|
||||||
|
-> Result<(Kmeans<T>, usize), Error> {
|
||||||
|
|
||||||
|
let mut kmeans = Kmeans::new(centroids, data);
|
||||||
|
|
||||||
|
for nb_iterations in 0..max_iterations {
|
||||||
|
|
||||||
|
let stable = kmeans.iterate();
|
||||||
|
|
||||||
|
if stable {
|
||||||
|
return Ok((kmeans, nb_iterations));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(Error::IterationsLimitExceeded)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn equal_kmeans<T: Clusterable>(centroids: Vec<T>, data: Vec<T>, max_iterations: usize)
|
||||||
|
-> Result<(Kmeans<T>, usize), Error> {
|
||||||
|
|
||||||
|
let number_of_elements = data.len();
|
||||||
|
let (kmeans, nb_iterations) = kmeans(centroids, data, max_iterations)?;
|
||||||
|
let mut clusters = kmeans.to_vec_vec();
|
||||||
|
|
||||||
|
for &mut (ref mut cluster, ref mut label) in &mut clusters {
|
||||||
|
cluster.sort_by(|ref e1, ref e2| {
|
||||||
|
*&e1.distance(&kmeans.centroids[*label])
|
||||||
|
.partial_cmp(&e2.distance(&kmeans.centroids[*label])).unwrap()
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let max = (number_of_elements as f64 / clusters.len() as f64).ceil() as usize;
|
||||||
|
|
||||||
|
let mut to_replace = vec![];
|
||||||
|
for &mut (ref mut cluster, _) in &mut clusters {
|
||||||
|
while cluster.len() > max {
|
||||||
|
to_replace.push(cluster.pop().unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for element in to_replace {
|
||||||
|
// Find the best non full cluster to place it in
|
||||||
|
let mut best_distance = std::f64::MAX;
|
||||||
|
let mut best_index = 0;
|
||||||
|
|
||||||
|
for (index, centroid) in kmeans.centroids.iter().enumerate() {
|
||||||
|
|
||||||
|
let new_distance = element.distance(centroid);
|
||||||
|
if new_distance < best_distance && clusters[index].0.len() < max {
|
||||||
|
best_distance = new_distance;
|
||||||
|
best_index = index;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
clusters[best_index].0.push(element);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut elements = vec![];
|
||||||
|
let mut labels = vec![];
|
||||||
|
let mut centroids = vec![];
|
||||||
|
|
||||||
|
for (index, &(ref cluster, _)) in clusters.iter().enumerate() {
|
||||||
|
for element in cluster {
|
||||||
|
elements.push(element.clone());
|
||||||
|
labels.push(index);
|
||||||
|
}
|
||||||
|
centroids.push(Clusterable::get_centroid(cluster.iter()).unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
let len = centroids.len();
|
||||||
|
|
||||||
|
// Build a new k-means from clusters
|
||||||
|
let new_kmeans = Kmeans {
|
||||||
|
centroids: centroids,
|
||||||
|
elements: elements,
|
||||||
|
labels: labels,
|
||||||
|
cluster_number: len,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((new_kmeans, nb_iterations))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
23
src/lib.rs
23
src/lib.rs
|
@ -3,25 +3,8 @@ pub mod clusterable;
|
||||||
|
|
||||||
pub use kmeans::Kmeans;
|
pub use kmeans::Kmeans;
|
||||||
pub use clusterable::Clusterable;
|
pub use clusterable::Clusterable;
|
||||||
|
pub use kmeans::Error;
|
||||||
|
pub use kmeans::kmeans;
|
||||||
|
pub use kmeans::equal_kmeans;
|
||||||
|
|
||||||
pub enum Error {
|
|
||||||
IterationsLimitExceeded,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn kmeans<T: Clusterable>(centroids: Vec<T>, data: Vec<T>, max_iterations: usize)
|
|
||||||
-> Result<(Kmeans<T>, usize), Error> {
|
|
||||||
|
|
||||||
let mut kmeans = Kmeans::new(centroids, data);
|
|
||||||
|
|
||||||
for nb_iterations in 0..max_iterations {
|
|
||||||
|
|
||||||
let stable = kmeans.iterate();
|
|
||||||
|
|
||||||
if stable {
|
|
||||||
return Ok((kmeans, nb_iterations));
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
Err(Error::IterationsLimitExceeded)
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in New Issue