141 lines
3.4 KiB
Rust
141 lines
3.4 KiB
Rust
use std::slice::Iter;
|
|
use std::iter::Zip;
|
|
use std::collections::HashMap;
|
|
|
|
#[derive(Debug)]
|
|
pub enum Error {
|
|
TooManyIterations,
|
|
}
|
|
|
|
pub struct KmeansData<T: Clone + PartialEq> {
|
|
pub elements: Vec<T>,
|
|
pub labels: Vec<usize>,
|
|
}
|
|
|
|
pub struct Kmeans<T: Clone + PartialEq, D: Fn(&T, &T) -> f64, C: Fn(&Vec<T>) -> T> {
|
|
data: KmeansData<T>,
|
|
centroids: Vec<T>,
|
|
distance: D,
|
|
centroid: C,
|
|
}
|
|
|
|
impl<T: Clone + PartialEq, D: Fn(&T, &T) -> f64, C: Fn(&Vec<T>) -> T> Kmeans<T, D, C> {
|
|
pub fn new(data: Vec<T>, centroids: Vec<T>, distance: D, centroid: C) -> Kmeans<T, D, C> {
|
|
|
|
let len = data.len();
|
|
Kmeans {
|
|
data: KmeansData { elements: data, labels: vec![0; len] },
|
|
centroids: centroids,
|
|
distance: distance,
|
|
centroid: centroid,
|
|
}
|
|
|
|
}
|
|
|
|
/// True if converged
|
|
pub fn iterate(&mut self) -> bool {
|
|
// Update the labels
|
|
self.update_labels();
|
|
|
|
// Update the centroids
|
|
let new_centroids = self.compute_centroids();
|
|
|
|
let ret = new_centroids == self.centroids;
|
|
self.centroids = new_centroids;
|
|
ret
|
|
}
|
|
|
|
fn update_labels(&mut self) {
|
|
|
|
let iterator = self.data.elements.iter().zip(self.data.labels.iter_mut());
|
|
|
|
for (ref element, ref mut label) in iterator {
|
|
|
|
let mut best_distance = std::f64::MAX;
|
|
let mut best_label = 0;
|
|
|
|
for (index, centroid) in self.centroids.iter().enumerate() {
|
|
|
|
let new_distance = (self.distance)(*element, centroid);
|
|
if new_distance < best_distance {
|
|
best_distance = new_distance;
|
|
best_label = index;
|
|
}
|
|
}
|
|
|
|
**label = best_label;
|
|
}
|
|
|
|
}
|
|
|
|
fn compute_centroids(&self) -> Vec<T> {
|
|
let mut centroids_map = HashMap::new();
|
|
|
|
for (element, label) in self.data.elements.iter().zip(self.data.labels.iter()) {
|
|
let mut centroid = centroids_map.entry(label).or_insert(vec![]);
|
|
centroid.push(element.clone());
|
|
}
|
|
|
|
let mut new_centroids = vec![];
|
|
|
|
for (_, value) in centroids_map {
|
|
new_centroids.push((self.centroid)(&value));
|
|
}
|
|
|
|
new_centroids
|
|
}
|
|
|
|
pub fn iter(&self) -> Zip<Iter<T>, Iter<usize>> {
|
|
self.data.elements.iter().zip(self.data.labels.iter())
|
|
}
|
|
|
|
pub fn into_iter(self) -> Vec<Vec<T>> {
|
|
let mut map = HashMap::new();
|
|
|
|
for (element, label) in self.data.elements.iter().zip(self.data.labels.iter()) {
|
|
let mut centroid = map.entry(label).or_insert(vec![]);
|
|
centroid.push(element.clone());
|
|
}
|
|
|
|
let mut output = vec![];
|
|
|
|
for (_, value) in map {
|
|
let mut cluster = vec![];
|
|
|
|
for element in value {
|
|
cluster.push(element);
|
|
}
|
|
|
|
output.push(cluster);
|
|
}
|
|
|
|
output
|
|
}
|
|
}
|
|
|
|
pub fn kmeans<T: Clone + PartialEq, D: Fn(&T, &T) -> f64, C: Fn(&Vec<T>) -> T>(
|
|
elements: Vec<T>,
|
|
initial: Vec<T>,
|
|
distance: D,
|
|
centroid: C,
|
|
max_iteration: usize,
|
|
) -> Result<(Kmeans<T, D, C>,usize), Error> {
|
|
|
|
let mut clusters = Kmeans::new(elements, initial, distance, centroid);
|
|
|
|
let mut counter = 0;
|
|
|
|
let iterations = loop {
|
|
|
|
counter += 1;
|
|
|
|
if clusters.iterate() || counter > max_iteration {
|
|
break counter;
|
|
}
|
|
|
|
};
|
|
|
|
Ok((clusters, iterations))
|
|
}
|
|
|