kmeans/src/example.rs

112 lines
2.7 KiB
Rust

extern crate rand;
extern crate generic_kmeans;
use std::fmt::{Display, Formatter, Result};
use std::fs::File;
use rand::distributions::Range;
use rand::distributions::normal::Normal;
use generic_kmeans::{kmeans, Clusterable};
#[derive(PartialEq, Copy, Clone, Debug)]
struct Vector2<T> {
pub x: T,
pub y: T,
}
impl<T> Vector2<T> {
pub fn new(x: T, y: T) -> Vector2<T> {
Vector2 {
x: x,
y: y,
}
}
}
impl<T: Display> Display for Vector2<T> {
fn fmt(&self, formatter: &mut Formatter) -> Result {
write!(formatter, "({}, {})", self.x, self.y)
}
}
impl Clusterable for Vector2<f64> {
fn distance(&self, other: &Self) -> f64 {
(self.x - other.x) * (self.x - other.x) + (self.y - other.y) * (self.y - other.y)
}
fn get_centroid<'a, I>(cluster: I) -> Option<Self>
where I: Iterator<Item = &'a Self>, Self: 'a {
let mut centroid = Vector2::new(0.0, 0.0);
let mut count = 0.0;
for i in cluster {
centroid.x += i.x;
centroid.y += i.y;
count += 1.0;
}
if count > 0.0 {
centroid.x /= count as f64;
centroid.y /= count as f64;
Some(centroid)
} else {
None
}
}
}
fn main() {
use rand::distributions::IndependentSample;
let colors = vec![
"blue",
"red",
"green",
];
let range = Range::new(0.0, 10.0);
let mut rng = rand::thread_rng();
let cluster_number = 3;
let mut centers = vec![];
for _ in 0..cluster_number {
let center = Vector2::new(range.ind_sample(&mut rng), range.ind_sample(&mut rng));
centers.push((center, Normal::new(center.x, 0.5), Normal::new(center.y, 0.5)));
}
let mut elements = vec![];
for &mut (_, x_rng, y_rng) in centers.iter_mut() {
for _ in 0..100 {
elements.push(Vector2::new(x_rng.ind_sample(&mut rng), y_rng.ind_sample(&mut rng)));
}
}
let initialization = vec![
Vector2::new(0.0,0.0),
Vector2::new(10.0,0.0),
Vector2::new(0.0,10.0),
];
let (clusters, nb_iterations) = kmeans(initialization, elements, 100000).ok().unwrap();
let mut output = File::create("plot/dat.dat").unwrap();
for (element, &label) in clusters.iter() {
use std::io::Write;
writeln!(output, "{} {} {}", element.x, element.y, colors[label]).unwrap();
}
println!("Finished in {} iterations", nb_iterations);
let mut center_file = File::create("plot/centers.dat").unwrap();
for (&(center, _, _), color) in centers.iter().zip(&colors) {
use std::io::Write;
writeln!(center_file, "{} {} {}", center.x, center.y, color).unwrap();
}
}