Update, clean, test, example

This commit is contained in:
Thomas Forgione 2018-02-16 17:21:05 +01:00
parent c3aff5c77d
commit 14455d3693
No known key found for this signature in database
GPG Key ID: C75CD416BD1FFCE1
7 changed files with 387 additions and 197 deletions

View File

@ -1,6 +1,11 @@
[package]
name = "kmeans"
name = "generic_kmeans"
version = "0.1.0"
authors = ["Thomas Forgione <thomas@tforgione.fr>"]
[dependencies]
[[bin]]
name = "example"
path = "src/example.rs"

38
src/cluster.rs Normal file
View File

@ -0,0 +1,38 @@
pub type Cluster<T> = Vec<T>;
pub trait Clusterable where Self: Sized + Clone + PartialEq {
fn distance(&self, rhs: &Self) -> f64;
fn get_centroid(elements: &Vec<Self>) -> Option<Self>;
}
macro_rules! impl_clusterable {
( $type: ty) => {
impl Clusterable for $type {
fn distance(&self, rhs: &Self) -> f64 {
if self > rhs {
*self as f64 - *rhs as f64
} else {
*rhs as f64 - *self as f64
}
}
fn get_centroid(elements: &Vec<Self>) -> Option<Self> {
if elements.len() == 0 {
return None;
}
let mut tmp = 0.0 as Self;
for element in elements {
tmp += *element as Self;
}
Some(tmp / elements.len() as Self)
}
}
}
}
impl_clusterable!(f32);
impl_clusterable!(f64);

82
src/example.rs Normal file
View File

@ -0,0 +1,82 @@
extern crate generic_kmeans;
use std::fmt::{Display, Formatter, Result};
use generic_kmeans::{kmeans, Clusterable};
#[derive(PartialEq, Clone)]
struct Vector2<T> {
pub x: T,
pub y: T,
}
impl<T> Vector2<T> {
pub fn new(x: T, y: T) -> Vector2<T> {
Vector2 {
x: x,
y: y,
}
}
}
impl<T: Display> Display for Vector2<T> {
fn fmt(&self, formatter: &mut Formatter) -> Result {
write!(formatter, "({}, {})", self.x, self.y)
}
}
impl Clusterable for Vector2<f64> {
fn distance(&self, other: &Self) -> f64 {
(self.x - other.x) * (self.x - other.y) + (self.y - other.y) * (self.y - other.y)
}
fn get_centroid(cluster: &Vec<Vector2<f64>>) -> Option<Vector2<f64>> {
let len = cluster.len();
if len == 0 {
return None;
}
let mut centroid = Vector2::new(0.0, 0.0);
for i in cluster {
centroid.x += i.x;
centroid.y += i.y;
}
centroid.x /= len as f64;
centroid.y /= len as f64;
Some(centroid)
}
}
fn main() {
let elements = vec![
Vector2::new(8.0, 3.0),
Vector2::new(9.0, 3.0),
Vector2::new(9.0, 2.0),
Vector2::new(1.0, 8.0),
Vector2::new(2.0, 9.0),
Vector2::new(3.0, 8.0),
];
let initial = vec![
Vector2::new(1.0, 10.0),
Vector2::new(10.0, 0.0),
];
let (clusters, nb_iterations) = kmeans(initial, elements, 1000).ok().unwrap();
println!("{}", nb_iterations);
for (index, cluster) in clusters.iter().enumerate() {
println!("CLUSTER {}", index);
for element in cluster {
println!("\t{}", element);
}
}
}

48
src/kmeans.rs Normal file
View File

@ -0,0 +1,48 @@
use std;
use kmeansdata::KmeansData;
use cluster::{Cluster, Clusterable};
pub struct Kmeans<T: Clusterable> {
pub centroids: Vec<T>,
pub data: KmeansData<T>,
}
impl<T:Clusterable> Kmeans<T> {
pub fn new(centroids: Vec<T>, data: Vec<Vec<T>>) -> Kmeans<T> {
Kmeans {
centroids: centroids,
data: KmeansData::from_clusters(data),
}
}
pub fn guess_centroids(data: Vec<Vec<T>>) -> Kmeans<T> {
let mut centroids = vec![];
for cluster in &data {
if let Some(centroid) = T::get_centroid(cluster) {
centroids.push(centroid);
}
}
Kmeans::new(centroids, data)
}
fn from_data(centroids: Vec<T>, data: KmeansData<T>) -> Kmeans<T> {
Kmeans {
centroids: centroids,
data: data,
}
}
pub fn next_iteration(self) -> (Kmeans<T>, bool) {
let (new_centroids, data) = self.data.iterate(&self.centroids);
let stable = new_centroids == self.centroids;
(Kmeans::from_data(new_centroids, data), stable)
}
pub fn iter(&self) -> std::slice::Iter<Cluster<T>> {
self.data.clusters()
}
}

168
src/kmeansdata.rs Normal file
View File

@ -0,0 +1,168 @@
use std;
use std::marker::PhantomData;
use cluster::{Clusterable, Cluster};
pub struct KmeansData<T: Clusterable> {
clusters: Vec<Vec<T>>,
}
impl<T: Clusterable> KmeansData<T> {
pub fn from_clusters(data: Vec<Vec<T>>) -> KmeansData<T> {
KmeansData {
clusters: data,
}
}
pub fn new(centroids: &Vec<T>, data: KmeansDataIntoIter<T>) -> KmeansData<T> {
let mut clusters = vec![];
for _ in centroids {
clusters.push(Cluster::new());
}
let mut new_kmeans: KmeansData<T> = KmeansData {
clusters: clusters,
};
for element in data {
// Compute the distance
let mut distance = std::f64::MAX;
let mut index = 0;
for (new_index, centroid) in centroids.iter().enumerate() {
let new_distance = element.distance(centroid);
if new_distance < distance {
distance = new_distance;
index = new_index;
}
}
// Add element to the new kmeans
new_kmeans.clusters[index].push(element)
}
new_kmeans
}
pub fn add_cluster(&mut self) {
self.clusters.push(Cluster::new());
}
pub fn iterate(self, centroids: &Vec<T>) -> (Vec<T>, KmeansData<T>) {
// Compute the result with the given centroids
let result = KmeansData::new(&centroids, self.into_iter());
// Compute the new centroids
let mut new_centroids = vec![];
for cluster in &result.clusters {
if let Some(centroid) = T::get_centroid(cluster) {
new_centroids.push(centroid);
}
}
(new_centroids, result)
}
pub fn iter(&self) -> KmeansDataIter<T> {
KmeansDataIter {
global_iter: self.clusters.iter(),
local_iter: None,
_phantom: PhantomData,
}
}
pub fn clusters(&self) -> std::slice::Iter<Cluster<T>> {
self.clusters.iter()
}
}
pub struct KmeansDataIter<'a, T> where T:'a, T: Clusterable {
global_iter: std::slice::Iter<'a, std::vec::Vec<T>>,
local_iter: Option<std::slice::Iter<'a, T>>,
_phantom: PhantomData<T>,
}
impl<'a, T: 'a + Clusterable> Iterator for KmeansDataIter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<&'a T> {
if let Some(ref mut local_iter) = self.local_iter {
match local_iter.next() {
Some(t) => Some(t),
None => {
if let Some(next) = self.global_iter.next() {
*local_iter = next.iter();
match local_iter.next() {
Some(t) => Some(t),
None => None,
}
} else {
None
}
}
}
} else {
self.local_iter = match self.global_iter.next() {
None => None,
Some(t) => Some(t.iter()),
};
self.next()
}
}
}
pub struct KmeansDataIntoIter<T> where T: Clusterable {
global_iter: std::vec::IntoIter<std::vec::Vec<T>>,
local_iter: Option<std::vec::IntoIter<T>>,
_phantom: PhantomData<T>,
}
impl<T: Clusterable> Iterator for KmeansDataIntoIter<T> {
type Item = T;
fn next(&mut self) -> Option<T> {
if let Some(ref mut local_iter) = self.local_iter {
match local_iter.next() {
Some(t) => Some(t),
None => {
if let Some(next) = self.global_iter.next() {
*local_iter = next.into_iter();
match local_iter.next() {
Some(t) => Some(t),
None => None,
}
} else {
None
}
}
}
} else {
self.local_iter = match self.global_iter.next() {
None => None,
Some(t) => Some(t.into_iter()),
};
self.next()
}
}
}
impl<T: Clusterable> IntoIterator for KmeansData<T> {
type Item = T;
type IntoIter = KmeansDataIntoIter<T>;
fn into_iter(self) -> Self::IntoIter {
Self::IntoIter {
global_iter: self.clusters.into_iter(),
local_iter: None,
_phantom: PhantomData,
}
}
}

View File

@ -1,190 +1,31 @@
pub mod kmeans;
pub mod kmeansdata;
pub mod cluster;
pub mod test;
use std::marker::PhantomData;
pub use kmeans::Kmeans;
pub use kmeansdata::KmeansData;
pub use cluster::{Cluster, Clusterable};
pub trait Clusterable where Self: Sized {
fn distance(&self, rhs: &Self) -> f64;
fn get_centroid(elements: &Vec<Self>) -> Option<Self>;
pub enum Error {
IterationsLimitExceeded,
}
macro_rules! impl_clusterable {
( $type: ty) => {
impl Clusterable for $type {
fn distance(&self, rhs: &Self) -> f64 {
if self > rhs {
*self as f64 - *rhs as f64
} else {
*rhs as f64 - *self as f64
}
}
pub fn kmeans<T: Clusterable>(centroids: Vec<T>, data: Vec<T>, max_iterations: usize)
-> Result<(Kmeans<T>, usize), Error> {
fn get_centroid(elements: &Vec<Self>) -> Option<Self> {
let mut kmeans = Kmeans::new(centroids, vec![data]);
if elements.len() == 0 {
return None;
}
for nb_iterations in 0..max_iterations {
let mut tmp = 0.0 as Self;
for element in elements {
tmp += *element as Self;
}
Some(tmp / elements.len() as Self)
let (new_kmeans, stable) = kmeans.next_iteration();
kmeans = new_kmeans;
}
}
}
}
impl_clusterable!(f32);
impl_clusterable!(f64);
pub type Cluster<T> = Vec<T>;
pub struct Kmeans<T: Clusterable> {
clusters: Vec<Vec<T>>,
}
impl<T: Clusterable> Kmeans<T> {
pub fn new(centroids: Vec<T>, data: KmeansIntoIter<T>) -> Kmeans<T> {
let mut clusters = vec![];
for _ in &centroids {
clusters.push(Cluster::new());
if stable {
return Ok((kmeans, nb_iterations));
}
let mut new_kmeans: Kmeans<T> = Kmeans {
clusters: clusters,
};
for element in data {
// Compute the distance
let mut distance = std::f64::MAX;
let mut index = 0;
for (new_index, centroid) in centroids.iter().enumerate() {
let new_distance = element.distance(centroid);
if new_distance < distance {
distance = new_distance;
index = new_index;
}
}
// Add element to the new kmeans
new_kmeans.clusters[index].push(element)
}
new_kmeans
}
pub fn add_cluster(&mut self) {
self.clusters.push(Cluster::new());
}
pub fn iterate(self) -> Kmeans<T> {
let mut centroids = vec![];
// Compute the centroids
for cluster in &self.clusters {
if let Some(centroid) = T::get_centroid(cluster) {
centroids.push(centroid);
}
}
Kmeans::new(centroids, self.into_iter())
}
pub fn iter(&self) -> std::slice::Iter<Cluster<T>> {
self.clusters.iter()
}
}
pub struct KmeansIter<'a, T> where T:'a, T: Clusterable {
global_iter: std::slice::Iter<'a, std::vec::Vec<T>>,
local_iter: Option<std::slice::Iter<'a, T>>,
_phantom: PhantomData<T>,
}
impl<'a, T: 'a + Clusterable> Iterator for KmeansIter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<&'a T> {
if let Some(ref mut local_iter) = self.local_iter {
match local_iter.next() {
Some(t) => Some(t),
None => {
if let Some(next) = self.global_iter.next() {
*local_iter = next.iter();
match local_iter.next() {
Some(t) => Some(t),
None => None,
}
} else {
None
}
}
}
} else {
self.local_iter = match self.global_iter.next() {
None => None,
Some(t) => Some(t.iter()),
};
self.next()
}
}
}
pub struct KmeansIntoIter<T> where T: Clusterable {
global_iter: std::vec::IntoIter<std::vec::Vec<T>>,
local_iter: Option<std::vec::IntoIter<T>>,
_phantom: PhantomData<T>,
}
impl<T: Clusterable> Iterator for KmeansIntoIter<T> {
type Item = T;
fn next(&mut self) -> Option<T> {
if let Some(ref mut local_iter) = self.local_iter {
match local_iter.next() {
Some(t) => Some(t),
None => {
if let Some(next) = self.global_iter.next() {
*local_iter = next.into_iter();
match local_iter.next() {
Some(t) => Some(t),
None => None,
}
} else {
None
}
}
}
} else {
self.local_iter = match self.global_iter.next() {
None => None,
Some(t) => Some(t.into_iter()),
};
self.next()
}
}
}
impl<T: Clusterable> IntoIterator for Kmeans<T> {
type Item = T;
type IntoIter = KmeansIntoIter<T>;
fn into_iter(self) -> Self::IntoIter {
Self::IntoIter {
global_iter: self.clusters.into_iter(),
local_iter: None,
_phantom: PhantomData,
}
}
Err(Error::IterationsLimitExceeded)
}

View File

@ -2,14 +2,12 @@
mod test {
#[test]
fn iterators() {
use Kmeans;
use kmeansdata::KmeansData;
let data = vec![4.0, 5.0, 11.0, 12.0, 13.0];
let kmeans = Kmeans {
clusters: vec![
vec![4.0, 5.0],
vec![11.0, 12.0, 13.0],
],
};
let kmeans = KmeansData::from_clusters(vec![
vec![4.0, 5.0],
vec![11.0, 12.0, 13.0],
]);
for (val1, val2) in kmeans.into_iter().zip(data) {
assert_eq!(val1, val2);
@ -18,19 +16,29 @@ mod test {
#[test]
fn iterate() {
use Kmeans;
let kmeans = Kmeans {
clusters: vec![
vec![4.0, 5.0, 11.0, 12.0],
vec![13.0],
],
};
let kmeans = kmeans.iterate();
let kmeans = kmeans.iterate();
let kmeans = kmeans.iterate();
let kmeans = kmeans.iterate();
use kmeans::Kmeans;
assert_eq!(kmeans.clusters[0], vec![4.0, 5.0]);
assert_eq!(kmeans.clusters[1], vec![11.0, 12.0, 13.0]);
let data = vec![
vec![4.0, 5.0, 11.0, 12.0],
vec![13.0],
];
let solution = vec![
vec![4.0, 5.0],
vec![11.0, 12.0, 13.0],
];
let mut kmeans = Kmeans::guess_centroids(data.clone());
for _ in 0..4 {
let (new_kmeans, stable) = kmeans.next_iteration();
kmeans = new_kmeans;
}
for (k1, k2) in kmeans.iter().zip(solution) {
for (i, j) in k1.iter().zip(&k2) {
assert_eq!(i, j);
}
}
}
}