Added equal kmeans... not very good though
This commit is contained in:
parent
164d15c6e7
commit
c9de1ca333
|
@ -5,7 +5,7 @@ use std::fmt::{Display, Formatter, Result};
|
|||
use std::fs::File;
|
||||
use rand::distributions::Range;
|
||||
use rand::distributions::normal::Normal;
|
||||
use generic_kmeans::{kmeans, Clusterable};
|
||||
use generic_kmeans::{equal_kmeans, Clusterable};
|
||||
|
||||
#[derive(PartialEq, Copy, Clone, Debug)]
|
||||
struct Vector2<T> {
|
||||
|
@ -57,6 +57,8 @@ impl Clusterable for Vector2<f64> {
|
|||
|
||||
fn main() {
|
||||
|
||||
const VAR: f64 = 1.0;
|
||||
|
||||
use rand::distributions::IndependentSample;
|
||||
|
||||
let colors = vec![
|
||||
|
@ -72,7 +74,7 @@ fn main() {
|
|||
let mut centers = vec![];
|
||||
for _ in 0..cluster_number {
|
||||
let center = Vector2::new(range.ind_sample(&mut rng), range.ind_sample(&mut rng));
|
||||
centers.push((center, Normal::new(center.x, 0.5), Normal::new(center.y, 0.5)));
|
||||
centers.push((center, Normal::new(center.x, VAR), Normal::new(center.y, VAR)));
|
||||
}
|
||||
|
||||
let mut elements = vec![];
|
||||
|
@ -92,13 +94,17 @@ fn main() {
|
|||
];
|
||||
|
||||
|
||||
let (clusters, nb_iterations) = kmeans(initialization, elements, 100000).ok().unwrap();
|
||||
let (clusters, nb_iterations) = equal_kmeans(initialization, elements, 100000).ok().unwrap();
|
||||
let clusters = clusters.into_vec_vec();
|
||||
|
||||
let mut output = File::create("plot/dat.dat").unwrap();
|
||||
|
||||
for (element, &label) in clusters.iter() {
|
||||
use std::io::Write;
|
||||
writeln!(output, "{} {} {}", element.x, element.y, colors[label]).unwrap();
|
||||
for (index, (cluster, color)) in clusters.iter().zip(colors.iter()).enumerate() {
|
||||
println!("Cluster {}: {} elements", index, cluster.len());
|
||||
for element in cluster {
|
||||
use std::io::Write;
|
||||
writeln!(output, "{} {} {}", element.x, element.y, color).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
println!("Finished in {} iterations", nb_iterations);
|
||||
|
|
138
src/kmeans.rs
138
src/kmeans.rs
|
@ -1,4 +1,7 @@
|
|||
use std;
|
||||
use std::collections::HashMap;
|
||||
use std::slice::Iter;
|
||||
use std::vec::IntoIter;
|
||||
use std::iter::Zip;
|
||||
use clusterable::Clusterable;
|
||||
|
||||
|
@ -78,5 +81,140 @@ impl<T:Clusterable> Kmeans<T> {
|
|||
return self.elements.iter().zip(self.labels.iter());
|
||||
}
|
||||
|
||||
pub fn into_iter(self) -> Zip<IntoIter<T>, IntoIter<usize>> {
|
||||
debug_assert_eq!(self.elements.len(), self.labels.len());
|
||||
return self.elements.into_iter().zip(self.labels.into_iter());
|
||||
}
|
||||
|
||||
pub fn into_vec_vec(self) -> Vec<Vec<T>> {
|
||||
let mut map = HashMap::new();
|
||||
|
||||
for (element, label) in self.into_iter() {
|
||||
let mut entry = map.entry(label).or_insert(vec![]);
|
||||
entry.push(element);
|
||||
}
|
||||
|
||||
let mut output = vec![];
|
||||
|
||||
for (_, cluster) in map {
|
||||
let mut vec = vec![];
|
||||
for element in cluster {
|
||||
vec.push(element);
|
||||
}
|
||||
output.push(vec);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub fn to_vec_vec(&self) -> Vec<(Vec<T>, usize)> {
|
||||
let mut map = HashMap::new();
|
||||
|
||||
for (element, label) in self.iter() {
|
||||
let mut entry = map.entry(*label).or_insert(vec![]);
|
||||
entry.push(element.clone());
|
||||
}
|
||||
|
||||
let mut output = vec![];
|
||||
|
||||
for (label, cluster) in map {
|
||||
let mut vec = vec![];
|
||||
for element in cluster {
|
||||
vec.push(element);
|
||||
}
|
||||
output.push((vec, label));
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
pub enum Error {
|
||||
IterationsLimitExceeded,
|
||||
}
|
||||
|
||||
pub fn kmeans<T: Clusterable>(centroids: Vec<T>, data: Vec<T>, max_iterations: usize)
|
||||
-> Result<(Kmeans<T>, usize), Error> {
|
||||
|
||||
let mut kmeans = Kmeans::new(centroids, data);
|
||||
|
||||
for nb_iterations in 0..max_iterations {
|
||||
|
||||
let stable = kmeans.iterate();
|
||||
|
||||
if stable {
|
||||
return Ok((kmeans, nb_iterations));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Err(Error::IterationsLimitExceeded)
|
||||
}
|
||||
|
||||
pub fn equal_kmeans<T: Clusterable>(centroids: Vec<T>, data: Vec<T>, max_iterations: usize)
|
||||
-> Result<(Kmeans<T>, usize), Error> {
|
||||
|
||||
let number_of_elements = data.len();
|
||||
let (kmeans, nb_iterations) = kmeans(centroids, data, max_iterations)?;
|
||||
let mut clusters = kmeans.to_vec_vec();
|
||||
|
||||
for &mut (ref mut cluster, ref mut label) in &mut clusters {
|
||||
cluster.sort_by(|ref e1, ref e2| {
|
||||
*&e1.distance(&kmeans.centroids[*label])
|
||||
.partial_cmp(&e2.distance(&kmeans.centroids[*label])).unwrap()
|
||||
});
|
||||
}
|
||||
|
||||
let max = (number_of_elements as f64 / clusters.len() as f64).ceil() as usize;
|
||||
|
||||
let mut to_replace = vec![];
|
||||
for &mut (ref mut cluster, _) in &mut clusters {
|
||||
while cluster.len() > max {
|
||||
to_replace.push(cluster.pop().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
for element in to_replace {
|
||||
// Find the best non full cluster to place it in
|
||||
let mut best_distance = std::f64::MAX;
|
||||
let mut best_index = 0;
|
||||
|
||||
for (index, centroid) in kmeans.centroids.iter().enumerate() {
|
||||
|
||||
let new_distance = element.distance(centroid);
|
||||
if new_distance < best_distance && clusters[index].0.len() < max {
|
||||
best_distance = new_distance;
|
||||
best_index = index;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
clusters[best_index].0.push(element);
|
||||
}
|
||||
|
||||
let mut elements = vec![];
|
||||
let mut labels = vec![];
|
||||
let mut centroids = vec![];
|
||||
|
||||
for (index, &(ref cluster, _)) in clusters.iter().enumerate() {
|
||||
for element in cluster {
|
||||
elements.push(element.clone());
|
||||
labels.push(index);
|
||||
}
|
||||
centroids.push(Clusterable::get_centroid(cluster.iter()).unwrap());
|
||||
}
|
||||
|
||||
let len = centroids.len();
|
||||
|
||||
// Build a new k-means from clusters
|
||||
let new_kmeans = Kmeans {
|
||||
centroids: centroids,
|
||||
elements: elements,
|
||||
labels: labels,
|
||||
cluster_number: len,
|
||||
};
|
||||
|
||||
Ok((new_kmeans, nb_iterations))
|
||||
|
||||
}
|
||||
|
|
23
src/lib.rs
23
src/lib.rs
|
@ -3,25 +3,8 @@ pub mod clusterable;
|
|||
|
||||
pub use kmeans::Kmeans;
|
||||
pub use clusterable::Clusterable;
|
||||
pub use kmeans::Error;
|
||||
pub use kmeans::kmeans;
|
||||
pub use kmeans::equal_kmeans;
|
||||
|
||||
pub enum Error {
|
||||
IterationsLimitExceeded,
|
||||
}
|
||||
|
||||
pub fn kmeans<T: Clusterable>(centroids: Vec<T>, data: Vec<T>, max_iterations: usize)
|
||||
-> Result<(Kmeans<T>, usize), Error> {
|
||||
|
||||
let mut kmeans = Kmeans::new(centroids, data);
|
||||
|
||||
for nb_iterations in 0..max_iterations {
|
||||
|
||||
let stable = kmeans.iterate();
|
||||
|
||||
if stable {
|
||||
return Ok((kmeans, nb_iterations));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Err(Error::IterationsLimitExceeded)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue