Cleaning
This commit is contained in:
@@ -1,38 +0,0 @@
|
||||
pub type Cluster<T> = Vec<T>;
|
||||
|
||||
pub trait Clusterable where Self: Sized + Clone + PartialEq {
|
||||
fn distance(&self, rhs: &Self) -> f64;
|
||||
fn get_centroid(elements: &Vec<Self>) -> Option<Self>;
|
||||
}
|
||||
|
||||
macro_rules! impl_clusterable {
|
||||
( $type: ty) => {
|
||||
impl Clusterable for $type {
|
||||
fn distance(&self, rhs: &Self) -> f64 {
|
||||
if self > rhs {
|
||||
*self as f64 - *rhs as f64
|
||||
} else {
|
||||
*rhs as f64 - *self as f64
|
||||
}
|
||||
}
|
||||
|
||||
fn get_centroid(elements: &Vec<Self>) -> Option<Self> {
|
||||
|
||||
if elements.len() == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut tmp = 0.0 as Self;
|
||||
for element in elements {
|
||||
tmp += *element as Self;
|
||||
}
|
||||
Some(tmp / elements.len() as Self)
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl_clusterable!(f32);
|
||||
impl_clusterable!(f64);
|
||||
|
||||
43
src/clusterable.rs
Normal file
43
src/clusterable.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
pub trait Clusterable where Self: Sized + Clone + PartialEq + Debug {
|
||||
fn distance(&self, rhs: &Self) -> f64;
|
||||
fn get_centroid<'a, I>(elements: I) -> Option<Self>
|
||||
where I: Iterator<Item = &'a Self>, Self: 'a;
|
||||
}
|
||||
|
||||
macro_rules! impl_clusterable {
|
||||
( $type: ty) => {
|
||||
impl Clusterable for $type {
|
||||
fn distance(&self, rhs: &Self) -> f64 {
|
||||
if self > rhs {
|
||||
*self as f64 - *rhs as f64
|
||||
} else {
|
||||
*rhs as f64 - *self as f64
|
||||
}
|
||||
}
|
||||
|
||||
fn get_centroid<'a, I>(elements: I) -> Option<Self>
|
||||
where I: Iterator<Item = &'a Self>, Self: 'a {
|
||||
|
||||
let mut tmp = 0.0;
|
||||
let mut count = 0.0;
|
||||
for element in elements {
|
||||
tmp += element;
|
||||
count += 1.0;
|
||||
}
|
||||
|
||||
if count > 0.0 {
|
||||
Some(tmp / count)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl_clusterable!(f32);
|
||||
impl_clusterable!(f64);
|
||||
|
||||
@@ -7,7 +7,7 @@ use rand::distributions::Range;
|
||||
use rand::distributions::normal::Normal;
|
||||
use generic_kmeans::{kmeans, Clusterable};
|
||||
|
||||
#[derive(PartialEq, Copy, Clone)]
|
||||
#[derive(PartialEq, Copy, Clone, Debug)]
|
||||
struct Vector2<T> {
|
||||
pub x: T,
|
||||
pub y: T,
|
||||
@@ -30,28 +30,28 @@ impl<T: Display> Display for Vector2<T> {
|
||||
|
||||
impl Clusterable for Vector2<f64> {
|
||||
fn distance(&self, other: &Self) -> f64 {
|
||||
(self.x - other.x) * (self.x - other.y) + (self.y - other.y) * (self.y - other.y)
|
||||
(self.x - other.x) * (self.x - other.x) + (self.y - other.y) * (self.y - other.y)
|
||||
}
|
||||
|
||||
fn get_centroid(cluster: &Vec<Vector2<f64>>) -> Option<Vector2<f64>> {
|
||||
|
||||
let len = cluster.len();
|
||||
|
||||
if len == 0 {
|
||||
return None;
|
||||
}
|
||||
fn get_centroid<'a, I>(cluster: I) -> Option<Self>
|
||||
where I: Iterator<Item = &'a Self>, Self: 'a {
|
||||
|
||||
let mut centroid = Vector2::new(0.0, 0.0);
|
||||
let mut count = 0.0;
|
||||
|
||||
for i in cluster {
|
||||
centroid.x += i.x;
|
||||
centroid.y += i.y;
|
||||
count += 1.0;
|
||||
}
|
||||
|
||||
centroid.x /= len as f64;
|
||||
centroid.y /= len as f64;
|
||||
|
||||
Some(centroid)
|
||||
if count > 0.0 {
|
||||
centroid.x /= count as f64;
|
||||
centroid.y /= count as f64;
|
||||
Some(centroid)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,21 +85,24 @@ fn main() {
|
||||
|
||||
}
|
||||
|
||||
let initialization = vec![
|
||||
Vector2::new(0.0,0.0),
|
||||
Vector2::new(10.0,0.0),
|
||||
Vector2::new(0.0,10.0),
|
||||
];
|
||||
|
||||
let (clusters, nb_iterations) = kmeans(
|
||||
centers.iter().map(|x| x.clone().0).collect::<Vec<_>>(), elements, 100000).ok().unwrap();
|
||||
|
||||
println!("{}", nb_iterations);
|
||||
let (clusters, nb_iterations) = kmeans(initialization, elements, 100000).ok().unwrap();
|
||||
|
||||
let mut output = File::create("plot/dat.dat").unwrap();
|
||||
|
||||
for (cluster, color) in clusters.iter().zip(&colors) {
|
||||
for element in cluster {
|
||||
use std::io::Write;
|
||||
writeln!(output, "{} {} {}", element.x, element.y, color).unwrap();
|
||||
}
|
||||
for (element, &label) in clusters.iter() {
|
||||
use std::io::Write;
|
||||
writeln!(output, "{} {} {}", element.x, element.y, colors[label]).unwrap();
|
||||
}
|
||||
|
||||
println!("Finished in {} iterations", nb_iterations);
|
||||
|
||||
let mut center_file = File::create("plot/centers.dat").unwrap();
|
||||
for (&(center, _, _), color) in centers.iter().zip(&colors) {
|
||||
use std::io::Write;
|
||||
|
||||
@@ -1,48 +1,82 @@
|
||||
use std;
|
||||
|
||||
use kmeansdata::KmeansData;
|
||||
use cluster::{Cluster, Clusterable};
|
||||
use std::slice::Iter;
|
||||
use std::iter::Zip;
|
||||
use clusterable::Clusterable;
|
||||
|
||||
pub struct Kmeans<T: Clusterable> {
|
||||
pub centroids: Vec<T>,
|
||||
pub data: KmeansData<T>,
|
||||
centroids: Vec<T>,
|
||||
elements: Vec<T>,
|
||||
labels: Vec<usize>,
|
||||
cluster_number: usize,
|
||||
}
|
||||
|
||||
impl<T:Clusterable> Kmeans<T> {
|
||||
pub fn new(centroids: Vec<T>, data: Vec<Vec<T>>) -> Kmeans<T> {
|
||||
pub fn new(centroids: Vec<T>, data: Vec<T>) -> Kmeans<T> {
|
||||
let labels = Kmeans::build_labels(¢roids, &data);
|
||||
let cluster_number = centroids.len();
|
||||
Kmeans {
|
||||
centroids: centroids,
|
||||
data: KmeansData::from_clusters(data),
|
||||
elements: data,
|
||||
labels: labels,
|
||||
cluster_number: cluster_number
|
||||
}
|
||||
}
|
||||
|
||||
pub fn guess_centroids(data: Vec<Vec<T>>) -> Kmeans<T> {
|
||||
/// \returns True if converged
|
||||
pub fn iterate(&mut self) -> bool {
|
||||
// Update the centroids
|
||||
let centroids = Kmeans::build_centroids(&self.elements, &self.labels, self.cluster_number);
|
||||
|
||||
if self.centroids == centroids {
|
||||
true
|
||||
} else {
|
||||
self.centroids = centroids;
|
||||
Kmeans::update_labels(&self.centroids, &self.elements, &mut self.labels);
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_labels(centroids: &Vec<T>, data: &Vec<T>) -> Vec<usize> {
|
||||
debug_assert_ne!(0, centroids.len());
|
||||
|
||||
let mut output = vec![0; data.len()];
|
||||
Kmeans::update_labels(centroids, data, &mut output);
|
||||
output
|
||||
}
|
||||
|
||||
pub fn update_labels(centroids: &Vec<T>, data: &Vec<T>, labels: &mut Vec<usize>) {
|
||||
for (element, new_label) in data.iter().zip(labels.iter_mut()) {
|
||||
*new_label = centroids
|
||||
.iter()
|
||||
.enumerate()
|
||||
.min_by(|&(_, c1), &(_, c2)| {
|
||||
c1.distance(element).partial_cmp(&c2.distance(element)).unwrap()
|
||||
}).unwrap().0;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_centroids(data: &Vec<T>, labels: &Vec<usize>, cluster_number: usize) -> Vec<T> {
|
||||
|
||||
let mut centroids = vec![];
|
||||
for cluster in &data {
|
||||
if let Some(centroid) = T::get_centroid(cluster) {
|
||||
centroids.push(centroid);
|
||||
for label in 0..cluster_number {
|
||||
|
||||
let to_consider = data
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|&(index, _)| labels[index] == label)
|
||||
.map(|(_, element)| element);
|
||||
|
||||
if let Some(centroid) = T::get_centroid(to_consider) {
|
||||
centroids.push(centroid);
|
||||
}
|
||||
}
|
||||
|
||||
Kmeans::new(centroids, data)
|
||||
|
||||
centroids
|
||||
}
|
||||
|
||||
fn from_data(centroids: Vec<T>, data: KmeansData<T>) -> Kmeans<T> {
|
||||
Kmeans {
|
||||
centroids: centroids,
|
||||
data: data,
|
||||
}
|
||||
pub fn iter(&self) -> Zip<Iter<T>, Iter<usize>> {
|
||||
debug_assert_eq!(self.elements.len(), self.labels.len());
|
||||
return self.elements.iter().zip(self.labels.iter());
|
||||
}
|
||||
|
||||
pub fn next_iteration(self) -> (Kmeans<T>, bool) {
|
||||
let (new_centroids, data) = self.data.iterate(&self.centroids);
|
||||
let stable = new_centroids == self.centroids;
|
||||
(Kmeans::from_data(new_centroids, data), stable)
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> std::slice::Iter<Cluster<T>> {
|
||||
self.data.clusters()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
use std;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use cluster::{Clusterable, Cluster};
|
||||
|
||||
|
||||
pub struct KmeansData<T: Clusterable> {
|
||||
clusters: Vec<Vec<T>>,
|
||||
}
|
||||
|
||||
impl<T: Clusterable> KmeansData<T> {
|
||||
|
||||
pub fn from_clusters(data: Vec<Vec<T>>) -> KmeansData<T> {
|
||||
KmeansData {
|
||||
clusters: data,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(centroids: &Vec<T>, data: KmeansDataIntoIter<T>) -> KmeansData<T> {
|
||||
|
||||
let mut clusters = vec![];
|
||||
|
||||
for _ in centroids {
|
||||
clusters.push(Cluster::new());
|
||||
}
|
||||
|
||||
let mut new_kmeans: KmeansData<T> = KmeansData {
|
||||
clusters: clusters,
|
||||
};
|
||||
|
||||
for element in data {
|
||||
|
||||
// Compute the distance
|
||||
let mut distance = std::f64::MAX;
|
||||
let mut index = 0;
|
||||
|
||||
for (new_index, centroid) in centroids.iter().enumerate() {
|
||||
let new_distance = element.distance(centroid);
|
||||
|
||||
if new_distance < distance {
|
||||
distance = new_distance;
|
||||
index = new_index;
|
||||
}
|
||||
}
|
||||
|
||||
// Add element to the new kmeans
|
||||
new_kmeans.clusters[index].push(element)
|
||||
|
||||
}
|
||||
|
||||
new_kmeans
|
||||
}
|
||||
|
||||
pub fn add_cluster(&mut self) {
|
||||
self.clusters.push(Cluster::new());
|
||||
}
|
||||
|
||||
pub fn iterate(self, centroids: &Vec<T>) -> (Vec<T>, KmeansData<T>) {
|
||||
// Compute the result with the given centroids
|
||||
let result = KmeansData::new(¢roids, self.into_iter());
|
||||
|
||||
// Compute the new centroids
|
||||
let mut new_centroids = vec![];
|
||||
|
||||
for cluster in &result.clusters {
|
||||
|
||||
if let Some(centroid) = T::get_centroid(cluster) {
|
||||
new_centroids.push(centroid);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
(new_centroids, result)
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> KmeansDataIter<T> {
|
||||
KmeansDataIter {
|
||||
global_iter: self.clusters.iter(),
|
||||
local_iter: None,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clusters(&self) -> std::slice::Iter<Cluster<T>> {
|
||||
self.clusters.iter()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
pub struct KmeansDataIter<'a, T> where T:'a, T: Clusterable {
|
||||
global_iter: std::slice::Iter<'a, std::vec::Vec<T>>,
|
||||
local_iter: Option<std::slice::Iter<'a, T>>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<'a, T: 'a + Clusterable> Iterator for KmeansDataIter<'a, T> {
|
||||
type Item = &'a T;
|
||||
fn next(&mut self) -> Option<&'a T> {
|
||||
if let Some(ref mut local_iter) = self.local_iter {
|
||||
match local_iter.next() {
|
||||
Some(t) => Some(t),
|
||||
None => {
|
||||
if let Some(next) = self.global_iter.next() {
|
||||
*local_iter = next.iter();
|
||||
match local_iter.next() {
|
||||
Some(t) => Some(t),
|
||||
None => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self.local_iter = match self.global_iter.next() {
|
||||
None => None,
|
||||
Some(t) => Some(t.iter()),
|
||||
};
|
||||
self.next()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct KmeansDataIntoIter<T> where T: Clusterable {
|
||||
global_iter: std::vec::IntoIter<std::vec::Vec<T>>,
|
||||
local_iter: Option<std::vec::IntoIter<T>>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: Clusterable> Iterator for KmeansDataIntoIter<T> {
|
||||
type Item = T;
|
||||
fn next(&mut self) -> Option<T> {
|
||||
if let Some(ref mut local_iter) = self.local_iter {
|
||||
match local_iter.next() {
|
||||
Some(t) => Some(t),
|
||||
None => {
|
||||
if let Some(next) = self.global_iter.next() {
|
||||
*local_iter = next.into_iter();
|
||||
match local_iter.next() {
|
||||
Some(t) => Some(t),
|
||||
None => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self.local_iter = match self.global_iter.next() {
|
||||
None => None,
|
||||
Some(t) => Some(t.into_iter()),
|
||||
};
|
||||
self.next()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clusterable> IntoIterator for KmeansData<T> {
|
||||
type Item = T;
|
||||
type IntoIter = KmeansDataIntoIter<T>;
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
Self::IntoIter {
|
||||
global_iter: self.clusters.into_iter(),
|
||||
local_iter: None,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
12
src/lib.rs
12
src/lib.rs
@@ -1,11 +1,8 @@
|
||||
pub mod kmeans;
|
||||
pub mod kmeansdata;
|
||||
pub mod cluster;
|
||||
pub mod test;
|
||||
pub mod clusterable;
|
||||
|
||||
pub use kmeans::Kmeans;
|
||||
pub use kmeansdata::KmeansData;
|
||||
pub use cluster::{Cluster, Clusterable};
|
||||
pub use clusterable::Clusterable;
|
||||
|
||||
pub enum Error {
|
||||
IterationsLimitExceeded,
|
||||
@@ -14,12 +11,11 @@ pub enum Error {
|
||||
pub fn kmeans<T: Clusterable>(centroids: Vec<T>, data: Vec<T>, max_iterations: usize)
|
||||
-> Result<(Kmeans<T>, usize), Error> {
|
||||
|
||||
let mut kmeans = Kmeans::new(centroids, vec![data]);
|
||||
let mut kmeans = Kmeans::new(centroids, data);
|
||||
|
||||
for nb_iterations in 0..max_iterations {
|
||||
|
||||
let (new_kmeans, stable) = kmeans.next_iteration();
|
||||
kmeans = new_kmeans;
|
||||
let stable = kmeans.iterate();
|
||||
|
||||
if stable {
|
||||
return Ok((kmeans, nb_iterations));
|
||||
|
||||
44
src/test.rs
44
src/test.rs
@@ -1,44 +0,0 @@
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
#[test]
|
||||
fn iterators() {
|
||||
use kmeansdata::KmeansData;
|
||||
let data = vec![4.0, 5.0, 11.0, 12.0, 13.0];
|
||||
let kmeans = KmeansData::from_clusters(vec![
|
||||
vec![4.0, 5.0],
|
||||
vec![11.0, 12.0, 13.0],
|
||||
]);
|
||||
|
||||
for (val1, val2) in kmeans.into_iter().zip(data) {
|
||||
assert_eq!(val1, val2);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn iterate() {
|
||||
use kmeans::Kmeans;
|
||||
|
||||
let data = vec![
|
||||
vec![4.0, 5.0, 11.0, 12.0],
|
||||
vec![13.0],
|
||||
];
|
||||
|
||||
let solution = vec![
|
||||
vec![4.0, 5.0],
|
||||
vec![11.0, 12.0, 13.0],
|
||||
];
|
||||
|
||||
let mut kmeans = Kmeans::guess_centroids(data.clone());
|
||||
|
||||
for _ in 0..4 {
|
||||
let (new_kmeans, stable) = kmeans.next_iteration();
|
||||
kmeans = new_kmeans;
|
||||
}
|
||||
|
||||
for (k1, k2) in kmeans.iter().zip(solution) {
|
||||
for (i, j) in k1.iter().zip(&k2) {
|
||||
assert_eq!(i, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user