This commit is contained in:
Thomas Forgione
2018-02-20 14:54:40 +01:00
parent 2d0ec8ceae
commit 164d15c6e7
13 changed files with 134 additions and 626 deletions

View File

@@ -1,38 +0,0 @@
pub type Cluster<T> = Vec<T>;
pub trait Clusterable where Self: Sized + Clone + PartialEq {
fn distance(&self, rhs: &Self) -> f64;
fn get_centroid(elements: &Vec<Self>) -> Option<Self>;
}
macro_rules! impl_clusterable {
( $type: ty) => {
impl Clusterable for $type {
fn distance(&self, rhs: &Self) -> f64 {
if self > rhs {
*self as f64 - *rhs as f64
} else {
*rhs as f64 - *self as f64
}
}
fn get_centroid(elements: &Vec<Self>) -> Option<Self> {
if elements.len() == 0 {
return None;
}
let mut tmp = 0.0 as Self;
for element in elements {
tmp += *element as Self;
}
Some(tmp / elements.len() as Self)
}
}
}
}
impl_clusterable!(f32);
impl_clusterable!(f64);

43
src/clusterable.rs Normal file
View File

@@ -0,0 +1,43 @@
use std::fmt::Debug;
pub trait Clusterable where Self: Sized + Clone + PartialEq + Debug {
fn distance(&self, rhs: &Self) -> f64;
fn get_centroid<'a, I>(elements: I) -> Option<Self>
where I: Iterator<Item = &'a Self>, Self: 'a;
}
macro_rules! impl_clusterable {
( $type: ty) => {
impl Clusterable for $type {
fn distance(&self, rhs: &Self) -> f64 {
if self > rhs {
*self as f64 - *rhs as f64
} else {
*rhs as f64 - *self as f64
}
}
fn get_centroid<'a, I>(elements: I) -> Option<Self>
where I: Iterator<Item = &'a Self>, Self: 'a {
let mut tmp = 0.0;
let mut count = 0.0;
for element in elements {
tmp += element;
count += 1.0;
}
if count > 0.0 {
Some(tmp / count)
} else {
None
}
}
}
}
}
impl_clusterable!(f32);
impl_clusterable!(f64);

View File

@@ -7,7 +7,7 @@ use rand::distributions::Range;
use rand::distributions::normal::Normal;
use generic_kmeans::{kmeans, Clusterable};
#[derive(PartialEq, Copy, Clone)]
#[derive(PartialEq, Copy, Clone, Debug)]
struct Vector2<T> {
pub x: T,
pub y: T,
@@ -30,28 +30,28 @@ impl<T: Display> Display for Vector2<T> {
impl Clusterable for Vector2<f64> {
fn distance(&self, other: &Self) -> f64 {
(self.x - other.x) * (self.x - other.y) + (self.y - other.y) * (self.y - other.y)
(self.x - other.x) * (self.x - other.x) + (self.y - other.y) * (self.y - other.y)
}
fn get_centroid(cluster: &Vec<Vector2<f64>>) -> Option<Vector2<f64>> {
let len = cluster.len();
if len == 0 {
return None;
}
fn get_centroid<'a, I>(cluster: I) -> Option<Self>
where I: Iterator<Item = &'a Self>, Self: 'a {
let mut centroid = Vector2::new(0.0, 0.0);
let mut count = 0.0;
for i in cluster {
centroid.x += i.x;
centroid.y += i.y;
count += 1.0;
}
centroid.x /= len as f64;
centroid.y /= len as f64;
Some(centroid)
if count > 0.0 {
centroid.x /= count as f64;
centroid.y /= count as f64;
Some(centroid)
} else {
None
}
}
}
@@ -85,21 +85,24 @@ fn main() {
}
let initialization = vec![
Vector2::new(0.0,0.0),
Vector2::new(10.0,0.0),
Vector2::new(0.0,10.0),
];
let (clusters, nb_iterations) = kmeans(
centers.iter().map(|x| x.clone().0).collect::<Vec<_>>(), elements, 100000).ok().unwrap();
println!("{}", nb_iterations);
let (clusters, nb_iterations) = kmeans(initialization, elements, 100000).ok().unwrap();
let mut output = File::create("plot/dat.dat").unwrap();
for (cluster, color) in clusters.iter().zip(&colors) {
for element in cluster {
use std::io::Write;
writeln!(output, "{} {} {}", element.x, element.y, color).unwrap();
}
for (element, &label) in clusters.iter() {
use std::io::Write;
writeln!(output, "{} {} {}", element.x, element.y, colors[label]).unwrap();
}
println!("Finished in {} iterations", nb_iterations);
let mut center_file = File::create("plot/centers.dat").unwrap();
for (&(center, _, _), color) in centers.iter().zip(&colors) {
use std::io::Write;

View File

@@ -1,48 +1,82 @@
use std;
use kmeansdata::KmeansData;
use cluster::{Cluster, Clusterable};
use std::slice::Iter;
use std::iter::Zip;
use clusterable::Clusterable;
pub struct Kmeans<T: Clusterable> {
pub centroids: Vec<T>,
pub data: KmeansData<T>,
centroids: Vec<T>,
elements: Vec<T>,
labels: Vec<usize>,
cluster_number: usize,
}
impl<T:Clusterable> Kmeans<T> {
pub fn new(centroids: Vec<T>, data: Vec<Vec<T>>) -> Kmeans<T> {
pub fn new(centroids: Vec<T>, data: Vec<T>) -> Kmeans<T> {
let labels = Kmeans::build_labels(&centroids, &data);
let cluster_number = centroids.len();
Kmeans {
centroids: centroids,
data: KmeansData::from_clusters(data),
elements: data,
labels: labels,
cluster_number: cluster_number
}
}
pub fn guess_centroids(data: Vec<Vec<T>>) -> Kmeans<T> {
/// \returns True if converged
pub fn iterate(&mut self) -> bool {
// Update the centroids
let centroids = Kmeans::build_centroids(&self.elements, &self.labels, self.cluster_number);
if self.centroids == centroids {
true
} else {
self.centroids = centroids;
Kmeans::update_labels(&self.centroids, &self.elements, &mut self.labels);
false
}
}
pub fn build_labels(centroids: &Vec<T>, data: &Vec<T>) -> Vec<usize> {
debug_assert_ne!(0, centroids.len());
let mut output = vec![0; data.len()];
Kmeans::update_labels(centroids, data, &mut output);
output
}
pub fn update_labels(centroids: &Vec<T>, data: &Vec<T>, labels: &mut Vec<usize>) {
for (element, new_label) in data.iter().zip(labels.iter_mut()) {
*new_label = centroids
.iter()
.enumerate()
.min_by(|&(_, c1), &(_, c2)| {
c1.distance(element).partial_cmp(&c2.distance(element)).unwrap()
}).unwrap().0;
}
}
pub fn build_centroids(data: &Vec<T>, labels: &Vec<usize>, cluster_number: usize) -> Vec<T> {
let mut centroids = vec![];
for cluster in &data {
if let Some(centroid) = T::get_centroid(cluster) {
centroids.push(centroid);
for label in 0..cluster_number {
let to_consider = data
.iter()
.enumerate()
.filter(|&(index, _)| labels[index] == label)
.map(|(_, element)| element);
if let Some(centroid) = T::get_centroid(to_consider) {
centroids.push(centroid);
}
}
Kmeans::new(centroids, data)
centroids
}
fn from_data(centroids: Vec<T>, data: KmeansData<T>) -> Kmeans<T> {
Kmeans {
centroids: centroids,
data: data,
}
pub fn iter(&self) -> Zip<Iter<T>, Iter<usize>> {
debug_assert_eq!(self.elements.len(), self.labels.len());
return self.elements.iter().zip(self.labels.iter());
}
pub fn next_iteration(self) -> (Kmeans<T>, bool) {
let (new_centroids, data) = self.data.iterate(&self.centroids);
let stable = new_centroids == self.centroids;
(Kmeans::from_data(new_centroids, data), stable)
}
pub fn iter(&self) -> std::slice::Iter<Cluster<T>> {
self.data.clusters()
}
}

View File

@@ -1,168 +0,0 @@
use std;
use std::marker::PhantomData;
use cluster::{Clusterable, Cluster};
pub struct KmeansData<T: Clusterable> {
clusters: Vec<Vec<T>>,
}
impl<T: Clusterable> KmeansData<T> {
pub fn from_clusters(data: Vec<Vec<T>>) -> KmeansData<T> {
KmeansData {
clusters: data,
}
}
pub fn new(centroids: &Vec<T>, data: KmeansDataIntoIter<T>) -> KmeansData<T> {
let mut clusters = vec![];
for _ in centroids {
clusters.push(Cluster::new());
}
let mut new_kmeans: KmeansData<T> = KmeansData {
clusters: clusters,
};
for element in data {
// Compute the distance
let mut distance = std::f64::MAX;
let mut index = 0;
for (new_index, centroid) in centroids.iter().enumerate() {
let new_distance = element.distance(centroid);
if new_distance < distance {
distance = new_distance;
index = new_index;
}
}
// Add element to the new kmeans
new_kmeans.clusters[index].push(element)
}
new_kmeans
}
pub fn add_cluster(&mut self) {
self.clusters.push(Cluster::new());
}
pub fn iterate(self, centroids: &Vec<T>) -> (Vec<T>, KmeansData<T>) {
// Compute the result with the given centroids
let result = KmeansData::new(&centroids, self.into_iter());
// Compute the new centroids
let mut new_centroids = vec![];
for cluster in &result.clusters {
if let Some(centroid) = T::get_centroid(cluster) {
new_centroids.push(centroid);
}
}
(new_centroids, result)
}
pub fn iter(&self) -> KmeansDataIter<T> {
KmeansDataIter {
global_iter: self.clusters.iter(),
local_iter: None,
_phantom: PhantomData,
}
}
pub fn clusters(&self) -> std::slice::Iter<Cluster<T>> {
self.clusters.iter()
}
}
pub struct KmeansDataIter<'a, T> where T:'a, T: Clusterable {
global_iter: std::slice::Iter<'a, std::vec::Vec<T>>,
local_iter: Option<std::slice::Iter<'a, T>>,
_phantom: PhantomData<T>,
}
impl<'a, T: 'a + Clusterable> Iterator for KmeansDataIter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<&'a T> {
if let Some(ref mut local_iter) = self.local_iter {
match local_iter.next() {
Some(t) => Some(t),
None => {
if let Some(next) = self.global_iter.next() {
*local_iter = next.iter();
match local_iter.next() {
Some(t) => Some(t),
None => None,
}
} else {
None
}
}
}
} else {
self.local_iter = match self.global_iter.next() {
None => None,
Some(t) => Some(t.iter()),
};
self.next()
}
}
}
pub struct KmeansDataIntoIter<T> where T: Clusterable {
global_iter: std::vec::IntoIter<std::vec::Vec<T>>,
local_iter: Option<std::vec::IntoIter<T>>,
_phantom: PhantomData<T>,
}
impl<T: Clusterable> Iterator for KmeansDataIntoIter<T> {
type Item = T;
fn next(&mut self) -> Option<T> {
if let Some(ref mut local_iter) = self.local_iter {
match local_iter.next() {
Some(t) => Some(t),
None => {
if let Some(next) = self.global_iter.next() {
*local_iter = next.into_iter();
match local_iter.next() {
Some(t) => Some(t),
None => None,
}
} else {
None
}
}
}
} else {
self.local_iter = match self.global_iter.next() {
None => None,
Some(t) => Some(t.into_iter()),
};
self.next()
}
}
}
impl<T: Clusterable> IntoIterator for KmeansData<T> {
type Item = T;
type IntoIter = KmeansDataIntoIter<T>;
fn into_iter(self) -> Self::IntoIter {
Self::IntoIter {
global_iter: self.clusters.into_iter(),
local_iter: None,
_phantom: PhantomData,
}
}
}

View File

@@ -1,11 +1,8 @@
pub mod kmeans;
pub mod kmeansdata;
pub mod cluster;
pub mod test;
pub mod clusterable;
pub use kmeans::Kmeans;
pub use kmeansdata::KmeansData;
pub use cluster::{Cluster, Clusterable};
pub use clusterable::Clusterable;
pub enum Error {
IterationsLimitExceeded,
@@ -14,12 +11,11 @@ pub enum Error {
pub fn kmeans<T: Clusterable>(centroids: Vec<T>, data: Vec<T>, max_iterations: usize)
-> Result<(Kmeans<T>, usize), Error> {
let mut kmeans = Kmeans::new(centroids, vec![data]);
let mut kmeans = Kmeans::new(centroids, data);
for nb_iterations in 0..max_iterations {
let (new_kmeans, stable) = kmeans.next_iteration();
kmeans = new_kmeans;
let stable = kmeans.iterate();
if stable {
return Ok((kmeans, nb_iterations));

View File

@@ -1,44 +0,0 @@
#[cfg(test)]
mod test {
#[test]
fn iterators() {
use kmeansdata::KmeansData;
let data = vec![4.0, 5.0, 11.0, 12.0, 13.0];
let kmeans = KmeansData::from_clusters(vec![
vec![4.0, 5.0],
vec![11.0, 12.0, 13.0],
]);
for (val1, val2) in kmeans.into_iter().zip(data) {
assert_eq!(val1, val2);
}
}
#[test]
fn iterate() {
use kmeans::Kmeans;
let data = vec![
vec![4.0, 5.0, 11.0, 12.0],
vec![13.0],
];
let solution = vec![
vec![4.0, 5.0],
vec![11.0, 12.0, 13.0],
];
let mut kmeans = Kmeans::guess_centroids(data.clone());
for _ in 0..4 {
let (new_kmeans, stable) = kmeans.next_iteration();
kmeans = new_kmeans;
}
for (k1, k2) in kmeans.iter().zip(solution) {
for (i, j) in k1.iter().zip(&k2) {
assert_eq!(i, j);
}
}
}
}