Cleaning
This commit is contained in:
parent
c9de1ca333
commit
35fec4243d
|
@ -1,43 +0,0 @@
|
|||
use std::fmt::Debug;
|
||||
|
||||
pub trait Clusterable where Self: Sized + Clone + PartialEq + Debug {
|
||||
fn distance(&self, rhs: &Self) -> f64;
|
||||
fn get_centroid<'a, I>(elements: I) -> Option<Self>
|
||||
where I: Iterator<Item = &'a Self>, Self: 'a;
|
||||
}
|
||||
|
||||
macro_rules! impl_clusterable {
|
||||
( $type: ty) => {
|
||||
impl Clusterable for $type {
|
||||
fn distance(&self, rhs: &Self) -> f64 {
|
||||
if self > rhs {
|
||||
*self as f64 - *rhs as f64
|
||||
} else {
|
||||
*rhs as f64 - *self as f64
|
||||
}
|
||||
}
|
||||
|
||||
fn get_centroid<'a, I>(elements: I) -> Option<Self>
|
||||
where I: Iterator<Item = &'a Self>, Self: 'a {
|
||||
|
||||
let mut tmp = 0.0;
|
||||
let mut count = 0.0;
|
||||
for element in elements {
|
||||
tmp += element;
|
||||
count += 1.0;
|
||||
}
|
||||
|
||||
if count > 0.0 {
|
||||
Some(tmp / count)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl_clusterable!(f32);
|
||||
impl_clusterable!(f64);
|
||||
|
113
src/example.rs
113
src/example.rs
|
@ -5,7 +5,12 @@ use std::fmt::{Display, Formatter, Result};
|
|||
use std::fs::File;
|
||||
use rand::distributions::Range;
|
||||
use rand::distributions::normal::Normal;
|
||||
use generic_kmeans::{equal_kmeans, Clusterable};
|
||||
|
||||
use generic_kmeans::kmeans;
|
||||
|
||||
const CLUSTER_NUMBER: usize = 3;
|
||||
const VAR: f64 = 1.0;
|
||||
const MAX: f64 = 10.0;
|
||||
|
||||
#[derive(PartialEq, Copy, Clone, Debug)]
|
||||
struct Vector2<T> {
|
||||
|
@ -28,38 +33,69 @@ impl<T: Display> Display for Vector2<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl Clusterable for Vector2<f64> {
|
||||
fn distance(&self, other: &Self) -> f64 {
|
||||
(self.x - other.x) * (self.x - other.x) + (self.y - other.y) * (self.y - other.y)
|
||||
fn distance(v1: &Vector2<f64>, v2: &Vector2<f64>) -> f64 {
|
||||
let dx = v2.x - v1.x;
|
||||
let dy = v2.y - v1.y;
|
||||
dx * dx + dy * dy
|
||||
}
|
||||
|
||||
fn centroid(elements: &Vec<Vector2<f64>>) -> Vector2<f64> {
|
||||
let mut v = Vector2::new(0.0, 0.0);
|
||||
|
||||
for element in elements {
|
||||
v.x += element.x;
|
||||
v.y += element.y;
|
||||
}
|
||||
|
||||
fn get_centroid<'a, I>(cluster: I) -> Option<Self>
|
||||
where I: Iterator<Item = &'a Self>, Self: 'a {
|
||||
v.x /= elements.len() as f64;
|
||||
v.y /= elements.len() as f64;
|
||||
|
||||
let mut centroid = Vector2::new(0.0, 0.0);
|
||||
let mut count = 0.0;
|
||||
v
|
||||
}
|
||||
|
||||
for i in cluster {
|
||||
centroid.x += i.x;
|
||||
centroid.y += i.y;
|
||||
count += 1.0;
|
||||
}
|
||||
fn generate_points(centers: &mut Vec<(Vector2<f64>, Normal, Normal)>) -> Vec<Vector2<f64>> {
|
||||
|
||||
if count > 0.0 {
|
||||
centroid.x /= count as f64;
|
||||
centroid.y /= count as f64;
|
||||
Some(centroid)
|
||||
} else {
|
||||
None
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut output = vec![];
|
||||
|
||||
for &mut (_, x_rng, y_rng) in centers.iter_mut() {
|
||||
for _ in 0..100 {
|
||||
use rand::distributions::IndependentSample;
|
||||
output.push(Vector2::new(x_rng.ind_sample(&mut rng), y_rng.ind_sample(&mut rng)));
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
fn generate_centers(number: usize) -> Vec<(Vector2<f64>, Normal, Normal)> {
|
||||
|
||||
let mut output = vec![];
|
||||
|
||||
let range = Range::new(0.0, MAX);
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for _ in 0..number {
|
||||
|
||||
use rand::distributions::IndependentSample;
|
||||
let center = Vector2::new(range.ind_sample(&mut rng), range.ind_sample(&mut rng));
|
||||
output.push((center, Normal::new(center.x, VAR), Normal::new(center.y, VAR)));
|
||||
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
fn main() {
|
||||
|
||||
const VAR: f64 = 1.0;
|
||||
let mut centers = generate_centers(CLUSTER_NUMBER);
|
||||
let elements = generate_points(&mut centers);
|
||||
|
||||
use rand::distributions::IndependentSample;
|
||||
let initial = vec![
|
||||
Vector2::new(0.0, 0.0),
|
||||
Vector2::new(10.0, 0.0),
|
||||
Vector2::new(0.0, 10.0),
|
||||
];
|
||||
|
||||
let colors = vec![
|
||||
"blue",
|
||||
|
@ -67,39 +103,13 @@ fn main() {
|
|||
"green",
|
||||
];
|
||||
|
||||
let range = Range::new(0.0, 10.0);
|
||||
let mut rng = rand::thread_rng();
|
||||
let cluster_number = 3;
|
||||
let kmeans = kmeans(elements, initial, &distance, ¢roid, 1000).unwrap();
|
||||
|
||||
let mut centers = vec![];
|
||||
for _ in 0..cluster_number {
|
||||
let center = Vector2::new(range.ind_sample(&mut rng), range.ind_sample(&mut rng));
|
||||
centers.push((center, Normal::new(center.x, VAR), Normal::new(center.y, VAR)));
|
||||
}
|
||||
|
||||
let mut elements = vec![];
|
||||
|
||||
for &mut (_, x_rng, y_rng) in centers.iter_mut() {
|
||||
|
||||
for _ in 0..100 {
|
||||
elements.push(Vector2::new(x_rng.ind_sample(&mut rng), y_rng.ind_sample(&mut rng)));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
let initialization = vec![
|
||||
Vector2::new(0.0,0.0),
|
||||
Vector2::new(10.0,0.0),
|
||||
Vector2::new(0.0,10.0),
|
||||
];
|
||||
|
||||
|
||||
let (clusters, nb_iterations) = equal_kmeans(initialization, elements, 100000).ok().unwrap();
|
||||
let clusters = clusters.into_vec_vec();
|
||||
println!("Converged in {} iterations.", kmeans.1);
|
||||
|
||||
let mut output = File::create("plot/dat.dat").unwrap();
|
||||
|
||||
for (index, (cluster, color)) in clusters.iter().zip(colors.iter()).enumerate() {
|
||||
for (index, (cluster, color)) in kmeans.0.iter().zip(colors.iter()).enumerate() {
|
||||
println!("Cluster {}: {} elements", index, cluster.len());
|
||||
for element in cluster {
|
||||
use std::io::Write;
|
||||
|
@ -107,11 +117,10 @@ fn main() {
|
|||
}
|
||||
}
|
||||
|
||||
println!("Finished in {} iterations", nb_iterations);
|
||||
|
||||
let mut center_file = File::create("plot/centers.dat").unwrap();
|
||||
for (&(center, _, _), color) in centers.iter().zip(&colors) {
|
||||
use std::io::Write;
|
||||
writeln!(center_file, "{} {} {}", center.x, center.y, color).unwrap();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
220
src/kmeans.rs
220
src/kmeans.rs
|
@ -1,220 +0,0 @@
|
|||
use std;
|
||||
use std::collections::HashMap;
|
||||
use std::slice::Iter;
|
||||
use std::vec::IntoIter;
|
||||
use std::iter::Zip;
|
||||
use clusterable::Clusterable;
|
||||
|
||||
pub struct Kmeans<T: Clusterable> {
|
||||
centroids: Vec<T>,
|
||||
elements: Vec<T>,
|
||||
labels: Vec<usize>,
|
||||
cluster_number: usize,
|
||||
}
|
||||
|
||||
impl<T:Clusterable> Kmeans<T> {
|
||||
pub fn new(centroids: Vec<T>, data: Vec<T>) -> Kmeans<T> {
|
||||
let labels = Kmeans::build_labels(¢roids, &data);
|
||||
let cluster_number = centroids.len();
|
||||
Kmeans {
|
||||
centroids: centroids,
|
||||
elements: data,
|
||||
labels: labels,
|
||||
cluster_number: cluster_number
|
||||
}
|
||||
}
|
||||
|
||||
/// \returns True if converged
|
||||
pub fn iterate(&mut self) -> bool {
|
||||
// Update the centroids
|
||||
let centroids = Kmeans::build_centroids(&self.elements, &self.labels, self.cluster_number);
|
||||
|
||||
if self.centroids == centroids {
|
||||
true
|
||||
} else {
|
||||
self.centroids = centroids;
|
||||
Kmeans::update_labels(&self.centroids, &self.elements, &mut self.labels);
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_labels(centroids: &Vec<T>, data: &Vec<T>) -> Vec<usize> {
|
||||
debug_assert_ne!(0, centroids.len());
|
||||
|
||||
let mut output = vec![0; data.len()];
|
||||
Kmeans::update_labels(centroids, data, &mut output);
|
||||
output
|
||||
}
|
||||
|
||||
pub fn update_labels(centroids: &Vec<T>, data: &Vec<T>, labels: &mut Vec<usize>) {
|
||||
for (element, new_label) in data.iter().zip(labels.iter_mut()) {
|
||||
*new_label = centroids
|
||||
.iter()
|
||||
.enumerate()
|
||||
.min_by(|&(_, c1), &(_, c2)| {
|
||||
c1.distance(element).partial_cmp(&c2.distance(element)).unwrap()
|
||||
}).unwrap().0;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_centroids(data: &Vec<T>, labels: &Vec<usize>, cluster_number: usize) -> Vec<T> {
|
||||
|
||||
let mut centroids = vec![];
|
||||
for label in 0..cluster_number {
|
||||
|
||||
let to_consider = data
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|&(index, _)| labels[index] == label)
|
||||
.map(|(_, element)| element);
|
||||
|
||||
if let Some(centroid) = T::get_centroid(to_consider) {
|
||||
centroids.push(centroid);
|
||||
}
|
||||
}
|
||||
|
||||
centroids
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> Zip<Iter<T>, Iter<usize>> {
|
||||
debug_assert_eq!(self.elements.len(), self.labels.len());
|
||||
return self.elements.iter().zip(self.labels.iter());
|
||||
}
|
||||
|
||||
pub fn into_iter(self) -> Zip<IntoIter<T>, IntoIter<usize>> {
|
||||
debug_assert_eq!(self.elements.len(), self.labels.len());
|
||||
return self.elements.into_iter().zip(self.labels.into_iter());
|
||||
}
|
||||
|
||||
pub fn into_vec_vec(self) -> Vec<Vec<T>> {
|
||||
let mut map = HashMap::new();
|
||||
|
||||
for (element, label) in self.into_iter() {
|
||||
let mut entry = map.entry(label).or_insert(vec![]);
|
||||
entry.push(element);
|
||||
}
|
||||
|
||||
let mut output = vec![];
|
||||
|
||||
for (_, cluster) in map {
|
||||
let mut vec = vec![];
|
||||
for element in cluster {
|
||||
vec.push(element);
|
||||
}
|
||||
output.push(vec);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub fn to_vec_vec(&self) -> Vec<(Vec<T>, usize)> {
|
||||
let mut map = HashMap::new();
|
||||
|
||||
for (element, label) in self.iter() {
|
||||
let mut entry = map.entry(*label).or_insert(vec![]);
|
||||
entry.push(element.clone());
|
||||
}
|
||||
|
||||
let mut output = vec![];
|
||||
|
||||
for (label, cluster) in map {
|
||||
let mut vec = vec![];
|
||||
for element in cluster {
|
||||
vec.push(element);
|
||||
}
|
||||
output.push((vec, label));
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
pub enum Error {
|
||||
IterationsLimitExceeded,
|
||||
}
|
||||
|
||||
pub fn kmeans<T: Clusterable>(centroids: Vec<T>, data: Vec<T>, max_iterations: usize)
|
||||
-> Result<(Kmeans<T>, usize), Error> {
|
||||
|
||||
let mut kmeans = Kmeans::new(centroids, data);
|
||||
|
||||
for nb_iterations in 0..max_iterations {
|
||||
|
||||
let stable = kmeans.iterate();
|
||||
|
||||
if stable {
|
||||
return Ok((kmeans, nb_iterations));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Err(Error::IterationsLimitExceeded)
|
||||
}
|
||||
|
||||
pub fn equal_kmeans<T: Clusterable>(centroids: Vec<T>, data: Vec<T>, max_iterations: usize)
|
||||
-> Result<(Kmeans<T>, usize), Error> {
|
||||
|
||||
let number_of_elements = data.len();
|
||||
let (kmeans, nb_iterations) = kmeans(centroids, data, max_iterations)?;
|
||||
let mut clusters = kmeans.to_vec_vec();
|
||||
|
||||
for &mut (ref mut cluster, ref mut label) in &mut clusters {
|
||||
cluster.sort_by(|ref e1, ref e2| {
|
||||
*&e1.distance(&kmeans.centroids[*label])
|
||||
.partial_cmp(&e2.distance(&kmeans.centroids[*label])).unwrap()
|
||||
});
|
||||
}
|
||||
|
||||
let max = (number_of_elements as f64 / clusters.len() as f64).ceil() as usize;
|
||||
|
||||
let mut to_replace = vec![];
|
||||
for &mut (ref mut cluster, _) in &mut clusters {
|
||||
while cluster.len() > max {
|
||||
to_replace.push(cluster.pop().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
for element in to_replace {
|
||||
// Find the best non full cluster to place it in
|
||||
let mut best_distance = std::f64::MAX;
|
||||
let mut best_index = 0;
|
||||
|
||||
for (index, centroid) in kmeans.centroids.iter().enumerate() {
|
||||
|
||||
let new_distance = element.distance(centroid);
|
||||
if new_distance < best_distance && clusters[index].0.len() < max {
|
||||
best_distance = new_distance;
|
||||
best_index = index;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
clusters[best_index].0.push(element);
|
||||
}
|
||||
|
||||
let mut elements = vec![];
|
||||
let mut labels = vec![];
|
||||
let mut centroids = vec![];
|
||||
|
||||
for (index, &(ref cluster, _)) in clusters.iter().enumerate() {
|
||||
for element in cluster {
|
||||
elements.push(element.clone());
|
||||
labels.push(index);
|
||||
}
|
||||
centroids.push(Clusterable::get_centroid(cluster.iter()).unwrap());
|
||||
}
|
||||
|
||||
let len = centroids.len();
|
||||
|
||||
// Build a new k-means from clusters
|
||||
let new_kmeans = Kmeans {
|
||||
centroids: centroids,
|
||||
elements: elements,
|
||||
labels: labels,
|
||||
cluster_number: len,
|
||||
};
|
||||
|
||||
Ok((new_kmeans, nb_iterations))
|
||||
|
||||
}
|
136
src/lib.rs
136
src/lib.rs
|
@ -1,10 +1,132 @@
|
|||
pub mod kmeans;
|
||||
pub mod clusterable;
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub use kmeans::Kmeans;
|
||||
pub use clusterable::Clusterable;
|
||||
pub use kmeans::Error;
|
||||
pub use kmeans::kmeans;
|
||||
pub use kmeans::equal_kmeans;
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
TooManyIterations,
|
||||
}
|
||||
|
||||
type Distance<T> = Fn(&T, &T) -> f64;
|
||||
type Centroid<T> = Fn(&Vec<T>) -> T;
|
||||
|
||||
pub struct KmeansData<T: Clone + PartialEq> {
|
||||
pub elements: Vec<T>,
|
||||
pub labels: Vec<usize>,
|
||||
}
|
||||
|
||||
pub struct Kmeans<'a, T: 'a + Clone + PartialEq> {
|
||||
data: KmeansData<T>,
|
||||
centroids: Vec<T>,
|
||||
distance: &'a Distance<T>,
|
||||
centroid: &'a Centroid<T>,
|
||||
}
|
||||
|
||||
impl<'a, T: 'a + Clone + PartialEq> Kmeans<'a, T> {
|
||||
pub fn new(data: Vec<T>, centroids: Vec<T>, distance: &'a Distance<T>, centroid: &'a Centroid<T>) -> Kmeans<'a, T> {
|
||||
|
||||
let len = data.len();
|
||||
Kmeans {
|
||||
data: KmeansData { elements: data, labels: vec![0; len] },
|
||||
centroids: centroids,
|
||||
distance: distance,
|
||||
centroid: centroid,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// True if converged
|
||||
pub fn iterate(&mut self) -> bool {
|
||||
// Update the labels
|
||||
self.update_labels();
|
||||
|
||||
// Update the centroids
|
||||
let new_centroids = self.compute_centroids();
|
||||
|
||||
let ret = new_centroids == self.centroids;
|
||||
self.centroids = new_centroids;
|
||||
ret
|
||||
}
|
||||
|
||||
fn update_labels(&mut self) {
|
||||
|
||||
let iterator = self.data.elements.iter().zip(self.data.labels.iter_mut());
|
||||
|
||||
for (ref element, ref mut label) in iterator {
|
||||
|
||||
let mut best_distance = std::f64::MAX;
|
||||
let mut best_label = 0;
|
||||
|
||||
for (index, centroid) in self.centroids.iter().enumerate() {
|
||||
|
||||
let new_distance = (self.distance)(*element, centroid);
|
||||
if new_distance < best_distance {
|
||||
best_distance = new_distance;
|
||||
best_label = index;
|
||||
}
|
||||
}
|
||||
|
||||
**label = best_label;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
fn compute_centroids(&self) -> Vec<T> {
|
||||
let mut centroids_map = HashMap::new();
|
||||
|
||||
for (element, label) in self.data.elements.iter().zip(self.data.labels.iter()) {
|
||||
let mut centroid = centroids_map.entry(label).or_insert(vec![]);
|
||||
centroid.push(element.clone());
|
||||
}
|
||||
|
||||
let mut new_centroids = vec![];
|
||||
|
||||
for (_, value) in centroids_map {
|
||||
new_centroids.push((self.centroid)(&value));
|
||||
}
|
||||
|
||||
new_centroids
|
||||
}
|
||||
}
|
||||
|
||||
pub fn kmeans<'a, T: 'a + Clone + PartialEq>(
|
||||
elements: Vec<T>,
|
||||
initial: Vec<T>,
|
||||
distance: &'a Distance<T>,
|
||||
centroid: &'a Centroid<T>,
|
||||
max_iteration: usize,
|
||||
) -> Result<(Vec<Vec<T>>,usize), Error> {
|
||||
|
||||
let mut clusters = Kmeans::new(elements, initial, distance, centroid);
|
||||
|
||||
let mut counter = 0;
|
||||
|
||||
let iterations = loop {
|
||||
|
||||
counter += 1;
|
||||
|
||||
if clusters.iterate() || counter > max_iteration {
|
||||
break counter;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
let mut map = HashMap::new();
|
||||
|
||||
for (element, label) in clusters.data.elements.iter().zip(clusters.data.labels.iter()) {
|
||||
let mut centroid = map.entry(label).or_insert(vec![]);
|
||||
centroid.push(element.clone());
|
||||
}
|
||||
|
||||
let mut output = vec![];
|
||||
|
||||
for (_, value) in map {
|
||||
let mut cluster = vec![];
|
||||
|
||||
for element in value {
|
||||
cluster.push(element);
|
||||
}
|
||||
|
||||
output.push(cluster);
|
||||
}
|
||||
|
||||
Ok((output, iterations))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue