kmeans/src/example.rs

128 lines
3.1 KiB
Rust

extern crate rand;
extern crate generic_kmeans;
use std::fmt::{Display, Formatter, Result};
use std::fs::File;
use rand::distributions::Range;
use rand::distributions::normal::Normal;
use generic_kmeans::kmeans;
const CLUSTER_NUMBER: usize = 3;
const VAR: f64 = 1.0;
const MAX: f64 = 10.0;
#[derive(PartialEq, Copy, Clone, Debug)]
struct Vector2<T> {
pub x: T,
pub y: T,
}
impl<T> Vector2<T> {
pub fn new(x: T, y: T) -> Vector2<T> {
Vector2 {
x: x,
y: y,
}
}
}
impl<T: Display> Display for Vector2<T> {
fn fmt(&self, formatter: &mut Formatter) -> Result {
write!(formatter, "({}, {})", self.x, self.y)
}
}
fn distance(v1: &Vector2<f64>, v2: &Vector2<f64>) -> f64 {
let dx = v2.x - v1.x;
let dy = v2.y - v1.y;
dx * dx + dy * dy
}
fn centroid(elements: &Vec<Vector2<f64>>) -> Vector2<f64> {
let mut v = Vector2::new(0.0, 0.0);
for element in elements {
v.x += element.x;
v.y += element.y;
}
v.x /= elements.len() as f64;
v.y /= elements.len() as f64;
v
}
fn generate_points(centers: &mut Vec<(Vector2<f64>, Normal, Normal)>) -> Vec<Vector2<f64>> {
let mut rng = rand::thread_rng();
let mut output = vec![];
for &mut (_, x_rng, y_rng) in centers.iter_mut() {
for _ in 0..100 {
use rand::distributions::IndependentSample;
output.push(Vector2::new(x_rng.ind_sample(&mut rng), y_rng.ind_sample(&mut rng)));
}
}
output
}
fn generate_centers(number: usize) -> Vec<(Vector2<f64>, Normal, Normal)> {
let mut output = vec![];
let range = Range::new(0.0, MAX);
let mut rng = rand::thread_rng();
for _ in 0..number {
use rand::distributions::IndependentSample;
let center = Vector2::new(range.ind_sample(&mut rng), range.ind_sample(&mut rng));
output.push((center, Normal::new(center.x, VAR), Normal::new(center.y, VAR)));
}
output
}
fn main() {
let mut centers = generate_centers(CLUSTER_NUMBER);
let elements = generate_points(&mut centers);
let initial = vec![
Vector2::new(0.0, 0.0),
Vector2::new(10.0, 0.0),
Vector2::new(0.0, 10.0),
];
let colors = vec![
"blue",
"red",
"green",
];
let (kmeans, nb_iterations) = kmeans(elements, initial, distance, centroid, 1000).unwrap();
let kmeans = kmeans.into_iter();
println!("Converged in {} iterations.", nb_iterations);
let mut output = File::create("plot/dat.dat").unwrap();
for (index, (cluster, color)) in kmeans.iter().zip(colors.iter()).enumerate() {
println!("Cluster {}: {} elements", index, cluster.len());
for element in cluster {
use std::io::Write;
writeln!(output, "{} {} {}", element.x, element.y, color).unwrap();
}
}
let mut center_file = File::create("plot/centers.dat").unwrap();
for (&(center, _, _), color) in centers.iter().zip(&colors) {
use std::io::Write;
writeln!(center_file, "{} {} {}", center.x, center.y, color).unwrap();
}
}