169 lines
4.5 KiB
Rust
169 lines
4.5 KiB
Rust
use std;
|
|
use std::marker::PhantomData;
|
|
|
|
use cluster::{Clusterable, Cluster};
|
|
|
|
|
|
pub struct KmeansData<T: Clusterable> {
|
|
clusters: Vec<Vec<T>>,
|
|
}
|
|
|
|
impl<T: Clusterable> KmeansData<T> {
|
|
|
|
pub fn from_clusters(data: Vec<Vec<T>>) -> KmeansData<T> {
|
|
KmeansData {
|
|
clusters: data,
|
|
}
|
|
}
|
|
|
|
pub fn new(centroids: &Vec<T>, data: KmeansDataIntoIter<T>) -> KmeansData<T> {
|
|
|
|
let mut clusters = vec![];
|
|
|
|
for _ in centroids {
|
|
clusters.push(Cluster::new());
|
|
}
|
|
|
|
let mut new_kmeans: KmeansData<T> = KmeansData {
|
|
clusters: clusters,
|
|
};
|
|
|
|
for element in data {
|
|
|
|
// Compute the distance
|
|
let mut distance = std::f64::MAX;
|
|
let mut index = 0;
|
|
|
|
for (new_index, centroid) in centroids.iter().enumerate() {
|
|
let new_distance = element.distance(centroid);
|
|
|
|
if new_distance < distance {
|
|
distance = new_distance;
|
|
index = new_index;
|
|
}
|
|
}
|
|
|
|
// Add element to the new kmeans
|
|
new_kmeans.clusters[index].push(element)
|
|
|
|
}
|
|
|
|
new_kmeans
|
|
}
|
|
|
|
pub fn add_cluster(&mut self) {
|
|
self.clusters.push(Cluster::new());
|
|
}
|
|
|
|
pub fn iterate(self, centroids: &Vec<T>) -> (Vec<T>, KmeansData<T>) {
|
|
// Compute the result with the given centroids
|
|
let result = KmeansData::new(¢roids, self.into_iter());
|
|
|
|
// Compute the new centroids
|
|
let mut new_centroids = vec![];
|
|
|
|
for cluster in &result.clusters {
|
|
|
|
if let Some(centroid) = T::get_centroid(cluster) {
|
|
new_centroids.push(centroid);
|
|
}
|
|
|
|
}
|
|
|
|
(new_centroids, result)
|
|
}
|
|
|
|
pub fn iter(&self) -> KmeansDataIter<T> {
|
|
KmeansDataIter {
|
|
global_iter: self.clusters.iter(),
|
|
local_iter: None,
|
|
_phantom: PhantomData,
|
|
}
|
|
}
|
|
|
|
pub fn clusters(&self) -> std::slice::Iter<Cluster<T>> {
|
|
self.clusters.iter()
|
|
}
|
|
|
|
}
|
|
|
|
pub struct KmeansDataIter<'a, T> where T:'a, T: Clusterable {
|
|
global_iter: std::slice::Iter<'a, std::vec::Vec<T>>,
|
|
local_iter: Option<std::slice::Iter<'a, T>>,
|
|
_phantom: PhantomData<T>,
|
|
}
|
|
|
|
impl<'a, T: 'a + Clusterable> Iterator for KmeansDataIter<'a, T> {
|
|
type Item = &'a T;
|
|
fn next(&mut self) -> Option<&'a T> {
|
|
if let Some(ref mut local_iter) = self.local_iter {
|
|
match local_iter.next() {
|
|
Some(t) => Some(t),
|
|
None => {
|
|
if let Some(next) = self.global_iter.next() {
|
|
*local_iter = next.iter();
|
|
match local_iter.next() {
|
|
Some(t) => Some(t),
|
|
None => None,
|
|
}
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
self.local_iter = match self.global_iter.next() {
|
|
None => None,
|
|
Some(t) => Some(t.iter()),
|
|
};
|
|
self.next()
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct KmeansDataIntoIter<T> where T: Clusterable {
|
|
global_iter: std::vec::IntoIter<std::vec::Vec<T>>,
|
|
local_iter: Option<std::vec::IntoIter<T>>,
|
|
_phantom: PhantomData<T>,
|
|
}
|
|
|
|
impl<T: Clusterable> Iterator for KmeansDataIntoIter<T> {
|
|
type Item = T;
|
|
fn next(&mut self) -> Option<T> {
|
|
if let Some(ref mut local_iter) = self.local_iter {
|
|
match local_iter.next() {
|
|
Some(t) => Some(t),
|
|
None => {
|
|
if let Some(next) = self.global_iter.next() {
|
|
*local_iter = next.into_iter();
|
|
match local_iter.next() {
|
|
Some(t) => Some(t),
|
|
None => None,
|
|
}
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
self.local_iter = match self.global_iter.next() {
|
|
None => None,
|
|
Some(t) => Some(t.into_iter()),
|
|
};
|
|
self.next()
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T: Clusterable> IntoIterator for KmeansData<T> {
|
|
type Item = T;
|
|
type IntoIter = KmeansDataIntoIter<T>;
|
|
fn into_iter(self) -> Self::IntoIter {
|
|
Self::IntoIter {
|
|
global_iter: self.clusters.into_iter(),
|
|
local_iter: None,
|
|
_phantom: PhantomData,
|
|
}
|
|
}
|
|
}
|