Update, clean, test, example
This commit is contained in:
parent
c3aff5c77d
commit
14455d3693
|
@ -1,6 +1,11 @@
|
||||||
[package]
|
[package]
|
||||||
name = "kmeans"
|
name = "generic_kmeans"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
authors = ["Thomas Forgione <thomas@tforgione.fr>"]
|
authors = ["Thomas Forgione <thomas@tforgione.fr>"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "example"
|
||||||
|
path = "src/example.rs"
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
pub type Cluster<T> = Vec<T>;
|
||||||
|
|
||||||
|
pub trait Clusterable where Self: Sized + Clone + PartialEq {
|
||||||
|
fn distance(&self, rhs: &Self) -> f64;
|
||||||
|
fn get_centroid(elements: &Vec<Self>) -> Option<Self>;
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! impl_clusterable {
|
||||||
|
( $type: ty) => {
|
||||||
|
impl Clusterable for $type {
|
||||||
|
fn distance(&self, rhs: &Self) -> f64 {
|
||||||
|
if self > rhs {
|
||||||
|
*self as f64 - *rhs as f64
|
||||||
|
} else {
|
||||||
|
*rhs as f64 - *self as f64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_centroid(elements: &Vec<Self>) -> Option<Self> {
|
||||||
|
|
||||||
|
if elements.len() == 0 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut tmp = 0.0 as Self;
|
||||||
|
for element in elements {
|
||||||
|
tmp += *element as Self;
|
||||||
|
}
|
||||||
|
Some(tmp / elements.len() as Self)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_clusterable!(f32);
|
||||||
|
impl_clusterable!(f64);
|
||||||
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
extern crate generic_kmeans;
|
||||||
|
|
||||||
|
use std::fmt::{Display, Formatter, Result};
|
||||||
|
|
||||||
|
use generic_kmeans::{kmeans, Clusterable};
|
||||||
|
|
||||||
|
#[derive(PartialEq, Clone)]
|
||||||
|
struct Vector2<T> {
|
||||||
|
pub x: T,
|
||||||
|
pub y: T,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Vector2<T> {
|
||||||
|
pub fn new(x: T, y: T) -> Vector2<T> {
|
||||||
|
Vector2 {
|
||||||
|
x: x,
|
||||||
|
y: y,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Display> Display for Vector2<T> {
|
||||||
|
fn fmt(&self, formatter: &mut Formatter) -> Result {
|
||||||
|
write!(formatter, "({}, {})", self.x, self.y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Clusterable for Vector2<f64> {
|
||||||
|
fn distance(&self, other: &Self) -> f64 {
|
||||||
|
(self.x - other.x) * (self.x - other.y) + (self.y - other.y) * (self.y - other.y)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_centroid(cluster: &Vec<Vector2<f64>>) -> Option<Vector2<f64>> {
|
||||||
|
|
||||||
|
let len = cluster.len();
|
||||||
|
|
||||||
|
if len == 0 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut centroid = Vector2::new(0.0, 0.0);
|
||||||
|
|
||||||
|
for i in cluster {
|
||||||
|
centroid.x += i.x;
|
||||||
|
centroid.y += i.y;
|
||||||
|
}
|
||||||
|
|
||||||
|
centroid.x /= len as f64;
|
||||||
|
centroid.y /= len as f64;
|
||||||
|
|
||||||
|
Some(centroid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
|
||||||
|
let elements = vec![
|
||||||
|
Vector2::new(8.0, 3.0),
|
||||||
|
Vector2::new(9.0, 3.0),
|
||||||
|
Vector2::new(9.0, 2.0),
|
||||||
|
Vector2::new(1.0, 8.0),
|
||||||
|
Vector2::new(2.0, 9.0),
|
||||||
|
Vector2::new(3.0, 8.0),
|
||||||
|
];
|
||||||
|
|
||||||
|
let initial = vec![
|
||||||
|
Vector2::new(1.0, 10.0),
|
||||||
|
Vector2::new(10.0, 0.0),
|
||||||
|
];
|
||||||
|
|
||||||
|
let (clusters, nb_iterations) = kmeans(initial, elements, 1000).ok().unwrap();
|
||||||
|
|
||||||
|
println!("{}", nb_iterations);
|
||||||
|
|
||||||
|
for (index, cluster) in clusters.iter().enumerate() {
|
||||||
|
println!("CLUSTER {}", index);
|
||||||
|
|
||||||
|
for element in cluster {
|
||||||
|
println!("\t{}", element);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,48 @@
|
||||||
|
use std;
|
||||||
|
|
||||||
|
use kmeansdata::KmeansData;
|
||||||
|
use cluster::{Cluster, Clusterable};
|
||||||
|
|
||||||
|
pub struct Kmeans<T: Clusterable> {
|
||||||
|
pub centroids: Vec<T>,
|
||||||
|
pub data: KmeansData<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T:Clusterable> Kmeans<T> {
|
||||||
|
pub fn new(centroids: Vec<T>, data: Vec<Vec<T>>) -> Kmeans<T> {
|
||||||
|
Kmeans {
|
||||||
|
centroids: centroids,
|
||||||
|
data: KmeansData::from_clusters(data),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn guess_centroids(data: Vec<Vec<T>>) -> Kmeans<T> {
|
||||||
|
let mut centroids = vec![];
|
||||||
|
for cluster in &data {
|
||||||
|
if let Some(centroid) = T::get_centroid(cluster) {
|
||||||
|
centroids.push(centroid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Kmeans::new(centroids, data)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
fn from_data(centroids: Vec<T>, data: KmeansData<T>) -> Kmeans<T> {
|
||||||
|
Kmeans {
|
||||||
|
centroids: centroids,
|
||||||
|
data: data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn next_iteration(self) -> (Kmeans<T>, bool) {
|
||||||
|
let (new_centroids, data) = self.data.iterate(&self.centroids);
|
||||||
|
let stable = new_centroids == self.centroids;
|
||||||
|
(Kmeans::from_data(new_centroids, data), stable)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn iter(&self) -> std::slice::Iter<Cluster<T>> {
|
||||||
|
self.data.clusters()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,168 @@
|
||||||
|
use std;
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
use cluster::{Clusterable, Cluster};
|
||||||
|
|
||||||
|
|
||||||
|
pub struct KmeansData<T: Clusterable> {
|
||||||
|
clusters: Vec<Vec<T>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Clusterable> KmeansData<T> {
|
||||||
|
|
||||||
|
pub fn from_clusters(data: Vec<Vec<T>>) -> KmeansData<T> {
|
||||||
|
KmeansData {
|
||||||
|
clusters: data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new(centroids: &Vec<T>, data: KmeansDataIntoIter<T>) -> KmeansData<T> {
|
||||||
|
|
||||||
|
let mut clusters = vec![];
|
||||||
|
|
||||||
|
for _ in centroids {
|
||||||
|
clusters.push(Cluster::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut new_kmeans: KmeansData<T> = KmeansData {
|
||||||
|
clusters: clusters,
|
||||||
|
};
|
||||||
|
|
||||||
|
for element in data {
|
||||||
|
|
||||||
|
// Compute the distance
|
||||||
|
let mut distance = std::f64::MAX;
|
||||||
|
let mut index = 0;
|
||||||
|
|
||||||
|
for (new_index, centroid) in centroids.iter().enumerate() {
|
||||||
|
let new_distance = element.distance(centroid);
|
||||||
|
|
||||||
|
if new_distance < distance {
|
||||||
|
distance = new_distance;
|
||||||
|
index = new_index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add element to the new kmeans
|
||||||
|
new_kmeans.clusters[index].push(element)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
new_kmeans
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_cluster(&mut self) {
|
||||||
|
self.clusters.push(Cluster::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn iterate(self, centroids: &Vec<T>) -> (Vec<T>, KmeansData<T>) {
|
||||||
|
// Compute the result with the given centroids
|
||||||
|
let result = KmeansData::new(¢roids, self.into_iter());
|
||||||
|
|
||||||
|
// Compute the new centroids
|
||||||
|
let mut new_centroids = vec![];
|
||||||
|
|
||||||
|
for cluster in &result.clusters {
|
||||||
|
|
||||||
|
if let Some(centroid) = T::get_centroid(cluster) {
|
||||||
|
new_centroids.push(centroid);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
(new_centroids, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn iter(&self) -> KmeansDataIter<T> {
|
||||||
|
KmeansDataIter {
|
||||||
|
global_iter: self.clusters.iter(),
|
||||||
|
local_iter: None,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clusters(&self) -> std::slice::Iter<Cluster<T>> {
|
||||||
|
self.clusters.iter()
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct KmeansDataIter<'a, T> where T:'a, T: Clusterable {
|
||||||
|
global_iter: std::slice::Iter<'a, std::vec::Vec<T>>,
|
||||||
|
local_iter: Option<std::slice::Iter<'a, T>>,
|
||||||
|
_phantom: PhantomData<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T: 'a + Clusterable> Iterator for KmeansDataIter<'a, T> {
|
||||||
|
type Item = &'a T;
|
||||||
|
fn next(&mut self) -> Option<&'a T> {
|
||||||
|
if let Some(ref mut local_iter) = self.local_iter {
|
||||||
|
match local_iter.next() {
|
||||||
|
Some(t) => Some(t),
|
||||||
|
None => {
|
||||||
|
if let Some(next) = self.global_iter.next() {
|
||||||
|
*local_iter = next.iter();
|
||||||
|
match local_iter.next() {
|
||||||
|
Some(t) => Some(t),
|
||||||
|
None => None,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
self.local_iter = match self.global_iter.next() {
|
||||||
|
None => None,
|
||||||
|
Some(t) => Some(t.iter()),
|
||||||
|
};
|
||||||
|
self.next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct KmeansDataIntoIter<T> where T: Clusterable {
|
||||||
|
global_iter: std::vec::IntoIter<std::vec::Vec<T>>,
|
||||||
|
local_iter: Option<std::vec::IntoIter<T>>,
|
||||||
|
_phantom: PhantomData<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Clusterable> Iterator for KmeansDataIntoIter<T> {
|
||||||
|
type Item = T;
|
||||||
|
fn next(&mut self) -> Option<T> {
|
||||||
|
if let Some(ref mut local_iter) = self.local_iter {
|
||||||
|
match local_iter.next() {
|
||||||
|
Some(t) => Some(t),
|
||||||
|
None => {
|
||||||
|
if let Some(next) = self.global_iter.next() {
|
||||||
|
*local_iter = next.into_iter();
|
||||||
|
match local_iter.next() {
|
||||||
|
Some(t) => Some(t),
|
||||||
|
None => None,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
self.local_iter = match self.global_iter.next() {
|
||||||
|
None => None,
|
||||||
|
Some(t) => Some(t.into_iter()),
|
||||||
|
};
|
||||||
|
self.next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Clusterable> IntoIterator for KmeansData<T> {
|
||||||
|
type Item = T;
|
||||||
|
type IntoIter = KmeansDataIntoIter<T>;
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
Self::IntoIter {
|
||||||
|
global_iter: self.clusters.into_iter(),
|
||||||
|
local_iter: None,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
201
src/lib.rs
201
src/lib.rs
|
@ -1,190 +1,31 @@
|
||||||
|
pub mod kmeans;
|
||||||
|
pub mod kmeansdata;
|
||||||
|
pub mod cluster;
|
||||||
pub mod test;
|
pub mod test;
|
||||||
|
|
||||||
use std::marker::PhantomData;
|
pub use kmeans::Kmeans;
|
||||||
|
pub use kmeansdata::KmeansData;
|
||||||
|
pub use cluster::{Cluster, Clusterable};
|
||||||
|
|
||||||
pub trait Clusterable where Self: Sized {
|
pub enum Error {
|
||||||
fn distance(&self, rhs: &Self) -> f64;
|
IterationsLimitExceeded,
|
||||||
fn get_centroid(elements: &Vec<Self>) -> Option<Self>;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! impl_clusterable {
|
pub fn kmeans<T: Clusterable>(centroids: Vec<T>, data: Vec<T>, max_iterations: usize)
|
||||||
( $type: ty) => {
|
-> Result<(Kmeans<T>, usize), Error> {
|
||||||
impl Clusterable for $type {
|
|
||||||
fn distance(&self, rhs: &Self) -> f64 {
|
let mut kmeans = Kmeans::new(centroids, vec![data]);
|
||||||
if self > rhs {
|
|
||||||
*self as f64 - *rhs as f64
|
for nb_iterations in 0..max_iterations {
|
||||||
} else {
|
|
||||||
*rhs as f64 - *self as f64
|
let (new_kmeans, stable) = kmeans.next_iteration();
|
||||||
}
|
kmeans = new_kmeans;
|
||||||
|
|
||||||
|
if stable {
|
||||||
|
return Ok((kmeans, nb_iterations));
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_centroid(elements: &Vec<Self>) -> Option<Self> {
|
|
||||||
|
|
||||||
if elements.len() == 0 {
|
|
||||||
return None;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut tmp = 0.0 as Self;
|
Err(Error::IterationsLimitExceeded)
|
||||||
for element in elements {
|
|
||||||
tmp += *element as Self;
|
|
||||||
}
|
|
||||||
Some(tmp / elements.len() as Self)
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl_clusterable!(f32);
|
|
||||||
impl_clusterable!(f64);
|
|
||||||
|
|
||||||
|
|
||||||
pub type Cluster<T> = Vec<T>;
|
|
||||||
|
|
||||||
pub struct Kmeans<T: Clusterable> {
|
|
||||||
clusters: Vec<Vec<T>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Clusterable> Kmeans<T> {
|
|
||||||
|
|
||||||
pub fn new(centroids: Vec<T>, data: KmeansIntoIter<T>) -> Kmeans<T> {
|
|
||||||
|
|
||||||
let mut clusters = vec![];
|
|
||||||
|
|
||||||
for _ in ¢roids {
|
|
||||||
clusters.push(Cluster::new());
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut new_kmeans: Kmeans<T> = Kmeans {
|
|
||||||
clusters: clusters,
|
|
||||||
};
|
|
||||||
|
|
||||||
for element in data {
|
|
||||||
|
|
||||||
// Compute the distance
|
|
||||||
let mut distance = std::f64::MAX;
|
|
||||||
let mut index = 0;
|
|
||||||
|
|
||||||
for (new_index, centroid) in centroids.iter().enumerate() {
|
|
||||||
let new_distance = element.distance(centroid);
|
|
||||||
|
|
||||||
if new_distance < distance {
|
|
||||||
distance = new_distance;
|
|
||||||
index = new_index;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add element to the new kmeans
|
|
||||||
new_kmeans.clusters[index].push(element)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
new_kmeans
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn add_cluster(&mut self) {
|
|
||||||
self.clusters.push(Cluster::new());
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn iterate(self) -> Kmeans<T> {
|
|
||||||
let mut centroids = vec![];
|
|
||||||
|
|
||||||
// Compute the centroids
|
|
||||||
for cluster in &self.clusters {
|
|
||||||
|
|
||||||
if let Some(centroid) = T::get_centroid(cluster) {
|
|
||||||
centroids.push(centroid);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
Kmeans::new(centroids, self.into_iter())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn iter(&self) -> std::slice::Iter<Cluster<T>> {
|
|
||||||
self.clusters.iter()
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct KmeansIter<'a, T> where T:'a, T: Clusterable {
|
|
||||||
global_iter: std::slice::Iter<'a, std::vec::Vec<T>>,
|
|
||||||
local_iter: Option<std::slice::Iter<'a, T>>,
|
|
||||||
_phantom: PhantomData<T>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, T: 'a + Clusterable> Iterator for KmeansIter<'a, T> {
|
|
||||||
type Item = &'a T;
|
|
||||||
fn next(&mut self) -> Option<&'a T> {
|
|
||||||
if let Some(ref mut local_iter) = self.local_iter {
|
|
||||||
match local_iter.next() {
|
|
||||||
Some(t) => Some(t),
|
|
||||||
None => {
|
|
||||||
if let Some(next) = self.global_iter.next() {
|
|
||||||
*local_iter = next.iter();
|
|
||||||
match local_iter.next() {
|
|
||||||
Some(t) => Some(t),
|
|
||||||
None => None,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
self.local_iter = match self.global_iter.next() {
|
|
||||||
None => None,
|
|
||||||
Some(t) => Some(t.iter()),
|
|
||||||
};
|
|
||||||
self.next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct KmeansIntoIter<T> where T: Clusterable {
|
|
||||||
global_iter: std::vec::IntoIter<std::vec::Vec<T>>,
|
|
||||||
local_iter: Option<std::vec::IntoIter<T>>,
|
|
||||||
_phantom: PhantomData<T>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Clusterable> Iterator for KmeansIntoIter<T> {
|
|
||||||
type Item = T;
|
|
||||||
fn next(&mut self) -> Option<T> {
|
|
||||||
if let Some(ref mut local_iter) = self.local_iter {
|
|
||||||
match local_iter.next() {
|
|
||||||
Some(t) => Some(t),
|
|
||||||
None => {
|
|
||||||
if let Some(next) = self.global_iter.next() {
|
|
||||||
*local_iter = next.into_iter();
|
|
||||||
match local_iter.next() {
|
|
||||||
Some(t) => Some(t),
|
|
||||||
None => None,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
self.local_iter = match self.global_iter.next() {
|
|
||||||
None => None,
|
|
||||||
Some(t) => Some(t.into_iter()),
|
|
||||||
};
|
|
||||||
self.next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
impl<T: Clusterable> IntoIterator for Kmeans<T> {
|
|
||||||
type Item = T;
|
|
||||||
type IntoIter = KmeansIntoIter<T>;
|
|
||||||
fn into_iter(self) -> Self::IntoIter {
|
|
||||||
Self::IntoIter {
|
|
||||||
global_iter: self.clusters.into_iter(),
|
|
||||||
local_iter: None,
|
|
||||||
_phantom: PhantomData,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
40
src/test.rs
40
src/test.rs
|
@ -2,14 +2,12 @@
|
||||||
mod test {
|
mod test {
|
||||||
#[test]
|
#[test]
|
||||||
fn iterators() {
|
fn iterators() {
|
||||||
use Kmeans;
|
use kmeansdata::KmeansData;
|
||||||
let data = vec![4.0, 5.0, 11.0, 12.0, 13.0];
|
let data = vec![4.0, 5.0, 11.0, 12.0, 13.0];
|
||||||
let kmeans = Kmeans {
|
let kmeans = KmeansData::from_clusters(vec![
|
||||||
clusters: vec![
|
|
||||||
vec![4.0, 5.0],
|
vec![4.0, 5.0],
|
||||||
vec![11.0, 12.0, 13.0],
|
vec![11.0, 12.0, 13.0],
|
||||||
],
|
]);
|
||||||
};
|
|
||||||
|
|
||||||
for (val1, val2) in kmeans.into_iter().zip(data) {
|
for (val1, val2) in kmeans.into_iter().zip(data) {
|
||||||
assert_eq!(val1, val2);
|
assert_eq!(val1, val2);
|
||||||
|
@ -18,19 +16,29 @@ mod test {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn iterate() {
|
fn iterate() {
|
||||||
use Kmeans;
|
use kmeans::Kmeans;
|
||||||
let kmeans = Kmeans {
|
|
||||||
clusters: vec![
|
let data = vec![
|
||||||
vec![4.0, 5.0, 11.0, 12.0],
|
vec![4.0, 5.0, 11.0, 12.0],
|
||||||
vec![13.0],
|
vec![13.0],
|
||||||
],
|
];
|
||||||
};
|
|
||||||
let kmeans = kmeans.iterate();
|
|
||||||
let kmeans = kmeans.iterate();
|
|
||||||
let kmeans = kmeans.iterate();
|
|
||||||
let kmeans = kmeans.iterate();
|
|
||||||
|
|
||||||
assert_eq!(kmeans.clusters[0], vec![4.0, 5.0]);
|
let solution = vec![
|
||||||
assert_eq!(kmeans.clusters[1], vec![11.0, 12.0, 13.0]);
|
vec![4.0, 5.0],
|
||||||
|
vec![11.0, 12.0, 13.0],
|
||||||
|
];
|
||||||
|
|
||||||
|
let mut kmeans = Kmeans::guess_centroids(data.clone());
|
||||||
|
|
||||||
|
for _ in 0..4 {
|
||||||
|
let (new_kmeans, stable) = kmeans.next_iteration();
|
||||||
|
kmeans = new_kmeans;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (k1, k2) in kmeans.iter().zip(solution) {
|
||||||
|
for (i, j) in k1.iter().zip(&k2) {
|
||||||
|
assert_eq!(i, j);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue