Package de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans

Source Code of de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.KMeansPlusPlusInitialMeans

package de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans;

/*
This file is part of ELKI:
Environment for Developing KDD-Applications Supported by Index-Structures

Copyright (C) 2012
Ludwig-Maximilians-Universität München
Lehr- und Forschungseinheit für Datenbanksysteme
ELKI Development Team

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU Affero General Public License for more details.

You should have received a copy of the GNU Affero General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

import de.lmu.ifi.dbs.elki.data.NumberVector;
import de.lmu.ifi.dbs.elki.database.ids.ArrayDBIDs;
import de.lmu.ifi.dbs.elki.database.ids.DBID;
import de.lmu.ifi.dbs.elki.database.ids.DBIDIter;
import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil;
import de.lmu.ifi.dbs.elki.database.ids.ModifiableDBIDs;
import de.lmu.ifi.dbs.elki.database.query.distance.DistanceQuery;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.distance.distancefunction.PrimitiveDistanceFunction;
import de.lmu.ifi.dbs.elki.distance.distancefunction.PrimitiveDoubleDistanceFunction;
import de.lmu.ifi.dbs.elki.distance.distancevalue.NumberDistance;
import de.lmu.ifi.dbs.elki.logging.LoggingUtil;
import de.lmu.ifi.dbs.elki.math.linearalgebra.Vector;
import de.lmu.ifi.dbs.elki.utilities.documentation.Reference;
import de.lmu.ifi.dbs.elki.utilities.exceptions.AbortException;

/**
* K-Means++ initialization for k-means.
*
* Reference:
* <p>
* D. Arthur, S. Vassilvitskii<br />
* k-means++: the advantages of careful seeding<br />
* In: Proc. of the Eighteenth Annual ACM-SIAM Symposium on Discrete Algorithms,
* SODA 2007
* </p>
*
* @author Erich Schubert
*
* @param <V> Vector type
* @param <D> Distance type
*/
@Reference(authors = "D. Arthur, S. Vassilvitskii", title = "k-means++: the advantages of careful seeding", booktitle = "Proc. of the Eighteenth Annual ACM-SIAM Symposium on Discrete Algorithms, SODA 2007", url = "http://dx.doi.org/10.1145/1283383.1283494")
public class KMeansPlusPlusInitialMeans<V extends NumberVector<V, ?>, D extends NumberDistance<D, ?>> extends AbstractKMeansInitialization<V> {
  /**
   * Constructor.
   *
   * @param seed Random seed.
   */
  public KMeansPlusPlusInitialMeans(Long seed) {
    super(seed);
  }

  @Override
  public List<Vector> chooseInitialMeans(Relation<V> relation, int k, PrimitiveDistanceFunction<? super V, ?> distanceFunction) {
    // Get a distance query
    if(!(distanceFunction.getDistanceFactory() instanceof NumberDistance)) {
      throw new AbortException("K-Means++ initialization can only be used with numerical distances.");
    }
    @SuppressWarnings("unchecked")
    final PrimitiveDistanceFunction<? super V, D> distF = (PrimitiveDistanceFunction<? super V, D>) distanceFunction;
    DistanceQuery<V, D> distQ = relation.getDatabase().getDistanceQuery(relation, distF);

    // Chose first mean
    List<Vector> means = new ArrayList<Vector>(k);

    Random random = (seed != null) ? new Random(seed) : new Random();
    DBID first = DBIDUtil.randomSample(relation.getDBIDs(), 1, random.nextLong()).iterator().next();
    means.add(relation.get(first).getColumnVector());

    ModifiableDBIDs chosen = DBIDUtil.newHashSet(k);
    chosen.add(first);
    ArrayDBIDs ids = DBIDUtil.ensureArray(relation.getDBIDs());
    // Initialize weights
    double[] weights = new double[ids.size()];
    double weightsum = initialWeights(weights, ids, first, distQ);
    while(means.size() < k) {
      if(weightsum > Double.MAX_VALUE) {
        LoggingUtil.warning("Could not choose a reasonable mean for k-means++ - too many data points, too large squared distances?");
      }
      if(weightsum < Double.MIN_NORMAL) {
        LoggingUtil.warning("Could not choose a reasonable mean for k-means++ - to few data points?");
      }
      double r = random.nextDouble() * weightsum;
      int pos = 0;
      while(r > 0 && pos < weights.length) {
        r -= weights[pos];
        pos++;
      }
      // Add new mean:
      DBID newmean = ids.get(pos);
      means.add(relation.get(newmean).getColumnVector());
      chosen.add(newmean);
      // Update weights:
      weights[pos] = 0.0;
      // Choose optimized version for double distances, if applicable.
      if (distF instanceof PrimitiveDoubleDistanceFunction) {
        @SuppressWarnings("unchecked")
        PrimitiveDoubleDistanceFunction<V> ddist = (PrimitiveDoubleDistanceFunction<V>) distF;
        weightsum = updateWeights(weights, ids, newmean, ddist, relation);
      } else {
        weightsum = updateWeights(weights, ids, newmean, distQ);
      }
    }

    return means;
  }

  /**
   * Initialize the weight list.
   *
   * @param weights Weight list
   * @param ids IDs
   * @param latest Added ID
   * @param distQ Distance query
   * @return Weight sum
   */
  protected double initialWeights(double[] weights, ArrayDBIDs ids, DBID latest, DistanceQuery<V, D> distQ) {
    double weightsum = 0.0;
    DBIDIter it = ids.iter();
    for(int i = 0; i < weights.length; i++, it.advance()) {
      DBID id = it.getDBID();
      if(latest.equals(id)) {
        weights[i] = 0.0;
      }
      else {
        double d = distQ.distance(latest, id).doubleValue();
        weights[i] = d * d;
      }
      weightsum += weights[i];
    }
    return weightsum;
  }

  /**
   * Update the weight list.
   *
   * @param weights Weight list
   * @param ids IDs
   * @param latest Added ID
   * @param distQ Distance query
   * @return Weight sum
   */
  protected double updateWeights(double[] weights, ArrayDBIDs ids, DBID latest, DistanceQuery<V, D> distQ) {
    double weightsum = 0.0;
    DBIDIter it = ids.iter();
    for(int i = 0; i < weights.length; i++, it.advance()) {
      DBID id = it.getDBID();
      if(weights[i] > 0.0) {
        double d = distQ.distance(latest, id).doubleValue();
        weights[i] = Math.min(weights[i], d * d);
        weightsum += weights[i];
      }
    }
    return weightsum;
  }

  /**
   * Update the weight list.
   *
   * @param weights Weight list
   * @param ids IDs
   * @param latest Added ID
   * @param distF Distance function
   * @return Weight sum
   */
  protected double updateWeights(double[] weights, ArrayDBIDs ids, DBID latest, PrimitiveDoubleDistanceFunction<V> distF, Relation<V> rel) {
    final V lv = rel.get(latest);
    double weightsum = 0.0;
    DBIDIter it = ids.iter();
    for(int i = 0; i < weights.length; i++, it.advance()) {
      DBID id = it.getDBID();
      if(weights[i] > 0.0) {
        double d = distF.doubleDistance(lv, rel.get(id));
        weights[i] = Math.min(weights[i], d * d);
        weightsum += weights[i];
      }
    }
    return weightsum;
  }

  /**
   * Parameterization class.
   *
   * @author Erich Schubert
   *
   * @apiviz.exclude
   */
  public static class Parameterizer<V extends NumberVector<V, ?>, D extends NumberDistance<D, ?>> extends AbstractKMeansInitialization.Parameterizer<V> {
    @Override
    protected KMeansPlusPlusInitialMeans<V, D> makeInstance() {
      return new KMeansPlusPlusInitialMeans<V, D>(seed);
    }
  }
}
TOP

Related Classes of de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.KMeansPlusPlusInitialMeans

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.