1 #ifndef GENESIS_UTILS_MATH_KMEANS_H_
2 #define GENESIS_UTILS_MATH_KMEANS_H_
46 #include <unordered_set>
90 template<
typename Po
int >
118 size_t run( std::vector<Point>
const& data,
size_t const k )
125 argument_checks_( data, k );
136 runtime_checks_( data, k );
138 size_t iteration = 0;
139 bool changed_assignment;
147 changed_assignment =
lloyd_step( data, assignments_, centroids_ );
150 runtime_checks_( data, k );
153 auto const empty_centroids = get_empty_centroids_();
154 if( ! empty_centroids.empty() ) {
155 LOG_INFO <<
"Empty centroid occurred: " << empty_centroids.size();
157 data, assignments_, centroids_, empty_centroids
162 }
while( changed_assignment && iteration < max_iterations_ );
181 assignments_ = value;
198 if( assignments_.empty() || centroids_.empty() ) {
199 throw std::runtime_error(
"No assignments or centroids set yet." );
201 auto result = std::vector<size_t>( centroids_.size(), 0 );
202 for(
auto ass : assignments_ ) {
203 assert( ass <= result.size() );
211 assignments_.clear();
220 std::vector<Point>
const& data
227 return max_iterations_;
233 throw std::runtime_error(
"Cannot use 0 as max_iterations for Kmeans." );
235 max_iterations_ = value;
241 return init_strategy_;
246 init_strategy_ = value;
263 virtual void initialize( std::vector<Point>
const& data,
size_t const k )
266 switch( init_strategy_ ) {
268 init_with_random_assignments_( data, k );
272 init_with_random_centroids_( data, k );
276 init_with_kmeans_plus_plus_( data, k );
284 if( assignments_.size() == 0 && centroids_.size() == 0 ) {
286 init_with_random_centroids_( data, k );
288 }
else if( assignments_.size() == 0 && centroids_.size() > 0 ) {
292 }
else if( assignments_.size() > 0 && centroids_.size() == 0 ) {
298 assert( assignments_.size() > 0 && centroids_.size() > 0 );
303 if( assignments_.size() == 0 ) {
304 assignments_ = std::vector<size_t>( data.size(), 0 );
315 std::vector<Point>
const& data,
325 std::vector<Point>
const& data,
335 return changed_assignment;
339 std::vector<Point>
const& data,
344 bool changed_assignment =
false;
347 #pragma omp parallel for
348 for(
size_t i = 0; i < data.size(); ++i ) {
357 #pragma omp atomic write
358 changed_assignment =
true;
362 return changed_assignment;
369 size_t min_i = std::numeric_limits<size_t>::max();
370 double min_d = std::numeric_limits<double>::max();
373 for(
size_t i = 0; i <
centroids.size(); ++i ) {
381 return { min_i, min_d };
385 std::vector<Point>
const& data,
392 result.variances = std::vector<double>( k, 0.0 );
393 result.counts = std::vector<size_t>( k, 0 );
394 result.distances = std::vector<double>( data.size(), 0.0 );
397 #pragma omp parallel for
398 for(
size_t i = 0; i < data.size(); ++i ) {
406 auto const dist =
distance( centroid, data[ i ] );
407 result.distances[ i ] = dist;
410 #pragma omp atomic update
411 result.variances[ a ] += dist * dist;
412 #pragma omp atomic update
413 ++result.counts[ a ];
417 for(
size_t i = 0; i < k; ++i ) {
418 if( result.counts[ i ] > 0 ) {
419 result.variances[ i ] /= result.counts[ i ];
427 std::vector<Point>
const& data,
430 std::unordered_set<size_t>
const& empty_centroids
434 if( empty_centroids.empty() ) {
439 for(
auto const& ec_idx : empty_centroids ) {
442 assert( clus_info.variances.size() ==
centroids.size() );
443 assert( clus_info.distances.size() == data.size() );
447 auto const max_var_idx =
static_cast<size_t>( std::distance(
448 clus_info.variances.begin(),
449 std::max_element( clus_info.variances.begin(), clus_info.variances.end() )
453 if( clus_info.variances[ max_var_idx ] == 0.0 ) {
460 assert( ec_idx != max_var_idx );
463 assert( clus_info.counts[ ec_idx ] == 0 );
464 assert( clus_info.variances[ ec_idx ] == 0.0 );
467 size_t furth_idx = std::numeric_limits<size_t>::max();
468 double furth_dist = std::numeric_limits<double>::lowest();
469 for(
size_t i = 0; i < data.size(); ++i ) {
473 if( clus_info.distances[ i ] > furth_dist ) {
475 furth_dist = clus_info.distances[ i ];
513 std::vector<Point>
const& data,
519 std::vector<Point>
const& data,
534 void argument_checks_(
535 std::vector<Point>
const& data,
539 if( k > data.size() ) {
540 throw std::runtime_error(
541 "Cannot run Kmeans with more clusters (k == " +
std::to_string( k ) +
546 throw std::runtime_error(
547 "Cannot run Kmeans with zero clusters (k == 0)."
554 throw std::runtime_error(
"Invalid data." );
558 void runtime_checks_(
559 std::vector<Point>
const& data,
562 if( assignments_.size() != data.size() ) {
563 throw std::runtime_error(
568 for(
auto const& assign : assignments_ ) {
570 throw std::runtime_error(
576 if( centroids_.size() != k ) {
577 throw std::runtime_error(
584 void init_with_random_assignments_(
585 std::vector<Point>
const& data,
589 assignments_ = std::vector<size_t>( data.size(), 0 );
593 std::uniform_int_distribution<size_t> distribution( 0, k - 1 );
596 for(
size_t i = 0; i < data.size(); ++i ) {
597 assignments_[ i ] = distribution( engine );
601 void init_with_random_centroids_(
602 std::vector<Point>
const& data,
607 centroids_ = std::vector<Point>();
612 for(
auto const idx : idxs ) {
613 centroids_.push_back( data[idx] );
616 assert( centroids_.size() == k );
619 void init_with_kmeans_plus_plus_(
620 std::vector<Point>
const& data,
625 centroids_ = std::vector<Point>();
631 std::uniform_int_distribution<size_t> first_dist( 0, data.size() - 1 );
632 centroids_.push_back( data[ first_dist(engine) ]);
635 auto data_probs = std::vector<double>( data.size(), 0.0 );
638 for(
size_t i = 1; i < k; ++i ) {
641 #pragma omp parallel for
642 for(
size_t di = 0; di < data.size(); ++di ) {
649 data_probs[ di ] = min_d * min_d;
653 std::discrete_distribution<size_t> distribution(
654 data_probs.begin(), data_probs.end()
656 auto idx = distribution( engine );
657 assert( idx < data.size() );
658 centroids_.push_back( data[ idx ] );
661 assert( centroids_.size() == k );
664 std::unordered_set<size_t> get_empty_centroids_() {
665 auto const k = centroids_.size();
668 auto empties = std::unordered_set<size_t>();
669 for(
size_t i = 0; i < k; ++i ) {
674 for(
size_t i = 0; i < assignments_.size(); ++i ) {
675 assert( assignments_[i] < k );
676 empties.erase( assignments_[i] );
680 if( empties.empty() ) {
686 assert( ! empties.empty() );
696 std::vector<size_t> assignments_;
697 std::vector<Point> centroids_;
699 size_t max_iterations_ = 100;
707 #endif // include guard