This commit is contained in:
Thomas Forgione 2018-02-28 14:31:14 +01:00
parent c9de1ca333
commit 35fec4243d
No known key found for this signature in database
GPG Key ID: C75CD416BD1FFCE1
4 changed files with 190 additions and 322 deletions

View File

@ -1,43 +0,0 @@
use std::fmt::Debug;
pub trait Clusterable where Self: Sized + Clone + PartialEq + Debug {
fn distance(&self, rhs: &Self) -> f64;
fn get_centroid<'a, I>(elements: I) -> Option<Self>
where I: Iterator<Item = &'a Self>, Self: 'a;
}
macro_rules! impl_clusterable {
( $type: ty) => {
impl Clusterable for $type {
fn distance(&self, rhs: &Self) -> f64 {
if self > rhs {
*self as f64 - *rhs as f64
} else {
*rhs as f64 - *self as f64
}
}
fn get_centroid<'a, I>(elements: I) -> Option<Self>
where I: Iterator<Item = &'a Self>, Self: 'a {
let mut tmp = 0.0;
let mut count = 0.0;
for element in elements {
tmp += element;
count += 1.0;
}
if count > 0.0 {
Some(tmp / count)
} else {
None
}
}
}
}
}
impl_clusterable!(f32);
impl_clusterable!(f64);

View File

@ -5,7 +5,12 @@ use std::fmt::{Display, Formatter, Result};
use std::fs::File;
use rand::distributions::Range;
use rand::distributions::normal::Normal;
use generic_kmeans::{equal_kmeans, Clusterable};
use generic_kmeans::kmeans;
const CLUSTER_NUMBER: usize = 3;
const VAR: f64 = 1.0;
const MAX: f64 = 10.0;
#[derive(PartialEq, Copy, Clone, Debug)]
struct Vector2<T> {
@ -28,38 +33,69 @@ impl<T: Display> Display for Vector2<T> {
}
}
impl Clusterable for Vector2<f64> {
fn distance(&self, other: &Self) -> f64 {
(self.x - other.x) * (self.x - other.x) + (self.y - other.y) * (self.y - other.y)
fn distance(v1: &Vector2<f64>, v2: &Vector2<f64>) -> f64 {
let dx = v2.x - v1.x;
let dy = v2.y - v1.y;
dx * dx + dy * dy
}
fn get_centroid<'a, I>(cluster: I) -> Option<Self>
where I: Iterator<Item = &'a Self>, Self: 'a {
fn centroid(elements: &Vec<Vector2<f64>>) -> Vector2<f64> {
let mut v = Vector2::new(0.0, 0.0);
let mut centroid = Vector2::new(0.0, 0.0);
let mut count = 0.0;
for i in cluster {
centroid.x += i.x;
centroid.y += i.y;
count += 1.0;
for element in elements {
v.x += element.x;
v.y += element.y;
}
if count > 0.0 {
centroid.x /= count as f64;
centroid.y /= count as f64;
Some(centroid)
} else {
None
v.x /= elements.len() as f64;
v.y /= elements.len() as f64;
v
}
fn generate_points(centers: &mut Vec<(Vector2<f64>, Normal, Normal)>) -> Vec<Vector2<f64>> {
let mut rng = rand::thread_rng();
let mut output = vec![];
for &mut (_, x_rng, y_rng) in centers.iter_mut() {
for _ in 0..100 {
use rand::distributions::IndependentSample;
output.push(Vector2::new(x_rng.ind_sample(&mut rng), y_rng.ind_sample(&mut rng)));
}
}
output
}
fn generate_centers(number: usize) -> Vec<(Vector2<f64>, Normal, Normal)> {
let mut output = vec![];
let range = Range::new(0.0, MAX);
let mut rng = rand::thread_rng();
for _ in 0..number {
use rand::distributions::IndependentSample;
let center = Vector2::new(range.ind_sample(&mut rng), range.ind_sample(&mut rng));
output.push((center, Normal::new(center.x, VAR), Normal::new(center.y, VAR)));
}
output
}
fn main() {
const VAR: f64 = 1.0;
let mut centers = generate_centers(CLUSTER_NUMBER);
let elements = generate_points(&mut centers);
use rand::distributions::IndependentSample;
let initial = vec![
Vector2::new(0.0, 0.0),
Vector2::new(10.0, 0.0),
Vector2::new(0.0, 10.0),
];
let colors = vec![
"blue",
@ -67,39 +103,13 @@ fn main() {
"green",
];
let range = Range::new(0.0, 10.0);
let mut rng = rand::thread_rng();
let cluster_number = 3;
let kmeans = kmeans(elements, initial, &distance, &centroid, 1000).unwrap();
let mut centers = vec![];
for _ in 0..cluster_number {
let center = Vector2::new(range.ind_sample(&mut rng), range.ind_sample(&mut rng));
centers.push((center, Normal::new(center.x, VAR), Normal::new(center.y, VAR)));
}
let mut elements = vec![];
for &mut (_, x_rng, y_rng) in centers.iter_mut() {
for _ in 0..100 {
elements.push(Vector2::new(x_rng.ind_sample(&mut rng), y_rng.ind_sample(&mut rng)));
}
}
let initialization = vec![
Vector2::new(0.0,0.0),
Vector2::new(10.0,0.0),
Vector2::new(0.0,10.0),
];
let (clusters, nb_iterations) = equal_kmeans(initialization, elements, 100000).ok().unwrap();
let clusters = clusters.into_vec_vec();
println!("Converged in {} iterations.", kmeans.1);
let mut output = File::create("plot/dat.dat").unwrap();
for (index, (cluster, color)) in clusters.iter().zip(colors.iter()).enumerate() {
for (index, (cluster, color)) in kmeans.0.iter().zip(colors.iter()).enumerate() {
println!("Cluster {}: {} elements", index, cluster.len());
for element in cluster {
use std::io::Write;
@ -107,11 +117,10 @@ fn main() {
}
}
println!("Finished in {} iterations", nb_iterations);
let mut center_file = File::create("plot/centers.dat").unwrap();
for (&(center, _, _), color) in centers.iter().zip(&colors) {
use std::io::Write;
writeln!(center_file, "{} {} {}", center.x, center.y, color).unwrap();
}
}

View File

@ -1,220 +0,0 @@
use std;
use std::collections::HashMap;
use std::slice::Iter;
use std::vec::IntoIter;
use std::iter::Zip;
use clusterable::Clusterable;
pub struct Kmeans<T: Clusterable> {
centroids: Vec<T>,
elements: Vec<T>,
labels: Vec<usize>,
cluster_number: usize,
}
impl<T:Clusterable> Kmeans<T> {
pub fn new(centroids: Vec<T>, data: Vec<T>) -> Kmeans<T> {
let labels = Kmeans::build_labels(&centroids, &data);
let cluster_number = centroids.len();
Kmeans {
centroids: centroids,
elements: data,
labels: labels,
cluster_number: cluster_number
}
}
/// \returns True if converged
pub fn iterate(&mut self) -> bool {
// Update the centroids
let centroids = Kmeans::build_centroids(&self.elements, &self.labels, self.cluster_number);
if self.centroids == centroids {
true
} else {
self.centroids = centroids;
Kmeans::update_labels(&self.centroids, &self.elements, &mut self.labels);
false
}
}
pub fn build_labels(centroids: &Vec<T>, data: &Vec<T>) -> Vec<usize> {
debug_assert_ne!(0, centroids.len());
let mut output = vec![0; data.len()];
Kmeans::update_labels(centroids, data, &mut output);
output
}
pub fn update_labels(centroids: &Vec<T>, data: &Vec<T>, labels: &mut Vec<usize>) {
for (element, new_label) in data.iter().zip(labels.iter_mut()) {
*new_label = centroids
.iter()
.enumerate()
.min_by(|&(_, c1), &(_, c2)| {
c1.distance(element).partial_cmp(&c2.distance(element)).unwrap()
}).unwrap().0;
}
}
pub fn build_centroids(data: &Vec<T>, labels: &Vec<usize>, cluster_number: usize) -> Vec<T> {
let mut centroids = vec![];
for label in 0..cluster_number {
let to_consider = data
.iter()
.enumerate()
.filter(|&(index, _)| labels[index] == label)
.map(|(_, element)| element);
if let Some(centroid) = T::get_centroid(to_consider) {
centroids.push(centroid);
}
}
centroids
}
pub fn iter(&self) -> Zip<Iter<T>, Iter<usize>> {
debug_assert_eq!(self.elements.len(), self.labels.len());
return self.elements.iter().zip(self.labels.iter());
}
pub fn into_iter(self) -> Zip<IntoIter<T>, IntoIter<usize>> {
debug_assert_eq!(self.elements.len(), self.labels.len());
return self.elements.into_iter().zip(self.labels.into_iter());
}
pub fn into_vec_vec(self) -> Vec<Vec<T>> {
let mut map = HashMap::new();
for (element, label) in self.into_iter() {
let mut entry = map.entry(label).or_insert(vec![]);
entry.push(element);
}
let mut output = vec![];
for (_, cluster) in map {
let mut vec = vec![];
for element in cluster {
vec.push(element);
}
output.push(vec);
}
output
}
pub fn to_vec_vec(&self) -> Vec<(Vec<T>, usize)> {
let mut map = HashMap::new();
for (element, label) in self.iter() {
let mut entry = map.entry(*label).or_insert(vec![]);
entry.push(element.clone());
}
let mut output = vec![];
for (label, cluster) in map {
let mut vec = vec![];
for element in cluster {
vec.push(element);
}
output.push((vec, label));
}
output
}
}
pub enum Error {
IterationsLimitExceeded,
}
pub fn kmeans<T: Clusterable>(centroids: Vec<T>, data: Vec<T>, max_iterations: usize)
-> Result<(Kmeans<T>, usize), Error> {
let mut kmeans = Kmeans::new(centroids, data);
for nb_iterations in 0..max_iterations {
let stable = kmeans.iterate();
if stable {
return Ok((kmeans, nb_iterations));
}
}
Err(Error::IterationsLimitExceeded)
}
pub fn equal_kmeans<T: Clusterable>(centroids: Vec<T>, data: Vec<T>, max_iterations: usize)
-> Result<(Kmeans<T>, usize), Error> {
let number_of_elements = data.len();
let (kmeans, nb_iterations) = kmeans(centroids, data, max_iterations)?;
let mut clusters = kmeans.to_vec_vec();
for &mut (ref mut cluster, ref mut label) in &mut clusters {
cluster.sort_by(|ref e1, ref e2| {
*&e1.distance(&kmeans.centroids[*label])
.partial_cmp(&e2.distance(&kmeans.centroids[*label])).unwrap()
});
}
let max = (number_of_elements as f64 / clusters.len() as f64).ceil() as usize;
let mut to_replace = vec![];
for &mut (ref mut cluster, _) in &mut clusters {
while cluster.len() > max {
to_replace.push(cluster.pop().unwrap());
}
}
for element in to_replace {
// Find the best non full cluster to place it in
let mut best_distance = std::f64::MAX;
let mut best_index = 0;
for (index, centroid) in kmeans.centroids.iter().enumerate() {
let new_distance = element.distance(centroid);
if new_distance < best_distance && clusters[index].0.len() < max {
best_distance = new_distance;
best_index = index;
}
}
clusters[best_index].0.push(element);
}
let mut elements = vec![];
let mut labels = vec![];
let mut centroids = vec![];
for (index, &(ref cluster, _)) in clusters.iter().enumerate() {
for element in cluster {
elements.push(element.clone());
labels.push(index);
}
centroids.push(Clusterable::get_centroid(cluster.iter()).unwrap());
}
let len = centroids.len();
// Build a new k-means from clusters
let new_kmeans = Kmeans {
centroids: centroids,
elements: elements,
labels: labels,
cluster_number: len,
};
Ok((new_kmeans, nb_iterations))
}

View File

@ -1,10 +1,132 @@
pub mod kmeans;
pub mod clusterable;
use std::collections::HashMap;
pub use kmeans::Kmeans;
pub use clusterable::Clusterable;
pub use kmeans::Error;
pub use kmeans::kmeans;
pub use kmeans::equal_kmeans;
#[derive(Debug)]
pub enum Error {
TooManyIterations,
}
type Distance<T> = Fn(&T, &T) -> f64;
type Centroid<T> = Fn(&Vec<T>) -> T;
pub struct KmeansData<T: Clone + PartialEq> {
pub elements: Vec<T>,
pub labels: Vec<usize>,
}
pub struct Kmeans<'a, T: 'a + Clone + PartialEq> {
data: KmeansData<T>,
centroids: Vec<T>,
distance: &'a Distance<T>,
centroid: &'a Centroid<T>,
}
impl<'a, T: 'a + Clone + PartialEq> Kmeans<'a, T> {
pub fn new(data: Vec<T>, centroids: Vec<T>, distance: &'a Distance<T>, centroid: &'a Centroid<T>) -> Kmeans<'a, T> {
let len = data.len();
Kmeans {
data: KmeansData { elements: data, labels: vec![0; len] },
centroids: centroids,
distance: distance,
centroid: centroid,
}
}
/// True if converged
pub fn iterate(&mut self) -> bool {
// Update the labels
self.update_labels();
// Update the centroids
let new_centroids = self.compute_centroids();
let ret = new_centroids == self.centroids;
self.centroids = new_centroids;
ret
}
fn update_labels(&mut self) {
let iterator = self.data.elements.iter().zip(self.data.labels.iter_mut());
for (ref element, ref mut label) in iterator {
let mut best_distance = std::f64::MAX;
let mut best_label = 0;
for (index, centroid) in self.centroids.iter().enumerate() {
let new_distance = (self.distance)(*element, centroid);
if new_distance < best_distance {
best_distance = new_distance;
best_label = index;
}
}
**label = best_label;
}
}
fn compute_centroids(&self) -> Vec<T> {
let mut centroids_map = HashMap::new();
for (element, label) in self.data.elements.iter().zip(self.data.labels.iter()) {
let mut centroid = centroids_map.entry(label).or_insert(vec![]);
centroid.push(element.clone());
}
let mut new_centroids = vec![];
for (_, value) in centroids_map {
new_centroids.push((self.centroid)(&value));
}
new_centroids
}
}
pub fn kmeans<'a, T: 'a + Clone + PartialEq>(
elements: Vec<T>,
initial: Vec<T>,
distance: &'a Distance<T>,
centroid: &'a Centroid<T>,
max_iteration: usize,
) -> Result<(Vec<Vec<T>>,usize), Error> {
let mut clusters = Kmeans::new(elements, initial, distance, centroid);
let mut counter = 0;
let iterations = loop {
counter += 1;
if clusters.iterate() || counter > max_iteration {
break counter;
}
};
let mut map = HashMap::new();
for (element, label) in clusters.data.elements.iter().zip(clusters.data.labels.iter()) {
let mut centroid = map.entry(label).or_insert(vec![]);
centroid.push(element.clone());
}
let mut output = vec![];
for (_, value) in map {
let mut cluster = vec![];
for element in value {
cluster.push(element);
}
output.push(cluster);
}
Ok((output, iterations))
}