A toolkit for working with phylogenetic data.
v0.20.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
utils/math/kmeans.hpp
Go to the documentation of this file.
1 #ifndef GENESIS_UTILS_MATH_KMEANS_H_
2 #define GENESIS_UTILS_MATH_KMEANS_H_
3 
4 /*
5  Genesis - A toolkit for working with phylogenetic data.
6  Copyright (C) 2014-2018 Lucas Czech and HITS gGmbH
7 
8  This program is free software: you can redistribute it and/or modify
9  it under the terms of the GNU General Public License as published by
10  the Free Software Foundation, either version 3 of the License, or
11  (at your option) any later version.
12 
13  This program is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16  GNU General Public License for more details.
17 
18  You should have received a copy of the GNU General Public License
19  along with this program. If not, see <http://www.gnu.org/licenses/>.
20 
21  Contact:
22  Lucas Czech <lucas.czech@h-its.org>
23  Exelixis Lab, Heidelberg Institute for Theoretical Studies
24  Schloss-Wolfsbrunnenweg 35, D-69118 Heidelberg, Germany
25 */
26 
37 
38 #include <algorithm>
39 #include <cassert>
40 #include <cstddef>
41 #include <functional>
42 #include <limits>
43 #include <random>
44 #include <stdexcept>
45 #include <string>
46 #include <unordered_set>
47 #include <utility>
48 #include <vector>
49 
50 #ifdef GENESIS_OPENMP
51 # include <omp.h>
52 #endif
53 
54 namespace genesis {
55 namespace utils {
56 
57 // =================================================================================================
58 // K-Means Helpers
59 // =================================================================================================
60 
65 {
69  kNone
70 };
71 
77 {
78  std::vector<double> variances;
79  std::vector<size_t> counts;
80  std::vector<double> distances;
81 };
82 
83 // =================================================================================================
84 // Generic K-Means Class
85 // =================================================================================================
86 
90 template< typename Point >
91 class Kmeans
92 {
93 public:
94 
95  // -------------------------------------------------------------------------
96  // Typedefs and Constants
97  // -------------------------------------------------------------------------
98 
99  using value_type = Point;
100 
101  // -------------------------------------------------------------------------
102  // Constructors and Rule of Five
103  // -------------------------------------------------------------------------
104 
105  Kmeans() = default;
106  virtual ~Kmeans() = default;
107 
108  Kmeans( Kmeans const& ) = default;
109  Kmeans( Kmeans&& ) = default;
110 
111  Kmeans& operator= ( Kmeans const& ) = default;
112  Kmeans& operator= ( Kmeans&& ) = default;
113 
114  // -------------------------------------------------------------------------
115  // Member Functions
116  // -------------------------------------------------------------------------
117 
118  size_t run( std::vector<Point> const& data, size_t const k )
119  {
120  if( report_initialization ) {
122  }
123 
124  // Run basic checks. This throws if necessary.
125  argument_checks_( data, k );
126 
127  // Init assigments and centroids.
128  initialize( data, k );
129 
130  // Call the hook.
131  pre_loop_hook( data, assignments_, centroids_ );
132 
133  // By now, the result vectors should be filled correctly.
134  // This replaces asserts. It is slightly more expensive, but this class offers so much
135  // expansion points and custom behaviour, that we better check thoroughly.
136  runtime_checks_( data, k );
137 
138  size_t iteration = 0;
139  bool changed_assigment;
140 
141  do {
142  // Start a new iteration.
143  if( report_iteration ) {
144  report_iteration( iteration + 1 );
145  }
146 
147  changed_assigment = lloyd_step( data, assignments_, centroids_ );
148 
149  // Check again.
150  runtime_checks_( data, k );
151 
152  // Check if there are empty centroids, and if so, treat them.
153  auto const empty_centroids = get_empty_centroids_();
154  if( ! empty_centroids.empty() ) {
155  LOG_INFO << "Empty centroid occurred: " << empty_centroids.size();
156  changed_assigment |= treat_empty_centroids(
157  data, assignments_, centroids_, empty_centroids
158  );
159  }
160 
161  ++iteration;
162  } while( changed_assigment && iteration < max_iterations_ );
163 
164  // Call the hook.
165  post_loop_hook( data, assignments_, centroids_ );
166 
167  return iteration;
168  }
169 
170  // -------------------------------------------------------------------------
171  // Data Access
172  // -------------------------------------------------------------------------
173 
174  std::vector<size_t> const& assignments() const
175  {
176  return assignments_;
177  }
178 
179  Kmeans& assignments( std::vector<size_t> const& value )
180  {
181  assignments_ = value;
182  return *this;
183  }
184 
185  std::vector<Point> const& centroids() const
186  {
187  return centroids_;
188  }
189 
190  Kmeans& centroids( std::vector<Point> const& value )
191  {
192  centroids_ = value;
193  return *this;
194  }
195 
196  std::vector<size_t> cluster_sizes() const
197  {
198  if( assignments_.empty() || centroids_.empty() ) {
199  throw std::runtime_error( "No assignments or centroids set yet." );
200  }
201  auto result = std::vector<size_t>( centroids_.size(), 0 );
202  for( auto ass : assignments_ ) {
203  assert( ass <= result.size() );
204  ++result[ ass ];
205  }
206  return result;
207  }
208 
209  void clear()
210  {
211  assignments_.clear();
212  centroids_.clear();
213  }
214 
215  // -------------------------------------------------------------------------
216  // Properties
217  // -------------------------------------------------------------------------
218 
220  std::vector<Point> const& data
221  ) const {
222  return cluster_info( data, assignments_, centroids_ );
223  }
224 
225  size_t max_iterations() const
226  {
227  return max_iterations_;
228  }
229 
230  Kmeans& max_iterations( size_t value )
231  {
232  if( value == 0 ) {
233  throw std::runtime_error( "Cannot use 0 as max_iterations for Kmeans." );
234  }
235  max_iterations_ = value;
236  return *this;
237  }
238 
240  {
241  return init_strategy_;
242  }
243 
245  {
246  init_strategy_ = value;
247  return *this;
248  }
249 
250  // -------------------------------------------------------------------------
251  // Progress Report
252  // -------------------------------------------------------------------------
253 
254  std::function<void( void )> report_initialization;
255  std::function<void( size_t i )> report_iteration;
256 
257  // -------------------------------------------------------------------------
258  // Virtual Functions
259  // -------------------------------------------------------------------------
260 
261 protected:
262 
263  virtual void initialize( std::vector<Point> const& data, size_t const k )
264  {
265  // Select init stragegies.
266  switch( init_strategy_ ) {
268  init_with_random_assignments_( data, k );
269  break;
270  }
272  init_with_random_centroids_( data, k );
273  break;
274  }
276  init_with_kmeans_plus_plus_( data, k );
277  break;
278  }
279  default: {}
280  }
281 
282  // If the strategy did not yield useful values, we still need to init somehow,
283  // so do this now. This also applies if kNone was selected for init, but no centroids were set.
284  if( assignments_.size() == 0 && centroids_.size() == 0 ) {
285  // Nothing given: Sample random centroids from the data.
286  init_with_random_centroids_( data, k );
287 
288  } else if( assignments_.size() == 0 && centroids_.size() > 0 ) {
289  // Centroids given, but no assigments: Nothing to do for now.
290  // We will calculate the proper assigments in the main loop.
291 
292  } else if( assignments_.size() > 0 && centroids_.size() == 0 ) {
293  // Assignments given, but not centroids: Caculate the latter.
294  update_centroids( data, assignments_, centroids_ );
295 
296  } else {
297  // Both given: Nothing to do.
298  assert( assignments_.size() > 0 && centroids_.size() > 0 );
299  }
300 
301  // If we do not have an assigment vector yet, make one. It will be assigned proper values
302  // once we enter the main loop.
303  if( assignments_.size() == 0 ) {
304  assignments_ = std::vector<size_t>( data.size(), 0 );
305  }
306  }
307 
308  virtual bool data_validation( std::vector<Point> const& data ) const
309  {
310  (void) data;
311  return true;
312  }
313 
314  virtual void pre_loop_hook(
315  std::vector<Point> const& data,
316  std::vector<size_t>& assignments,
317  std::vector<Point>& centroids
318  ) {
319  (void) data;
320  (void) assignments;
321  (void) centroids;
322  }
323 
324  virtual bool lloyd_step(
325  std::vector<Point> const& data,
326  std::vector<size_t>& assignments,
327  std::vector<Point>& centroids
328  ) {
329  // Calculate new assignments and check whether they changed.
330  auto changed_assigment = assign_to_centroids( data, centroids, assignments );
331 
332  // Recalculate the centroids.
333  update_centroids( data, assignments, centroids );
334 
335  return changed_assigment;
336  }
337 
338  virtual bool assign_to_centroids(
339  std::vector<Point> const& data,
340  std::vector<Point> const& centroids,
341  std::vector<size_t>& assignments
342  ) {
343  // Store whether anything changed.
344  bool changed_assigment = false;
345 
346  // Assign each Point to its nearest centroid.
347  #pragma omp parallel for
348  for( size_t i = 0; i < data.size(); ++i ) {
349  auto const new_idx = find_nearest_cluster( centroids, data[i] ).first;
350 
351  if( new_idx != assignments[i] ) {
352  // Update the assignment. No need for locking, as each thread works on its own i.
353  assignments[i] = new_idx;
354 
355  // If we have a new assigment for this datum, we need to do another loop iteration.
356  // Do this atomically, as all threads use this variable.
357  #pragma omp atomic write
358  changed_assigment = true;
359  }
360  }
361 
362  return changed_assigment;
363  }
364 
365  virtual std::pair<size_t, double> find_nearest_cluster(
366  std::vector<Point> const& centroids,
367  Point const& datum
368  ) const {
369  size_t min_i = std::numeric_limits<size_t>::max();
370  double min_d = std::numeric_limits<double>::max();
371 
372  assert( centroids.size() > 0 );
373  for( size_t i = 0; i < centroids.size(); ++i ) {
374  auto const dist = distance( datum, centroids[i] );
375  if( dist < min_d ) {
376  min_i = i;
377  min_d = dist;
378  }
379  }
380 
381  return { min_i, min_d };
382  }
383 
385  std::vector<Point> const& data,
386  std::vector<size_t> const& assignments,
387  std::vector<Point> const& centroids
388  ) const {
389  auto const k = centroids.size();
390 
391  auto result = KmeansClusteringInfo();
392  result.variances = std::vector<double>( k, 0.0 );
393  result.counts = std::vector<size_t>( k, 0 );
394  result.distances = std::vector<double>( data.size(), 0.0 );
395 
396  // Work through the data and assigments and accumulate.
397  #pragma omp parallel for
398  for( size_t i = 0; i < data.size(); ++i ) {
399 
400  // Shorthands.
401  auto const a = assignments[ i ];
402  assert( a < k );
403  auto const& centroid = centroids[ a ];
404 
405  // Get dist from datum to centroid.
406  auto const dist = distance( centroid, data[ i ] );
407  result.distances[ i ] = dist;
408 
409  // Update centroid accumulators.
410  #pragma omp atomic update
411  result.variances[ a ] += dist * dist;
412  #pragma omp atomic update
413  ++result.counts[ a ];
414  }
415 
416  // Build the mean dist to get the variance for each centroid.
417  for( size_t i = 0; i < k; ++i ) {
418  if( result.counts[ i ] > 0 ) {
419  result.variances[ i ] /= result.counts[ i ];
420  }
421  }
422 
423  return result;
424  }
425 
426  virtual bool treat_empty_centroids(
427  std::vector<Point> const& data,
428  std::vector<size_t>& assignments,
429  std::vector<Point>& centroids,
430  std::unordered_set<size_t> const& empty_centroids
431  ) {
432  // If there are not empty centroids, we have nothing to do, and did not change anything,
433  // so return false.
434  if( empty_centroids.empty() ) {
435  return false;
436  }
437 
438  // Process all empty centroid indices.
439  for( auto const& ec_idx : empty_centroids ) {
440  // Get variances and counts of clusters and distances from data to them.
441  auto clus_info = cluster_info( data, assignments, centroids );
442  assert( clus_info.variances.size() == centroids.size() );
443  assert( clus_info.distances.size() == data.size() );
444  assert( data.size() == assignments.size() );
445 
446  // Get index of centroid with max variance.
447  auto const max_var_idx = static_cast<size_t>( std::distance(
448  clus_info.variances.begin(),
449  std::max_element( clus_info.variances.begin(), clus_info.variances.end() )
450  ));
451 
452  // If the max variance is 0, we cannot do anything. All points are the same.
453  if( clus_info.variances[ max_var_idx ] == 0.0 ) {
454  return false;
455  }
456 
457  // The empty centroid cannot be the same as the one we want to take a point from,
458  // because empty clusters have a variance of 0. If this assertion fails, we probably
459  // ran an analysis with k >> data.size() or so.
460  assert( ec_idx != max_var_idx );
461 
462  // The current empty cluster should actually be empty.
463  assert( clus_info.counts[ ec_idx ] == 0 );
464  assert( clus_info.variances[ ec_idx ] == 0.0 );
465 
466  // Find the point in the max var cluster that is furthest away from the centroid.
467  size_t furth_idx = std::numeric_limits<size_t>::max();
468  double furth_dist = std::numeric_limits<double>::lowest();
469  for( size_t i = 0; i < data.size(); ++i ) {
470  if( assignments[i] != max_var_idx ) {
471  continue;
472  }
473  if( clus_info.distances[ i ] > furth_dist ) {
474  furth_idx = i;
475  furth_dist = clus_info.distances[ i ];
476  }
477  }
478 
479  // The point needs to be part of the max var cluster, otherwise something went wrong.
480  assert( assignments[ furth_idx ] == max_var_idx );
481 
482  // Add the point to the empty cluster.
483  assignments[ furth_idx ] = ec_idx;
484 
485  // The following is some test code that could be used to avoid calculating the cluster
486  // info for each empty cluster. However, as we probably almost never run into this
487  // function anyway, we better keep it simple and calculate the info fresh every time.
488 
489  // // Adjust variance and count of the max var cluster.
490  // auto const dsqrt = clus_info.distances[ furth_idx ] * clus_info.distances[ furth_idx ];
491  // clus_info.variances[ max_var_idx ] *= clus_info.counts[ max_var_idx ];
492  // clus_info.variances[ max_var_idx ] -= dsqrt;
493  // --clus_info.counts[ max_var_idx ];
494  // clus_info.variances[ max_var_idx ] /= clus_info.counts[ max_var_idx ];
495  //
496  // // Adjust count of the (now not any longer) empty cluster.
497  // ++clus_info.counts[ ec_idx ];
498 
499  // Finally, we need to update the centroids in order to reflect the changes.
500  update_centroids( data, assignments, centroids );
501  }
502 
503  // Now we return true, because we changed some assignments.
504  return true;
505  }
506 
507  virtual double distance(
508  Point const& lhs,
509  Point const& rhs
510  ) const = 0;
511 
512  virtual void update_centroids(
513  std::vector<Point> const& data,
514  std::vector<size_t> const& assignments,
515  std::vector<Point>& centroids
516  ) = 0;
517 
518  virtual void post_loop_hook(
519  std::vector<Point> const& data,
520  std::vector<size_t>& assignments,
521  std::vector<Point>& centroids
522  ) {
523  (void) data;
524  (void) assignments;
525  (void) centroids;
526  }
527 
528  // -------------------------------------------------------------------------
529  // Internal Functions
530  // -------------------------------------------------------------------------
531 
532 private:
533 
534  void argument_checks_(
535  std::vector<Point> const& data,
536  size_t const k
537  ) const {
538  // Basic checks.
539  if( k > data.size() ) {
540  throw std::runtime_error(
541  "Cannot run Kmeans with more clusters (k == " + std::to_string( k ) +
542  ") than data points (" + std::to_string( data.size() ) + ")"
543  );
544  }
545  if( k == 0 ) {
546  throw std::runtime_error(
547  "Cannot run Kmeans with zero clusters (k == 0)."
548  );
549  }
550 
551  // Validate the data. The function might also throw on its own, in order
552  // to provide a more helpful message about what is actually invalid about the data.
553  if( ! data_validation( data ) ) {
554  throw std::runtime_error( "Invalid data." );
555  }
556  }
557 
558  void runtime_checks_(
559  std::vector<Point> const& data,
560  size_t const k
561  ) const {
562  if( assignments_.size() != data.size() ) {
563  throw std::runtime_error(
564  "Assignments has size " + std::to_string( assignments_.size() ) +
565  " but data has size " + std::to_string( data.size() ) + "."
566  );
567  }
568  for( auto const& assign : assignments_ ) {
569  if( assign >= k ) {
570  throw std::runtime_error(
571  "Invalid assignment " + std::to_string( assign ) +
572  " >= k = " + std::to_string( k ) + "."
573  );
574  }
575  }
576  if( centroids_.size() != k ) {
577  throw std::runtime_error(
578  "Centroids has size " + std::to_string( centroids_.size() ) +
579  " but k is " + std::to_string( k ) + "."
580  );
581  }
582  }
583 
584  void init_with_random_assignments_(
585  std::vector<Point> const& data,
586  size_t const k
587  ) {
588  // Prepare a vector of the desired size.
589  assignments_ = std::vector<size_t>( data.size(), 0 );
590 
591  // Prepare a random distribution in range [0,k).
592  auto& engine = Options::get().random_engine();
593  std::uniform_int_distribution<size_t> distribution( 0, k - 1 );
594 
595  // Assign random cluster indices for each data point.
596  for( size_t i = 0; i < data.size(); ++i ) {
597  assignments_[ i ] = distribution( engine );
598  }
599  }
600 
601  void init_with_random_centroids_(
602  std::vector<Point> const& data,
603  size_t const k
604  ) {
605  // Prepare centroids vector. Empty, because we don't want to assume any default
606  // constructor for the Points.
607  centroids_ = std::vector<Point>();
608 
609  // Select k unique numbers out of the interval [ 0, data.size() ),
610  // and copy those data points to the centroids.
611  auto idxs = select_without_replacement( k, data.size() );
612  for( auto const idx : idxs ) {
613  centroids_.push_back( data[idx] );
614  }
615 
616  assert( centroids_.size() == k );
617  }
618 
619  void init_with_kmeans_plus_plus_(
620  std::vector<Point> const& data,
621  size_t const k
622  ) {
623  // Prepare centroids vector. Empty, because we don't want to assume any default
624  // constructor for the Points.
625  centroids_ = std::vector<Point>();
626 
627  // Shorthand.
628  auto& engine = Options::get().random_engine();
629 
630  // Use a random point as first centroid.
631  std::uniform_int_distribution<size_t> first_dist( 0, data.size() - 1 );
632  centroids_.push_back( data[ first_dist(engine) ]);
633 
634  // Prepare a vector of probabilities to select each data point as a centroid.
635  auto data_probs = std::vector<double>( data.size(), 0.0 );
636 
637  // Add more centroids.
638  for( size_t i = 1; i < k; ++i ) {
639 
640  // For each data point...
641  #pragma omp parallel for
642  for( size_t di = 0; di < data.size(); ++di ) {
643 
644  // ...find the closest centroid (of the ones that are produced so far), ...
645  double const min_d = find_nearest_cluster( centroids_, data[ di ] ).second;
646 
647  // ...and use its square as probability to select this point.
648  // (No need for OpenMP locking here, as di is unique to each thread).
649  data_probs[ di ] = min_d * min_d;
650  }
651 
652  // Now select a new centroid from the data, according to the given probabilities.
653  std::discrete_distribution<size_t> distribution(
654  data_probs.begin(), data_probs.end()
655  );
656  auto idx = distribution( engine );
657  assert( idx < data.size() );
658  centroids_.push_back( data[ idx ] );
659  }
660 
661  assert( centroids_.size() == k );
662  }
663 
664  std::unordered_set<size_t> get_empty_centroids_() {
665  auto const k = centroids_.size();
666 
667  // Fill a list with all numbers up to k...
668  auto empties = std::unordered_set<size_t>();
669  for( size_t i = 0; i < k; ++i ) {
670  empties.insert( i );
671  }
672 
673  // ... then remove all assigned ones again.
674  for( size_t i = 0; i < assignments_.size(); ++i ) {
675  assert( assignments_[i] < k );
676  empties.erase( assignments_[i] );
677 
678  // Prematurely exit if there is nothing else to remove.
679  // We don't want to go to all assignments if not necessary.
680  if( empties.empty() ) {
681  return empties;
682  }
683  }
684 
685  // If we are here, there are empty centroids, otherwise we'd have prematurely exited.
686  assert( ! empties.empty() );
687  return empties;
688  }
689 
690  // -------------------------------------------------------------------------
691  // Data Members
692  // -------------------------------------------------------------------------
693 
694 private:
695 
696  std::vector<size_t> assignments_;
697  std::vector<Point> centroids_;
698 
699  size_t max_iterations_ = 100;
701 
702 };
703 
704 } // namespace utils
705 } // namespace genesis
706 
707 #endif // include guard
Helper POD that stores the variances and number of data points of each centroid, as well as the dista...
KmeansInitializationStrategy
Enum of the initialization strategies offered by the Kmeans implementation.
virtual bool treat_empty_centroids(std::vector< Point > const &data, std::vector< size_t > &assignments, std::vector< Point > &centroids, std::unordered_set< size_t > const &empty_centroids)
Kmeans & centroids(std::vector< Point > const &value)
size_t run(std::vector< Point > const &data, size_t const k)
std::vector< Point > const & centroids() const
std::string to_string(T const &v)
Return a string representation of a given value.
Definition: string.hpp:381
virtual void initialize(std::vector< Point > const &data, size_t const k)
virtual std::pair< size_t, double > find_nearest_cluster(std::vector< Point > const &centroids, Point const &datum) const
virtual void pre_loop_hook(std::vector< Point > const &data, std::vector< size_t > &assignments, std::vector< Point > &centroids)
std::function< void(size_t i)> report_iteration
virtual bool lloyd_step(std::vector< Point > const &data, std::vector< size_t > &assignments, std::vector< Point > &centroids)
virtual ~Kmeans()=default
size_t max_iterations() const
Kmeans & max_iterations(size_t value)
virtual bool assign_to_centroids(std::vector< Point > const &data, std::vector< Point > const &centroids, std::vector< size_t > &assignments)
virtual KmeansClusteringInfo cluster_info(std::vector< Point > const &data, std::vector< size_t > const &assignments, std::vector< Point > const &centroids) const
Provides easy and fast logging functionality.
virtual bool data_validation(std::vector< Point > const &data) const
KmeansInitializationStrategy initialization_strategy() const
std::default_random_engine & random_engine()
Returns the default engine for random number generation.
Definition: options.hpp:155
Kmeans & operator=(Kmeans const &)=default
std::vector< size_t > select_without_replacement(size_t const k, size_t const n)
Select k many unique numbers out of the range [ 0, n ).
Definition: random.cpp:46
Kmeans & assignments(std::vector< size_t > const &value)
Kmeans & initialization_strategy(KmeansInitializationStrategy value)
KmeansClusteringInfo cluster_info(std::vector< Point > const &data) const
std::vector< size_t > cluster_sizes() const
std::function< void(void)> report_initialization
virtual void post_loop_hook(std::vector< Point > const &data, std::vector< size_t > &assignments, std::vector< Point > &centroids)
static Options & get()
Returns a single instance of this class.
Definition: options.hpp:60
virtual void update_centroids(std::vector< Point > const &data, std::vector< size_t > const &assignments, std::vector< Point > &centroids)=0
virtual double distance(Point const &lhs, Point const &rhs) const =0
std::vector< size_t > const & assignments() const
#define LOG_INFO
Log an info message. See genesis::utils::LoggingLevel.
Definition: logging.hpp:98