First try

This commit is contained in:
Thomas Forgione 2018-02-16 15:34:13 +01:00
commit c3aff5c77d
No known key found for this signature in database
GPG Key ID: C75CD416BD1FFCE1
4 changed files with 236 additions and 0 deletions

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
/target/
**/*.rs.bk
Cargo.lock

6
Cargo.toml Normal file
View File

@ -0,0 +1,6 @@
[package]
name = "kmeans"
version = "0.1.0"
authors = ["Thomas Forgione <thomas@tforgione.fr>"]
[dependencies]

190
src/lib.rs Normal file
View File

@ -0,0 +1,190 @@
pub mod test;
use std::marker::PhantomData;
pub trait Clusterable where Self: Sized {
fn distance(&self, rhs: &Self) -> f64;
fn get_centroid(elements: &Vec<Self>) -> Option<Self>;
}
macro_rules! impl_clusterable {
( $type: ty) => {
impl Clusterable for $type {
fn distance(&self, rhs: &Self) -> f64 {
if self > rhs {
*self as f64 - *rhs as f64
} else {
*rhs as f64 - *self as f64
}
}
fn get_centroid(elements: &Vec<Self>) -> Option<Self> {
if elements.len() == 0 {
return None;
}
let mut tmp = 0.0 as Self;
for element in elements {
tmp += *element as Self;
}
Some(tmp / elements.len() as Self)
}
}
}
}
impl_clusterable!(f32);
impl_clusterable!(f64);
pub type Cluster<T> = Vec<T>;
pub struct Kmeans<T: Clusterable> {
clusters: Vec<Vec<T>>,
}
impl<T: Clusterable> Kmeans<T> {
pub fn new(centroids: Vec<T>, data: KmeansIntoIter<T>) -> Kmeans<T> {
let mut clusters = vec![];
for _ in &centroids {
clusters.push(Cluster::new());
}
let mut new_kmeans: Kmeans<T> = Kmeans {
clusters: clusters,
};
for element in data {
// Compute the distance
let mut distance = std::f64::MAX;
let mut index = 0;
for (new_index, centroid) in centroids.iter().enumerate() {
let new_distance = element.distance(centroid);
if new_distance < distance {
distance = new_distance;
index = new_index;
}
}
// Add element to the new kmeans
new_kmeans.clusters[index].push(element)
}
new_kmeans
}
pub fn add_cluster(&mut self) {
self.clusters.push(Cluster::new());
}
pub fn iterate(self) -> Kmeans<T> {
let mut centroids = vec![];
// Compute the centroids
for cluster in &self.clusters {
if let Some(centroid) = T::get_centroid(cluster) {
centroids.push(centroid);
}
}
Kmeans::new(centroids, self.into_iter())
}
pub fn iter(&self) -> std::slice::Iter<Cluster<T>> {
self.clusters.iter()
}
}
pub struct KmeansIter<'a, T> where T:'a, T: Clusterable {
global_iter: std::slice::Iter<'a, std::vec::Vec<T>>,
local_iter: Option<std::slice::Iter<'a, T>>,
_phantom: PhantomData<T>,
}
impl<'a, T: 'a + Clusterable> Iterator for KmeansIter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<&'a T> {
if let Some(ref mut local_iter) = self.local_iter {
match local_iter.next() {
Some(t) => Some(t),
None => {
if let Some(next) = self.global_iter.next() {
*local_iter = next.iter();
match local_iter.next() {
Some(t) => Some(t),
None => None,
}
} else {
None
}
}
}
} else {
self.local_iter = match self.global_iter.next() {
None => None,
Some(t) => Some(t.iter()),
};
self.next()
}
}
}
pub struct KmeansIntoIter<T> where T: Clusterable {
global_iter: std::vec::IntoIter<std::vec::Vec<T>>,
local_iter: Option<std::vec::IntoIter<T>>,
_phantom: PhantomData<T>,
}
impl<T: Clusterable> Iterator for KmeansIntoIter<T> {
type Item = T;
fn next(&mut self) -> Option<T> {
if let Some(ref mut local_iter) = self.local_iter {
match local_iter.next() {
Some(t) => Some(t),
None => {
if let Some(next) = self.global_iter.next() {
*local_iter = next.into_iter();
match local_iter.next() {
Some(t) => Some(t),
None => None,
}
} else {
None
}
}
}
} else {
self.local_iter = match self.global_iter.next() {
None => None,
Some(t) => Some(t.into_iter()),
};
self.next()
}
}
}
impl<T: Clusterable> IntoIterator for Kmeans<T> {
type Item = T;
type IntoIter = KmeansIntoIter<T>;
fn into_iter(self) -> Self::IntoIter {
Self::IntoIter {
global_iter: self.clusters.into_iter(),
local_iter: None,
_phantom: PhantomData,
}
}
}

36
src/test.rs Normal file
View File

@ -0,0 +1,36 @@
#[cfg(test)]
mod test {
#[test]
fn iterators() {
use Kmeans;
let data = vec![4.0, 5.0, 11.0, 12.0, 13.0];
let kmeans = Kmeans {
clusters: vec![
vec![4.0, 5.0],
vec![11.0, 12.0, 13.0],
],
};
for (val1, val2) in kmeans.into_iter().zip(data) {
assert_eq!(val1, val2);
}
}
#[test]
fn iterate() {
use Kmeans;
let kmeans = Kmeans {
clusters: vec![
vec![4.0, 5.0, 11.0, 12.0],
vec![13.0],
],
};
let kmeans = kmeans.iterate();
let kmeans = kmeans.iterate();
let kmeans = kmeans.iterate();
let kmeans = kmeans.iterate();
assert_eq!(kmeans.clusters[0], vec![4.0, 5.0]);
assert_eq!(kmeans.clusters[1], vec![11.0, 12.0, 13.0]);
}
}