Update, clean, test, example
This commit is contained in:
parent
c3aff5c77d
commit
14455d3693
|
@ -1,6 +1,11 @@
|
|||
[package]
|
||||
name = "kmeans"
|
||||
name = "generic_kmeans"
|
||||
version = "0.1.0"
|
||||
authors = ["Thomas Forgione <thomas@tforgione.fr>"]
|
||||
|
||||
[dependencies]
|
||||
|
||||
[[bin]]
|
||||
name = "example"
|
||||
path = "src/example.rs"
|
||||
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
pub type Cluster<T> = Vec<T>;
|
||||
|
||||
pub trait Clusterable where Self: Sized + Clone + PartialEq {
|
||||
fn distance(&self, rhs: &Self) -> f64;
|
||||
fn get_centroid(elements: &Vec<Self>) -> Option<Self>;
|
||||
}
|
||||
|
||||
macro_rules! impl_clusterable {
|
||||
( $type: ty) => {
|
||||
impl Clusterable for $type {
|
||||
fn distance(&self, rhs: &Self) -> f64 {
|
||||
if self > rhs {
|
||||
*self as f64 - *rhs as f64
|
||||
} else {
|
||||
*rhs as f64 - *self as f64
|
||||
}
|
||||
}
|
||||
|
||||
fn get_centroid(elements: &Vec<Self>) -> Option<Self> {
|
||||
|
||||
if elements.len() == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut tmp = 0.0 as Self;
|
||||
for element in elements {
|
||||
tmp += *element as Self;
|
||||
}
|
||||
Some(tmp / elements.len() as Self)
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl_clusterable!(f32);
|
||||
impl_clusterable!(f64);
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
extern crate generic_kmeans;
|
||||
|
||||
use std::fmt::{Display, Formatter, Result};
|
||||
|
||||
use generic_kmeans::{kmeans, Clusterable};
|
||||
|
||||
#[derive(PartialEq, Clone)]
|
||||
struct Vector2<T> {
|
||||
pub x: T,
|
||||
pub y: T,
|
||||
}
|
||||
|
||||
impl<T> Vector2<T> {
|
||||
pub fn new(x: T, y: T) -> Vector2<T> {
|
||||
Vector2 {
|
||||
x: x,
|
||||
y: y,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Display> Display for Vector2<T> {
|
||||
fn fmt(&self, formatter: &mut Formatter) -> Result {
|
||||
write!(formatter, "({}, {})", self.x, self.y)
|
||||
}
|
||||
}
|
||||
|
||||
impl Clusterable for Vector2<f64> {
|
||||
fn distance(&self, other: &Self) -> f64 {
|
||||
(self.x - other.x) * (self.x - other.y) + (self.y - other.y) * (self.y - other.y)
|
||||
}
|
||||
|
||||
fn get_centroid(cluster: &Vec<Vector2<f64>>) -> Option<Vector2<f64>> {
|
||||
|
||||
let len = cluster.len();
|
||||
|
||||
if len == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut centroid = Vector2::new(0.0, 0.0);
|
||||
|
||||
for i in cluster {
|
||||
centroid.x += i.x;
|
||||
centroid.y += i.y;
|
||||
}
|
||||
|
||||
centroid.x /= len as f64;
|
||||
centroid.y /= len as f64;
|
||||
|
||||
Some(centroid)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
|
||||
let elements = vec![
|
||||
Vector2::new(8.0, 3.0),
|
||||
Vector2::new(9.0, 3.0),
|
||||
Vector2::new(9.0, 2.0),
|
||||
Vector2::new(1.0, 8.0),
|
||||
Vector2::new(2.0, 9.0),
|
||||
Vector2::new(3.0, 8.0),
|
||||
];
|
||||
|
||||
let initial = vec![
|
||||
Vector2::new(1.0, 10.0),
|
||||
Vector2::new(10.0, 0.0),
|
||||
];
|
||||
|
||||
let (clusters, nb_iterations) = kmeans(initial, elements, 1000).ok().unwrap();
|
||||
|
||||
println!("{}", nb_iterations);
|
||||
|
||||
for (index, cluster) in clusters.iter().enumerate() {
|
||||
println!("CLUSTER {}", index);
|
||||
|
||||
for element in cluster {
|
||||
println!("\t{}", element);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
use std;
|
||||
|
||||
use kmeansdata::KmeansData;
|
||||
use cluster::{Cluster, Clusterable};
|
||||
|
||||
pub struct Kmeans<T: Clusterable> {
|
||||
pub centroids: Vec<T>,
|
||||
pub data: KmeansData<T>,
|
||||
}
|
||||
|
||||
impl<T:Clusterable> Kmeans<T> {
|
||||
pub fn new(centroids: Vec<T>, data: Vec<Vec<T>>) -> Kmeans<T> {
|
||||
Kmeans {
|
||||
centroids: centroids,
|
||||
data: KmeansData::from_clusters(data),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn guess_centroids(data: Vec<Vec<T>>) -> Kmeans<T> {
|
||||
let mut centroids = vec![];
|
||||
for cluster in &data {
|
||||
if let Some(centroid) = T::get_centroid(cluster) {
|
||||
centroids.push(centroid);
|
||||
}
|
||||
}
|
||||
|
||||
Kmeans::new(centroids, data)
|
||||
|
||||
}
|
||||
|
||||
fn from_data(centroids: Vec<T>, data: KmeansData<T>) -> Kmeans<T> {
|
||||
Kmeans {
|
||||
centroids: centroids,
|
||||
data: data,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn next_iteration(self) -> (Kmeans<T>, bool) {
|
||||
let (new_centroids, data) = self.data.iterate(&self.centroids);
|
||||
let stable = new_centroids == self.centroids;
|
||||
(Kmeans::from_data(new_centroids, data), stable)
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> std::slice::Iter<Cluster<T>> {
|
||||
self.data.clusters()
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,168 @@
|
|||
use std;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use cluster::{Clusterable, Cluster};
|
||||
|
||||
|
||||
pub struct KmeansData<T: Clusterable> {
|
||||
clusters: Vec<Vec<T>>,
|
||||
}
|
||||
|
||||
impl<T: Clusterable> KmeansData<T> {
|
||||
|
||||
pub fn from_clusters(data: Vec<Vec<T>>) -> KmeansData<T> {
|
||||
KmeansData {
|
||||
clusters: data,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(centroids: &Vec<T>, data: KmeansDataIntoIter<T>) -> KmeansData<T> {
|
||||
|
||||
let mut clusters = vec![];
|
||||
|
||||
for _ in centroids {
|
||||
clusters.push(Cluster::new());
|
||||
}
|
||||
|
||||
let mut new_kmeans: KmeansData<T> = KmeansData {
|
||||
clusters: clusters,
|
||||
};
|
||||
|
||||
for element in data {
|
||||
|
||||
// Compute the distance
|
||||
let mut distance = std::f64::MAX;
|
||||
let mut index = 0;
|
||||
|
||||
for (new_index, centroid) in centroids.iter().enumerate() {
|
||||
let new_distance = element.distance(centroid);
|
||||
|
||||
if new_distance < distance {
|
||||
distance = new_distance;
|
||||
index = new_index;
|
||||
}
|
||||
}
|
||||
|
||||
// Add element to the new kmeans
|
||||
new_kmeans.clusters[index].push(element)
|
||||
|
||||
}
|
||||
|
||||
new_kmeans
|
||||
}
|
||||
|
||||
pub fn add_cluster(&mut self) {
|
||||
self.clusters.push(Cluster::new());
|
||||
}
|
||||
|
||||
pub fn iterate(self, centroids: &Vec<T>) -> (Vec<T>, KmeansData<T>) {
|
||||
// Compute the result with the given centroids
|
||||
let result = KmeansData::new(¢roids, self.into_iter());
|
||||
|
||||
// Compute the new centroids
|
||||
let mut new_centroids = vec![];
|
||||
|
||||
for cluster in &result.clusters {
|
||||
|
||||
if let Some(centroid) = T::get_centroid(cluster) {
|
||||
new_centroids.push(centroid);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
(new_centroids, result)
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> KmeansDataIter<T> {
|
||||
KmeansDataIter {
|
||||
global_iter: self.clusters.iter(),
|
||||
local_iter: None,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clusters(&self) -> std::slice::Iter<Cluster<T>> {
|
||||
self.clusters.iter()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
pub struct KmeansDataIter<'a, T> where T:'a, T: Clusterable {
|
||||
global_iter: std::slice::Iter<'a, std::vec::Vec<T>>,
|
||||
local_iter: Option<std::slice::Iter<'a, T>>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<'a, T: 'a + Clusterable> Iterator for KmeansDataIter<'a, T> {
|
||||
type Item = &'a T;
|
||||
fn next(&mut self) -> Option<&'a T> {
|
||||
if let Some(ref mut local_iter) = self.local_iter {
|
||||
match local_iter.next() {
|
||||
Some(t) => Some(t),
|
||||
None => {
|
||||
if let Some(next) = self.global_iter.next() {
|
||||
*local_iter = next.iter();
|
||||
match local_iter.next() {
|
||||
Some(t) => Some(t),
|
||||
None => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self.local_iter = match self.global_iter.next() {
|
||||
None => None,
|
||||
Some(t) => Some(t.iter()),
|
||||
};
|
||||
self.next()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct KmeansDataIntoIter<T> where T: Clusterable {
|
||||
global_iter: std::vec::IntoIter<std::vec::Vec<T>>,
|
||||
local_iter: Option<std::vec::IntoIter<T>>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: Clusterable> Iterator for KmeansDataIntoIter<T> {
|
||||
type Item = T;
|
||||
fn next(&mut self) -> Option<T> {
|
||||
if let Some(ref mut local_iter) = self.local_iter {
|
||||
match local_iter.next() {
|
||||
Some(t) => Some(t),
|
||||
None => {
|
||||
if let Some(next) = self.global_iter.next() {
|
||||
*local_iter = next.into_iter();
|
||||
match local_iter.next() {
|
||||
Some(t) => Some(t),
|
||||
None => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self.local_iter = match self.global_iter.next() {
|
||||
None => None,
|
||||
Some(t) => Some(t.into_iter()),
|
||||
};
|
||||
self.next()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clusterable> IntoIterator for KmeansData<T> {
|
||||
type Item = T;
|
||||
type IntoIter = KmeansDataIntoIter<T>;
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
Self::IntoIter {
|
||||
global_iter: self.clusters.into_iter(),
|
||||
local_iter: None,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
193
src/lib.rs
193
src/lib.rs
|
@ -1,190 +1,31 @@
|
|||
pub mod kmeans;
|
||||
pub mod kmeansdata;
|
||||
pub mod cluster;
|
||||
pub mod test;
|
||||
|
||||
use std::marker::PhantomData;
|
||||
pub use kmeans::Kmeans;
|
||||
pub use kmeansdata::KmeansData;
|
||||
pub use cluster::{Cluster, Clusterable};
|
||||
|
||||
pub trait Clusterable where Self: Sized {
|
||||
fn distance(&self, rhs: &Self) -> f64;
|
||||
fn get_centroid(elements: &Vec<Self>) -> Option<Self>;
|
||||
pub enum Error {
|
||||
IterationsLimitExceeded,
|
||||
}
|
||||
|
||||
macro_rules! impl_clusterable {
|
||||
( $type: ty) => {
|
||||
impl Clusterable for $type {
|
||||
fn distance(&self, rhs: &Self) -> f64 {
|
||||
if self > rhs {
|
||||
*self as f64 - *rhs as f64
|
||||
} else {
|
||||
*rhs as f64 - *self as f64
|
||||
}
|
||||
}
|
||||
pub fn kmeans<T: Clusterable>(centroids: Vec<T>, data: Vec<T>, max_iterations: usize)
|
||||
-> Result<(Kmeans<T>, usize), Error> {
|
||||
|
||||
fn get_centroid(elements: &Vec<Self>) -> Option<Self> {
|
||||
let mut kmeans = Kmeans::new(centroids, vec![data]);
|
||||
|
||||
if elements.len() == 0 {
|
||||
return None;
|
||||
}
|
||||
for nb_iterations in 0..max_iterations {
|
||||
|
||||
let mut tmp = 0.0 as Self;
|
||||
for element in elements {
|
||||
tmp += *element as Self;
|
||||
}
|
||||
Some(tmp / elements.len() as Self)
|
||||
let (new_kmeans, stable) = kmeans.next_iteration();
|
||||
kmeans = new_kmeans;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl_clusterable!(f32);
|
||||
impl_clusterable!(f64);
|
||||
|
||||
|
||||
pub type Cluster<T> = Vec<T>;
|
||||
|
||||
pub struct Kmeans<T: Clusterable> {
|
||||
clusters: Vec<Vec<T>>,
|
||||
}
|
||||
|
||||
impl<T: Clusterable> Kmeans<T> {
|
||||
|
||||
pub fn new(centroids: Vec<T>, data: KmeansIntoIter<T>) -> Kmeans<T> {
|
||||
|
||||
let mut clusters = vec![];
|
||||
|
||||
for _ in ¢roids {
|
||||
clusters.push(Cluster::new());
|
||||
if stable {
|
||||
return Ok((kmeans, nb_iterations));
|
||||
}
|
||||
|
||||
let mut new_kmeans: Kmeans<T> = Kmeans {
|
||||
clusters: clusters,
|
||||
};
|
||||
|
||||
for element in data {
|
||||
|
||||
// Compute the distance
|
||||
let mut distance = std::f64::MAX;
|
||||
let mut index = 0;
|
||||
|
||||
for (new_index, centroid) in centroids.iter().enumerate() {
|
||||
let new_distance = element.distance(centroid);
|
||||
|
||||
if new_distance < distance {
|
||||
distance = new_distance;
|
||||
index = new_index;
|
||||
}
|
||||
}
|
||||
|
||||
// Add element to the new kmeans
|
||||
new_kmeans.clusters[index].push(element)
|
||||
|
||||
}
|
||||
|
||||
new_kmeans
|
||||
}
|
||||
|
||||
pub fn add_cluster(&mut self) {
|
||||
self.clusters.push(Cluster::new());
|
||||
}
|
||||
|
||||
pub fn iterate(self) -> Kmeans<T> {
|
||||
let mut centroids = vec![];
|
||||
|
||||
// Compute the centroids
|
||||
for cluster in &self.clusters {
|
||||
|
||||
if let Some(centroid) = T::get_centroid(cluster) {
|
||||
centroids.push(centroid);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Kmeans::new(centroids, self.into_iter())
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> std::slice::Iter<Cluster<T>> {
|
||||
self.clusters.iter()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
pub struct KmeansIter<'a, T> where T:'a, T: Clusterable {
|
||||
global_iter: std::slice::Iter<'a, std::vec::Vec<T>>,
|
||||
local_iter: Option<std::slice::Iter<'a, T>>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<'a, T: 'a + Clusterable> Iterator for KmeansIter<'a, T> {
|
||||
type Item = &'a T;
|
||||
fn next(&mut self) -> Option<&'a T> {
|
||||
if let Some(ref mut local_iter) = self.local_iter {
|
||||
match local_iter.next() {
|
||||
Some(t) => Some(t),
|
||||
None => {
|
||||
if let Some(next) = self.global_iter.next() {
|
||||
*local_iter = next.iter();
|
||||
match local_iter.next() {
|
||||
Some(t) => Some(t),
|
||||
None => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self.local_iter = match self.global_iter.next() {
|
||||
None => None,
|
||||
Some(t) => Some(t.iter()),
|
||||
};
|
||||
self.next()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct KmeansIntoIter<T> where T: Clusterable {
|
||||
global_iter: std::vec::IntoIter<std::vec::Vec<T>>,
|
||||
local_iter: Option<std::vec::IntoIter<T>>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: Clusterable> Iterator for KmeansIntoIter<T> {
|
||||
type Item = T;
|
||||
fn next(&mut self) -> Option<T> {
|
||||
if let Some(ref mut local_iter) = self.local_iter {
|
||||
match local_iter.next() {
|
||||
Some(t) => Some(t),
|
||||
None => {
|
||||
if let Some(next) = self.global_iter.next() {
|
||||
*local_iter = next.into_iter();
|
||||
match local_iter.next() {
|
||||
Some(t) => Some(t),
|
||||
None => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self.local_iter = match self.global_iter.next() {
|
||||
None => None,
|
||||
Some(t) => Some(t.into_iter()),
|
||||
};
|
||||
self.next()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
impl<T: Clusterable> IntoIterator for Kmeans<T> {
|
||||
type Item = T;
|
||||
type IntoIter = KmeansIntoIter<T>;
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
Self::IntoIter {
|
||||
global_iter: self.clusters.into_iter(),
|
||||
local_iter: None,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
Err(Error::IterationsLimitExceeded)
|
||||
}
|
||||
|
|
48
src/test.rs
48
src/test.rs
|
@ -2,14 +2,12 @@
|
|||
mod test {
|
||||
#[test]
|
||||
fn iterators() {
|
||||
use Kmeans;
|
||||
use kmeansdata::KmeansData;
|
||||
let data = vec![4.0, 5.0, 11.0, 12.0, 13.0];
|
||||
let kmeans = Kmeans {
|
||||
clusters: vec![
|
||||
vec![4.0, 5.0],
|
||||
vec![11.0, 12.0, 13.0],
|
||||
],
|
||||
};
|
||||
let kmeans = KmeansData::from_clusters(vec![
|
||||
vec![4.0, 5.0],
|
||||
vec![11.0, 12.0, 13.0],
|
||||
]);
|
||||
|
||||
for (val1, val2) in kmeans.into_iter().zip(data) {
|
||||
assert_eq!(val1, val2);
|
||||
|
@ -18,19 +16,29 @@ mod test {
|
|||
|
||||
#[test]
|
||||
fn iterate() {
|
||||
use Kmeans;
|
||||
let kmeans = Kmeans {
|
||||
clusters: vec![
|
||||
vec![4.0, 5.0, 11.0, 12.0],
|
||||
vec![13.0],
|
||||
],
|
||||
};
|
||||
let kmeans = kmeans.iterate();
|
||||
let kmeans = kmeans.iterate();
|
||||
let kmeans = kmeans.iterate();
|
||||
let kmeans = kmeans.iterate();
|
||||
use kmeans::Kmeans;
|
||||
|
||||
assert_eq!(kmeans.clusters[0], vec![4.0, 5.0]);
|
||||
assert_eq!(kmeans.clusters[1], vec![11.0, 12.0, 13.0]);
|
||||
let data = vec![
|
||||
vec![4.0, 5.0, 11.0, 12.0],
|
||||
vec![13.0],
|
||||
];
|
||||
|
||||
let solution = vec![
|
||||
vec![4.0, 5.0],
|
||||
vec![11.0, 12.0, 13.0],
|
||||
];
|
||||
|
||||
let mut kmeans = Kmeans::guess_centroids(data.clone());
|
||||
|
||||
for _ in 0..4 {
|
||||
let (new_kmeans, stable) = kmeans.next_iteration();
|
||||
kmeans = new_kmeans;
|
||||
}
|
||||
|
||||
for (k1, k2) in kmeans.iter().zip(solution) {
|
||||
for (i, j) in k1.iter().zip(&k2) {
|
||||
assert_eq!(i, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue