First try
This commit is contained in:
commit
c3aff5c77d
|
@ -0,0 +1,4 @@
|
|||
|
||||
/target/
|
||||
**/*.rs.bk
|
||||
Cargo.lock
|
|
@ -0,0 +1,6 @@
|
|||
[package]
|
||||
name = "kmeans"
|
||||
version = "0.1.0"
|
||||
authors = ["Thomas Forgione <thomas@tforgione.fr>"]
|
||||
|
||||
[dependencies]
|
|
@ -0,0 +1,190 @@
|
|||
pub mod test;
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
pub trait Clusterable where Self: Sized {
|
||||
fn distance(&self, rhs: &Self) -> f64;
|
||||
fn get_centroid(elements: &Vec<Self>) -> Option<Self>;
|
||||
}
|
||||
|
||||
macro_rules! impl_clusterable {
|
||||
( $type: ty) => {
|
||||
impl Clusterable for $type {
|
||||
fn distance(&self, rhs: &Self) -> f64 {
|
||||
if self > rhs {
|
||||
*self as f64 - *rhs as f64
|
||||
} else {
|
||||
*rhs as f64 - *self as f64
|
||||
}
|
||||
}
|
||||
|
||||
fn get_centroid(elements: &Vec<Self>) -> Option<Self> {
|
||||
|
||||
if elements.len() == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut tmp = 0.0 as Self;
|
||||
for element in elements {
|
||||
tmp += *element as Self;
|
||||
}
|
||||
Some(tmp / elements.len() as Self)
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl_clusterable!(f32);
|
||||
impl_clusterable!(f64);
|
||||
|
||||
|
||||
pub type Cluster<T> = Vec<T>;
|
||||
|
||||
pub struct Kmeans<T: Clusterable> {
|
||||
clusters: Vec<Vec<T>>,
|
||||
}
|
||||
|
||||
impl<T: Clusterable> Kmeans<T> {
|
||||
|
||||
pub fn new(centroids: Vec<T>, data: KmeansIntoIter<T>) -> Kmeans<T> {
|
||||
|
||||
let mut clusters = vec![];
|
||||
|
||||
for _ in ¢roids {
|
||||
clusters.push(Cluster::new());
|
||||
}
|
||||
|
||||
let mut new_kmeans: Kmeans<T> = Kmeans {
|
||||
clusters: clusters,
|
||||
};
|
||||
|
||||
for element in data {
|
||||
|
||||
// Compute the distance
|
||||
let mut distance = std::f64::MAX;
|
||||
let mut index = 0;
|
||||
|
||||
for (new_index, centroid) in centroids.iter().enumerate() {
|
||||
let new_distance = element.distance(centroid);
|
||||
|
||||
if new_distance < distance {
|
||||
distance = new_distance;
|
||||
index = new_index;
|
||||
}
|
||||
}
|
||||
|
||||
// Add element to the new kmeans
|
||||
new_kmeans.clusters[index].push(element)
|
||||
|
||||
}
|
||||
|
||||
new_kmeans
|
||||
}
|
||||
|
||||
pub fn add_cluster(&mut self) {
|
||||
self.clusters.push(Cluster::new());
|
||||
}
|
||||
|
||||
pub fn iterate(self) -> Kmeans<T> {
|
||||
let mut centroids = vec![];
|
||||
|
||||
// Compute the centroids
|
||||
for cluster in &self.clusters {
|
||||
|
||||
if let Some(centroid) = T::get_centroid(cluster) {
|
||||
centroids.push(centroid);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Kmeans::new(centroids, self.into_iter())
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> std::slice::Iter<Cluster<T>> {
|
||||
self.clusters.iter()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
pub struct KmeansIter<'a, T> where T:'a, T: Clusterable {
|
||||
global_iter: std::slice::Iter<'a, std::vec::Vec<T>>,
|
||||
local_iter: Option<std::slice::Iter<'a, T>>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<'a, T: 'a + Clusterable> Iterator for KmeansIter<'a, T> {
|
||||
type Item = &'a T;
|
||||
fn next(&mut self) -> Option<&'a T> {
|
||||
if let Some(ref mut local_iter) = self.local_iter {
|
||||
match local_iter.next() {
|
||||
Some(t) => Some(t),
|
||||
None => {
|
||||
if let Some(next) = self.global_iter.next() {
|
||||
*local_iter = next.iter();
|
||||
match local_iter.next() {
|
||||
Some(t) => Some(t),
|
||||
None => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self.local_iter = match self.global_iter.next() {
|
||||
None => None,
|
||||
Some(t) => Some(t.iter()),
|
||||
};
|
||||
self.next()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct KmeansIntoIter<T> where T: Clusterable {
|
||||
global_iter: std::vec::IntoIter<std::vec::Vec<T>>,
|
||||
local_iter: Option<std::vec::IntoIter<T>>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: Clusterable> Iterator for KmeansIntoIter<T> {
|
||||
type Item = T;
|
||||
fn next(&mut self) -> Option<T> {
|
||||
if let Some(ref mut local_iter) = self.local_iter {
|
||||
match local_iter.next() {
|
||||
Some(t) => Some(t),
|
||||
None => {
|
||||
if let Some(next) = self.global_iter.next() {
|
||||
*local_iter = next.into_iter();
|
||||
match local_iter.next() {
|
||||
Some(t) => Some(t),
|
||||
None => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self.local_iter = match self.global_iter.next() {
|
||||
None => None,
|
||||
Some(t) => Some(t.into_iter()),
|
||||
};
|
||||
self.next()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
impl<T: Clusterable> IntoIterator for Kmeans<T> {
|
||||
type Item = T;
|
||||
type IntoIter = KmeansIntoIter<T>;
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
Self::IntoIter {
|
||||
global_iter: self.clusters.into_iter(),
|
||||
local_iter: None,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
#[cfg(test)]
|
||||
mod test {
|
||||
#[test]
|
||||
fn iterators() {
|
||||
use Kmeans;
|
||||
let data = vec![4.0, 5.0, 11.0, 12.0, 13.0];
|
||||
let kmeans = Kmeans {
|
||||
clusters: vec![
|
||||
vec![4.0, 5.0],
|
||||
vec![11.0, 12.0, 13.0],
|
||||
],
|
||||
};
|
||||
|
||||
for (val1, val2) in kmeans.into_iter().zip(data) {
|
||||
assert_eq!(val1, val2);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn iterate() {
|
||||
use Kmeans;
|
||||
let kmeans = Kmeans {
|
||||
clusters: vec![
|
||||
vec![4.0, 5.0, 11.0, 12.0],
|
||||
vec![13.0],
|
||||
],
|
||||
};
|
||||
let kmeans = kmeans.iterate();
|
||||
let kmeans = kmeans.iterate();
|
||||
let kmeans = kmeans.iterate();
|
||||
let kmeans = kmeans.iterate();
|
||||
|
||||
assert_eq!(kmeans.clusters[0], vec![4.0, 5.0]);
|
||||
assert_eq!(kmeans.clusters[1], vec![11.0, 12.0, 13.0]);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue