diff --git a/.gitignore b/.gitignore index 143b1ca..3328184 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ /target/ **/*.rs.bk Cargo.lock +plot diff --git a/plot/centers.dat b/plot/centers.dat deleted file mode 100644 index 9e62446..0000000 --- a/plot/centers.dat +++ /dev/null @@ -1,3 +0,0 @@ -7.362969772457262 2.2292211784138405 blue -7.561872618222036 3.4619500908356526 red -3.9379822748704396 0.35795551057766106 green diff --git a/plot/dat.dat b/plot/dat.dat deleted file mode 100644 index 31dc24a..0000000 --- a/plot/dat.dat +++ /dev/null @@ -1,300 +0,0 @@ -3.1722920796470753 -0.733279293053285 blue -3.167068508218321 0.10323154299065418 blue -3.103950084700323 0.23930517934518614 blue -3.3798146480806963 -0.1827565585954446 blue -2.978283637278063 0.06453933903045073 blue -3.6356181662397535 -0.4044583199837899 blue -2.95372747507819 0.374453759303532 blue -3.3920150133049254 -0.4190603852985506 blue -3.6094600800276067 -0.2144185251138886 red -2.7739400381261774 0.712059481698815 red -3.846841411138542 -0.4888722200579573 red -3.6637633637955247 -0.20523517320505846 red -3.502927671921002 -0.01739950379852473 red -3.570670098713616 -0.04271440668919613 red -4.258771306128453 -0.7638447827738868 red -3.9409335680538744 -0.34461733565342756 red -3.172851060943594 0.5155595500900624 red -3.3660242079847826 0.3176395813234718 red -3.882627446966355 -0.3755163859233328 red -3.258525363706499 0.4069846744870096 red -3.495556786207888 0.3994046031736014 red -3.507425251682668 0.37043602544443244 red -3.927449678905854 -0.12931373550029374 red -3.9401081984845097 -0.21366400121478335 red -3.998440026339032 -0.232044527944281 red -3.485801013529333 0.25181078736199886 red -3.817636850420355 -0.12783722608293957 red -3.480959460450858 0.3356245969338822 red -3.733075996312495 0.019096925850280277 red -4.146995700224781 -0.04460715552241867 red -3.579793889310732 0.5229615388110218 red -3.6484363281436494 0.6045383904999224 red -3.4003348878727695 0.79816498433734 red -3.8439726772185896 0.2756903935248908 red -2.8383760073987534 1.3882253053077225 red -3.592774371838402 0.6789321749560016 red -3.7674662297399286 0.2921517888999538 red -4.115608810992105 -0.07025695746073285 red -3.528690841064493 0.516424034979121 red -3.217234748563143 0.9442394356410887 red -4.131774238527101 -0.04224023552247741 red -3.561080757822947 0.4818635464029179 red -4.2408854028156755 -0.2450782482196694 red -3.6730022742656243 0.29741785124017694 red -4.034026334642208 0.1488231988663128 red -4.119128183686852 -0.1580794094811806 red -3.9753355044557983 0.39627299770817237 red -3.6732827650690103 0.8059859451506826 red -4.488509416398572 -0.1462499511423424 red -4.32661224091649 0.07139296218783381 red -3.613505127994868 0.8377964314698141 red -3.87153299037606 0.5812571966391673 red -3.9902303145964706 0.5859036858002181 red -4.362303526272025 -0.042013012245863846 red -4.134369220146103 0.03349891857151277 red -4.177740948577232 0.12358770141391692 red -3.9540793283557014 0.3356584069120508 red -4.401108898000167 -0.003114398045010347 red -3.498514749532378 0.8800866500487596 red -4.41301152664003 -0.027628861527589144 red -4.215618119197788 0.3840095817249771 red -3.5482423999839146 0.7632539555148093 red -3.6246636191468666 0.7301238699155173 red -4.130094170140202 0.3078994155838442 red -4.034092528846325 0.3317083678197986 red -3.8796159868011064 0.3997569887460371 red -4.219000990049526 0.26810755366819106 red -4.247938088992708 0.39267087884272767 red -3.8030383348439853 0.5663640405892496 red -4.015106218988656 0.2762165711361133 red -3.6998523298348918 1.073558644892152 red -4.756491100686767 -0.20007878482051766 red -4.57835590857826 0.5851382590935945 red -3.941012536283112 1.135248406562449 red -4.283375404221392 0.698265954177747 red -5.196520463690655 0.8269139783274901 red -4.587045998171549 0.9493130438890539 red -4.277924679227115 0.5999740967807994 red -4.79655404175757 -0.03468448586681655 red -4.350986448867171 0.3460195729412589 red -4.3535938869662685 0.42185514891968456 red -4.516734825818239 0.6399861627972976 red -4.577579986201398 0.2693246854906323 red -4.047085482678346 0.9438333498946828 red -4.530054468395934 0.8253339759171412 red -4.62730518036874 0.39979489250253814 red -4.333874585129452 0.4127800284367539 red -4.61684222466966 0.5302988833055177 red -4.725281756162216 0.5076632573165651 red -4.450184320080639 1.4068875202461746 red -4.58559612719457 0.42148197697475753 red -4.517186886807359 1.1835934496680203 red -4.28283274886123 0.8329269399680973 red -4.224190877954441 0.5405588945003291 red -3.7912279751660867 1.2939721795466221 red -3.9822079025333577 1.0351892011246622 red -4.199080672684683 0.5206542057383026 red -4.329610262190835 0.32290838658859233 red -4.3612270824589405 1.0268438488576352 red -4.520204476958637 1.4227425201993709 red -7.722136516021652 1.9470066555970975 red -7.4916746875072855 2.1180961930221773 red -6.504637622726477 2.61151640305994 red -7.177999201946461 2.6255334521727467 red -7.26910394349391 2.234788707898471 red -6.117168606550697 1.866392307558659 red -7.623229088939814 1.792731727437522 red -7.807723494032717 1.84200160683787 red -6.706556568713833 2.2612285132853485 red -6.967195552131049 2.501871065272528 red -7.175843899364814 0.8574262374353125 red -6.8750422920422904 1.2371567676668918 red -7.619809736193145 1.3103889429511164 red -5.953387500844007 2.755233188216527 red -7.723489084817691 1.9463428958783504 red -6.644063412274511 2.5586672455593362 red -7.733605889106188 2.079736410311278 red -7.510572393106473 2.2061220380991124 red -6.856216013634419 2.6544853108123023 red -6.6930033087268885 2.245099227226044 red -7.247942190451006 2.1627019434042754 red -7.741449214385919 2.182334922711081 red -6.746942537637999 1.9959361510642295 red -7.649143630640873 2.22441486698875 red -6.9358767615672035 1.485569780201074 red -7.113591869260028 2.2928004264536828 red -7.427120398051516 2.383864858715106 red -7.349659579825405 2.0949180923760453 red -6.395333641483336 1.7743677054052447 red -6.466838685136713 1.939393843109669 red -7.253548883412662 2.1030668375095667 red -8.07939042280926 2.04729205871798 red -7.548984251183588 1.5499638356328398 red -6.1032067830466215 2.838456614189033 red -7.098490045615251 1.8897716920941956 red -7.59862527364878 2.3456095485144837 red -7.6801188008605985 2.344753403464505 red -7.49670020820185 1.726529815379416 red -7.146293141419195 1.252784057975802 red -7.321579062344307 2.5070505017099913 red -6.798977581928631 2.8315301607072727 red -7.5570303896807625 1.5791506580459989 red -7.272672897968238 1.9990630354210432 red -7.515475563843841 1.762927204669434 red -7.824459641746373 2.1013238310168627 red -7.532377761687111 2.0904990744009906 red -6.4714917643178085 2.782633453745481 red -7.493513215157782 2.1060717676124643 red -7.130875386373429 2.561241884726998 red -7.1965107684319065 2.090188844246861 red -7.7937385851348395 2.257283846237026 red -7.641528665375075 1.4752990325524786 red -7.195530312258788 2.086730017606406 red -7.520649659019042 2.0002708025432363 red -7.902751665079369 2.2023323361646625 red -8.120404543041419 2.0665508009908136 red -6.415331460879512 1.68412961311886 red -7.364998983795844 1.4399331873149213 red -8.14234624185481 2.0750561226272866 red -7.1704254033423585 2.181982728843698 red -7.402057697660059 1.533038191027845 red -7.664029025362594 1.7820627943094123 red -7.669404374935257 1.8538765822930114 red -7.371627210148008 1.558805724375508 red -7.0511942299360495 2.3419639393047356 red -7.329465356185301 2.177923167701359 red -7.771413252571511 2.11242037492263 red -7.663725996090565 1.7126632923352003 red -7.397374912219294 2.399953582766828 red -7.399196915900515 1.5693121453666254 red -7.147337668439287 2.108371971884583 red -7.20952148595069 1.2275066517040163 red -7.500514079976969 2.2498592424305035 red -7.56327947368118 1.212640674473267 red -6.850697308173663 2.489187247297768 red -6.818838516960285 2.265360970377559 red -6.909640919530763 2.4401077434705067 red -6.654089805070571 2.8505331135797998 red -8.000114875345277 2.119595362298255 red -6.960334760381899 2.70748387546396 red -7.838299264896003 2.715157569111897 red -7.633463429101306 2.47314113739334 red -7.416688794086579 2.889935392036926 red -7.805609921172734 2.5989409709922384 red -7.539781828336027 2.9330292220419616 red -7.597116683779243 3.2260333749067014 red -8.466410885039707 2.7377320271447263 red -7.133556641250677 2.9471790306833077 red -7.769517340230654 2.7763025185765584 red -7.337931150414927 2.5522719526839173 red -7.74023489812714 3.065068199792865 red -7.683759543909188 2.364155681212309 red -7.571610243520974 2.4839540444329544 red -7.095520936370959 2.92914400483635 red -8.041128388484529 2.7894817375794916 red -7.144407268059606 2.9691452384737635 red -8.505593495433358 2.1466363274907447 red -8.632508392616 2.22257963053884 red -7.911123376257175 2.5213466516039627 red -7.6940185408964545 3.352293659663364 red -7.991456661120536 2.589987966181656 red -8.087698508849215 2.556181174894577 red -7.718150718475244 2.8094487856195003 red -7.47315392406206 3.4782308361876915 red -8.012813012736501 3.122791194321132 red -8.174108122697525 3.429234781425609 red -7.766161280844902 3.7561802057042244 red -7.263545882712765 3.860939909657454 red -8.096615450151507 3.445162969565429 red -6.80705612289653 3.695299133705912 red -7.787262071030116 3.0400121637120074 red -7.921825185220113 4.11562189496709 red -7.352488965852293 3.0962605817874653 red -8.130241325680757 3.12580170376337 red -7.23034764678279 3.647675952767801 red -6.767616299967328 4.5653279969715985 red -7.7395783849343465 3.869413921383619 red -8.158085116611321 2.971554554199345 red -8.07002060099343 2.821712503112572 red -7.062756873013518 3.589476422002454 red -7.677978837449784 3.223400111739516 red -7.494413355292032 3.257440270821772 red -8.288820117028758 3.292305271320214 red -7.649362000563177 4.271160426138838 red -7.616020117241795 3.95881434948141 red -7.195507257591824 3.9428433603688404 red -7.245687352457268 2.9363826343846484 red -7.588216917623389 3.3532148773495645 red -7.126858553632669 2.9739175202264714 red -7.983026848285219 3.824046212867809 red -7.475292436908247 3.323994322421076 red -7.936496484568236 3.4794162984228825 red -6.76205135484123 3.5843519824089682 red -7.426875937925392 4.038873122243772 red -8.21951306389346 3.6399123259998225 red -7.324026068271806 3.061370274841475 red -8.569032256116909 3.3329612438310345 red -7.318164533243619 3.069397489626727 red -6.958855114032112 3.280972206830141 red -8.692122242415982 3.6918095437918073 red -7.2383605768073185 4.266375385002773 red -7.370919915669186 3.7015685565455962 red -8.150362283646688 3.329309338393711 red -7.479996619904384 3.0546845699004916 red -7.906092072147054 3.876850938641441 red -7.180680249472975 3.9577888988740364 red -7.322190245719994 3.2826421528571625 red -7.255438971347838 3.513450999321192 red -8.317600496173556 4.034622634968662 red -7.88276796985749 2.7966518004810754 red -8.103899733693975 4.001561001136776 red -6.954538631721563 2.976496115773739 red -7.240396443716532 3.2846074175197377 red -7.5850349627285185 2.8942969475507017 red -7.287112882089894 3.216029227750091 red -7.623359140518172 2.9593967166439348 red -7.713574159860361 3.4895881829426822 red -7.025980576805368 3.6433631066041405 red -7.444754728796709 3.483957785955234 red -6.838126438683099 3.3484180783059685 red -7.263944759925077 2.9654682573147193 red -7.6315888677034645 3.2466591797995314 red -6.826319178392804 4.484284866115551 red -7.360383098198501 3.7990932090392033 red -8.094085099203799 3.836014940244799 red -7.321057251480907 3.174050781865412 red -7.549705399139268 2.5146338343626593 red -7.1448124339102765 3.9680523242145584 red -7.734875951386552 3.2714316107507564 red -9.675705035077282 3.1167373037722834 red -7.444144597466469 3.089028651011974 red -6.783074702273208 3.211128587952996 red -7.47683218710337 2.518766753491712 red -7.163021503821117 3.252526493441669 red -7.980469959385454 4.158922284653899 red -8.52415725815523 3.11232846252088 red -7.697640498922979 4.2285295178410545 red -7.979057752484617 3.772083324553033 red -7.129657086439503 4.270166790068663 red -7.751753583966503 3.793402426923823 red -7.962022594729247 3.557862833899373 red -7.860303846423593 3.179038118142694 red -7.594168785731357 3.6788496848681413 red -9.124385644351904 2.468503337895192 red -7.713454204203212 2.9490935786628114 red -6.978844559548216 3.0178337639638806 red -6.906575699777272 3.3487729496964964 red -7.581454792889804 4.080280777611024 red -7.769797486055793 3.568644324142498 red -7.514981043050376 3.1549995905568915 red -8.076143978924529 3.470220123136699 red -7.29875139666393 3.3347908360349003 red -8.625871616822147 3.10608064651563 red -7.017251111494703 3.781879721569966 red -7.228136277849404 2.9054637393699094 red -7.270449993677676 3.01328472629758 red -7.3710753704641245 3.5688528791612377 red -7.933270825676227 3.5624377469972264 red -7.471503469540416 3.202424426475704 red -6.741829988777134 3.7063692230617864 red diff --git a/plot/diagram.png b/plot/diagram.png deleted file mode 100644 index 0b334d1..0000000 Binary files a/plot/diagram.png and /dev/null differ diff --git a/plot/plot.gpi b/plot/plot.gpi deleted file mode 100644 index 2b39561..0000000 --- a/plot/plot.gpi +++ /dev/null @@ -1,11 +0,0 @@ -set terminal png size 1280,720 enhanced font "Sans,15" -set output "diagram.png" -set pointsize 3 -set xrange [0:10] -set yrange [0:10] -plot "< awk '{if($3 == \"red\") print}' dat.dat" u 1:2 t "red" pt 4, \ - "< awk '{if($3 == \"green\") print}' dat.dat" u 1:2 t "green" pt 4, \ - "< awk '{if($3 == \"blue\") print}' dat.dat" u 1:2 t "blue" pt 4, \ - "< awk '{if($3 == \"red\") print}' centers.dat" u 1:2 t "red" pt 7, \ - "< awk '{if($3 == \"green\") print}' centers.dat" u 1:2 t "green" pt 7, \ - "< awk '{if($3 == \"blue\") print}' centers.dat" u 1:2 t "blue" pt 7 diff --git a/plot/regen.sh b/plot/regen.sh deleted file mode 100755 index 1636408..0000000 --- a/plot/regen.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/usr/bin/env bash -cd .. -cargo run --release --bin example -cd plot -gnuplot plot.gpi diff --git a/src/cluster.rs b/src/cluster.rs deleted file mode 100644 index 7b0e8ae..0000000 --- a/src/cluster.rs +++ /dev/null @@ -1,38 +0,0 @@ -pub type Cluster = Vec; - -pub trait Clusterable where Self: Sized + Clone + PartialEq { - fn distance(&self, rhs: &Self) -> f64; - fn get_centroid(elements: &Vec) -> Option; -} - -macro_rules! impl_clusterable { - ( $type: ty) => { - impl Clusterable for $type { - fn distance(&self, rhs: &Self) -> f64 { - if self > rhs { - *self as f64 - *rhs as f64 - } else { - *rhs as f64 - *self as f64 - } - } - - fn get_centroid(elements: &Vec) -> Option { - - if elements.len() == 0 { - return None; - } - - let mut tmp = 0.0 as Self; - for element in elements { - tmp += *element as Self; - } - Some(tmp / elements.len() as Self) - - } - } - } -} - -impl_clusterable!(f32); -impl_clusterable!(f64); - diff --git a/src/clusterable.rs b/src/clusterable.rs new file mode 100644 index 0000000..f685424 --- /dev/null +++ b/src/clusterable.rs @@ -0,0 +1,43 @@ +use std::fmt::Debug; + +pub trait Clusterable where Self: Sized + Clone + PartialEq + Debug { + fn distance(&self, rhs: &Self) -> f64; + fn get_centroid<'a, I>(elements: I) -> Option + where I: Iterator, Self: 'a; +} + +macro_rules! impl_clusterable { + ( $type: ty) => { + impl Clusterable for $type { + fn distance(&self, rhs: &Self) -> f64 { + if self > rhs { + *self as f64 - *rhs as f64 + } else { + *rhs as f64 - *self as f64 + } + } + + fn get_centroid<'a, I>(elements: I) -> Option + where I: Iterator, Self: 'a { + + let mut tmp = 0.0; + let mut count = 0.0; + for element in elements { + tmp += element; + count += 1.0; + } + + if count > 0.0 { + Some(tmp / count) + } else { + None + } + + } + } + } +} + +impl_clusterable!(f32); +impl_clusterable!(f64); + diff --git a/src/example.rs b/src/example.rs index a2ad4c4..79c5668 100644 --- a/src/example.rs +++ b/src/example.rs @@ -7,7 +7,7 @@ use rand::distributions::Range; use rand::distributions::normal::Normal; use generic_kmeans::{kmeans, Clusterable}; -#[derive(PartialEq, Copy, Clone)] +#[derive(PartialEq, Copy, Clone, Debug)] struct Vector2 { pub x: T, pub y: T, @@ -30,28 +30,28 @@ impl Display for Vector2 { impl Clusterable for Vector2 { fn distance(&self, other: &Self) -> f64 { - (self.x - other.x) * (self.x - other.y) + (self.y - other.y) * (self.y - other.y) + (self.x - other.x) * (self.x - other.x) + (self.y - other.y) * (self.y - other.y) } - fn get_centroid(cluster: &Vec>) -> Option> { - - let len = cluster.len(); - - if len == 0 { - return None; - } + fn get_centroid<'a, I>(cluster: I) -> Option + where I: Iterator, Self: 'a { let mut centroid = Vector2::new(0.0, 0.0); + let mut count = 0.0; for i in cluster { centroid.x += i.x; centroid.y += i.y; + count += 1.0; } - centroid.x /= len as f64; - centroid.y /= len as f64; - - Some(centroid) + if count > 0.0 { + centroid.x /= count as f64; + centroid.y /= count as f64; + Some(centroid) + } else { + None + } } } @@ -85,21 +85,24 @@ fn main() { } + let initialization = vec![ + Vector2::new(0.0,0.0), + Vector2::new(10.0,0.0), + Vector2::new(0.0,10.0), + ]; - let (clusters, nb_iterations) = kmeans( - centers.iter().map(|x| x.clone().0).collect::>(), elements, 100000).ok().unwrap(); - println!("{}", nb_iterations); + let (clusters, nb_iterations) = kmeans(initialization, elements, 100000).ok().unwrap(); let mut output = File::create("plot/dat.dat").unwrap(); - for (cluster, color) in clusters.iter().zip(&colors) { - for element in cluster { - use std::io::Write; - writeln!(output, "{} {} {}", element.x, element.y, color).unwrap(); - } + for (element, &label) in clusters.iter() { + use std::io::Write; + writeln!(output, "{} {} {}", element.x, element.y, colors[label]).unwrap(); } + println!("Finished in {} iterations", nb_iterations); + let mut center_file = File::create("plot/centers.dat").unwrap(); for (&(center, _, _), color) in centers.iter().zip(&colors) { use std::io::Write; diff --git a/src/kmeans.rs b/src/kmeans.rs index 29a3c8d..a245dc2 100644 --- a/src/kmeans.rs +++ b/src/kmeans.rs @@ -1,48 +1,82 @@ -use std; - -use kmeansdata::KmeansData; -use cluster::{Cluster, Clusterable}; +use std::slice::Iter; +use std::iter::Zip; +use clusterable::Clusterable; pub struct Kmeans { - pub centroids: Vec, - pub data: KmeansData, + centroids: Vec, + elements: Vec, + labels: Vec, + cluster_number: usize, } impl Kmeans { - pub fn new(centroids: Vec, data: Vec>) -> Kmeans { + pub fn new(centroids: Vec, data: Vec) -> Kmeans { + let labels = Kmeans::build_labels(¢roids, &data); + let cluster_number = centroids.len(); Kmeans { centroids: centroids, - data: KmeansData::from_clusters(data), + elements: data, + labels: labels, + cluster_number: cluster_number } } - pub fn guess_centroids(data: Vec>) -> Kmeans { + /// \returns True if converged + pub fn iterate(&mut self) -> bool { + // Update the centroids + let centroids = Kmeans::build_centroids(&self.elements, &self.labels, self.cluster_number); + + if self.centroids == centroids { + true + } else { + self.centroids = centroids; + Kmeans::update_labels(&self.centroids, &self.elements, &mut self.labels); + false + } + } + + pub fn build_labels(centroids: &Vec, data: &Vec) -> Vec { + debug_assert_ne!(0, centroids.len()); + + let mut output = vec![0; data.len()]; + Kmeans::update_labels(centroids, data, &mut output); + output + } + + pub fn update_labels(centroids: &Vec, data: &Vec, labels: &mut Vec) { + for (element, new_label) in data.iter().zip(labels.iter_mut()) { + *new_label = centroids + .iter() + .enumerate() + .min_by(|&(_, c1), &(_, c2)| { + c1.distance(element).partial_cmp(&c2.distance(element)).unwrap() + }).unwrap().0; + } + } + + pub fn build_centroids(data: &Vec, labels: &Vec, cluster_number: usize) -> Vec { + let mut centroids = vec![]; - for cluster in &data { - if let Some(centroid) = T::get_centroid(cluster) { - centroids.push(centroid); + for label in 0..cluster_number { + + let to_consider = data + .iter() + .enumerate() + .filter(|&(index, _)| labels[index] == label) + .map(|(_, element)| element); + + if let Some(centroid) = T::get_centroid(to_consider) { + centroids.push(centroid); } } - Kmeans::new(centroids, data) - + centroids } - fn from_data(centroids: Vec, data: KmeansData) -> Kmeans { - Kmeans { - centroids: centroids, - data: data, - } + pub fn iter(&self) -> Zip, Iter> { + debug_assert_eq!(self.elements.len(), self.labels.len()); + return self.elements.iter().zip(self.labels.iter()); } - pub fn next_iteration(self) -> (Kmeans, bool) { - let (new_centroids, data) = self.data.iterate(&self.centroids); - let stable = new_centroids == self.centroids; - (Kmeans::from_data(new_centroids, data), stable) - } - - pub fn iter(&self) -> std::slice::Iter> { - self.data.clusters() - } } diff --git a/src/kmeansdata.rs b/src/kmeansdata.rs deleted file mode 100644 index aca75ce..0000000 --- a/src/kmeansdata.rs +++ /dev/null @@ -1,168 +0,0 @@ -use std; -use std::marker::PhantomData; - -use cluster::{Clusterable, Cluster}; - - -pub struct KmeansData { - clusters: Vec>, -} - -impl KmeansData { - - pub fn from_clusters(data: Vec>) -> KmeansData { - KmeansData { - clusters: data, - } - } - - pub fn new(centroids: &Vec, data: KmeansDataIntoIter) -> KmeansData { - - let mut clusters = vec![]; - - for _ in centroids { - clusters.push(Cluster::new()); - } - - let mut new_kmeans: KmeansData = KmeansData { - clusters: clusters, - }; - - for element in data { - - // Compute the distance - let mut distance = std::f64::MAX; - let mut index = 0; - - for (new_index, centroid) in centroids.iter().enumerate() { - let new_distance = element.distance(centroid); - - if new_distance < distance { - distance = new_distance; - index = new_index; - } - } - - // Add element to the new kmeans - new_kmeans.clusters[index].push(element) - - } - - new_kmeans - } - - pub fn add_cluster(&mut self) { - self.clusters.push(Cluster::new()); - } - - pub fn iterate(self, centroids: &Vec) -> (Vec, KmeansData) { - // Compute the result with the given centroids - let result = KmeansData::new(¢roids, self.into_iter()); - - // Compute the new centroids - let mut new_centroids = vec![]; - - for cluster in &result.clusters { - - if let Some(centroid) = T::get_centroid(cluster) { - new_centroids.push(centroid); - } - - } - - (new_centroids, result) - } - - pub fn iter(&self) -> KmeansDataIter { - KmeansDataIter { - global_iter: self.clusters.iter(), - local_iter: None, - _phantom: PhantomData, - } - } - - pub fn clusters(&self) -> std::slice::Iter> { - self.clusters.iter() - } - -} - -pub struct KmeansDataIter<'a, T> where T:'a, T: Clusterable { - global_iter: std::slice::Iter<'a, std::vec::Vec>, - local_iter: Option>, - _phantom: PhantomData, -} - -impl<'a, T: 'a + Clusterable> Iterator for KmeansDataIter<'a, T> { - type Item = &'a T; - fn next(&mut self) -> Option<&'a T> { - if let Some(ref mut local_iter) = self.local_iter { - match local_iter.next() { - Some(t) => Some(t), - None => { - if let Some(next) = self.global_iter.next() { - *local_iter = next.iter(); - match local_iter.next() { - Some(t) => Some(t), - None => None, - } - } else { - None - } - } - } - } else { - self.local_iter = match self.global_iter.next() { - None => None, - Some(t) => Some(t.iter()), - }; - self.next() - } - } -} - -pub struct KmeansDataIntoIter where T: Clusterable { - global_iter: std::vec::IntoIter>, - local_iter: Option>, - _phantom: PhantomData, -} - -impl Iterator for KmeansDataIntoIter { - type Item = T; - fn next(&mut self) -> Option { - if let Some(ref mut local_iter) = self.local_iter { - match local_iter.next() { - Some(t) => Some(t), - None => { - if let Some(next) = self.global_iter.next() { - *local_iter = next.into_iter(); - match local_iter.next() { - Some(t) => Some(t), - None => None, - } - } else { - None - } - } - } - } else { - self.local_iter = match self.global_iter.next() { - None => None, - Some(t) => Some(t.into_iter()), - }; - self.next() - } - } -} - -impl IntoIterator for KmeansData { - type Item = T; - type IntoIter = KmeansDataIntoIter; - fn into_iter(self) -> Self::IntoIter { - Self::IntoIter { - global_iter: self.clusters.into_iter(), - local_iter: None, - _phantom: PhantomData, - } - } -} diff --git a/src/lib.rs b/src/lib.rs index 84dd9d6..dcc08ff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,8 @@ pub mod kmeans; -pub mod kmeansdata; -pub mod cluster; -pub mod test; +pub mod clusterable; pub use kmeans::Kmeans; -pub use kmeansdata::KmeansData; -pub use cluster::{Cluster, Clusterable}; +pub use clusterable::Clusterable; pub enum Error { IterationsLimitExceeded, @@ -14,12 +11,11 @@ pub enum Error { pub fn kmeans(centroids: Vec, data: Vec, max_iterations: usize) -> Result<(Kmeans, usize), Error> { - let mut kmeans = Kmeans::new(centroids, vec![data]); + let mut kmeans = Kmeans::new(centroids, data); for nb_iterations in 0..max_iterations { - let (new_kmeans, stable) = kmeans.next_iteration(); - kmeans = new_kmeans; + let stable = kmeans.iterate(); if stable { return Ok((kmeans, nb_iterations)); diff --git a/src/test.rs b/src/test.rs deleted file mode 100644 index e78028a..0000000 --- a/src/test.rs +++ /dev/null @@ -1,44 +0,0 @@ -#[cfg(test)] -mod test { - #[test] - fn iterators() { - use kmeansdata::KmeansData; - let data = vec![4.0, 5.0, 11.0, 12.0, 13.0]; - let kmeans = KmeansData::from_clusters(vec![ - vec![4.0, 5.0], - vec![11.0, 12.0, 13.0], - ]); - - for (val1, val2) in kmeans.into_iter().zip(data) { - assert_eq!(val1, val2); - } - } - - #[test] - fn iterate() { - use kmeans::Kmeans; - - let data = vec![ - vec![4.0, 5.0, 11.0, 12.0], - vec![13.0], - ]; - - let solution = vec![ - vec![4.0, 5.0], - vec![11.0, 12.0, 13.0], - ]; - - let mut kmeans = Kmeans::guess_centroids(data.clone()); - - for _ in 0..4 { - let (new_kmeans, stable) = kmeans.next_iteration(); - kmeans = new_kmeans; - } - - for (k1, k2) in kmeans.iter().zip(solution) { - for (i, j) in k1.iter().zip(&k2) { - assert_eq!(i, j); - } - } - } -}