kmeans/src/lib.rs

141 lines
3.4 KiB
Rust

use std::slice::Iter;
use std::iter::Zip;
use std::collections::HashMap;
#[derive(Debug)]
pub enum Error {
TooManyIterations,
}
pub struct KmeansData<T: Clone + PartialEq> {
pub elements: Vec<T>,
pub labels: Vec<usize>,
}
pub struct Kmeans<T: Clone + PartialEq, D: Fn(&T, &T) -> f64, C: Fn(&Vec<T>) -> T> {
data: KmeansData<T>,
centroids: Vec<T>,
distance: D,
centroid: C,
}
impl<T: Clone + PartialEq, D: Fn(&T, &T) -> f64, C: Fn(&Vec<T>) -> T> Kmeans<T, D, C> {
pub fn new(data: Vec<T>, centroids: Vec<T>, distance: D, centroid: C) -> Kmeans<T, D, C> {
let len = data.len();
Kmeans {
data: KmeansData { elements: data, labels: vec![0; len] },
centroids: centroids,
distance: distance,
centroid: centroid,
}
}
/// True if converged
pub fn iterate(&mut self) -> bool {
// Update the labels
self.update_labels();
// Update the centroids
let new_centroids = self.compute_centroids();
let ret = new_centroids == self.centroids;
self.centroids = new_centroids;
ret
}
fn update_labels(&mut self) {
let iterator = self.data.elements.iter().zip(self.data.labels.iter_mut());
for (ref element, ref mut label) in iterator {
let mut best_distance = std::f64::MAX;
let mut best_label = 0;
for (index, centroid) in self.centroids.iter().enumerate() {
let new_distance = (self.distance)(*element, centroid);
if new_distance < best_distance {
best_distance = new_distance;
best_label = index;
}
}
**label = best_label;
}
}
fn compute_centroids(&self) -> Vec<T> {
let mut centroids_map = HashMap::new();
for (element, label) in self.data.elements.iter().zip(self.data.labels.iter()) {
let mut centroid = centroids_map.entry(label).or_insert(vec![]);
centroid.push(element.clone());
}
let mut new_centroids = vec![];
for (_, value) in centroids_map {
new_centroids.push((self.centroid)(&value));
}
new_centroids
}
pub fn iter(&self) -> Zip<Iter<T>, Iter<usize>> {
self.data.elements.iter().zip(self.data.labels.iter())
}
pub fn into_iter(self) -> Vec<Vec<T>> {
let mut map = HashMap::new();
for (element, label) in self.data.elements.iter().zip(self.data.labels.iter()) {
let mut centroid = map.entry(label).or_insert(vec![]);
centroid.push(element.clone());
}
let mut output = vec![];
for (_, value) in map {
let mut cluster = vec![];
for element in value {
cluster.push(element);
}
output.push(cluster);
}
output
}
}
pub fn kmeans<T: Clone + PartialEq, D: Fn(&T, &T) -> f64, C: Fn(&Vec<T>) -> T>(
elements: Vec<T>,
initial: Vec<T>,
distance: D,
centroid: C,
max_iteration: usize,
) -> Result<(Kmeans<T, D, C>,usize), Error> {
let mut clusters = Kmeans::new(elements, initial, distance, centroid);
let mut counter = 0;
let iterations = loop {
counter += 1;
if clusters.iterate() || counter > max_iteration {
break counter;
}
};
Ok((clusters, iterations))
}