Clustering
This page describes clustering algorithms in MLlib. The guide for clustering in the RDD-based API also has relevant information about these algorithms.
Table of Contents
K-means
k-means is one of the most commonly used clustering algorithms that clusters the data points into a predefined number of clusters. The MLlib implementation includes a parallelized variant of the k-means++ method called kmeans||.
KMeans is implemented as an Estimator and generates a KMeansModel as the base model.
Input Columns
| Param name | Type(s) | Default | Description | 
|---|---|---|---|
| featuresCol | Vector | "features" | Feature vector | 
Output Columns
| Param name | Type(s) | Default | Description | 
|---|---|---|---|
| predictionCol | Int | "prediction" | Predicted cluster center | 
Examples
Refer to the Scala API docs for more details.
import org.apache.spark.ml.clustering.KMeans
// Loads data.
val dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")
// Trains a k-means model.
val kmeans = new KMeans().setK(2).setSeed(1L)
val model = kmeans.fit(dataset)
// Evaluate clustering by computing Within Set Sum of Squared Errors.
val WSSSE = model.computeCost(dataset)
println(s"Within Set Sum of Squared Errors = $WSSSE")
// Shows the result.
println("Cluster Centers: ")
model.clusterCenters.foreach(println)
Refer to the Java API docs for more details.
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.ml.clustering.KMeans;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
// Loads data.
Dataset<Row> dataset = spark.read().format("libsvm").load("data/mllib/sample_kmeans_data.txt");
// Trains a k-means model.
KMeans kmeans = new KMeans().setK(2).setSeed(1L);
KMeansModel model = kmeans.fit(dataset);
// Evaluate clustering by computing Within Set Sum of Squared Errors.
double WSSSE = model.computeCost(dataset);
System.out.println("Within Set Sum of Squared Errors = " + WSSSE);
// Shows the result.
Vector[] centers = model.clusterCenters();
System.out.println("Cluster Centers: ");
for (Vector center: centers) {
  System.out.println(center);
}
Refer to the Python API docs for more details.
from pyspark.ml.clustering import KMeans
# Loads data.
dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")
# Trains a k-means model.
kmeans = KMeans().setK(2).setSeed(1)
model = kmeans.fit(dataset)
# Evaluate clustering by computing Within Set Sum of Squared Errors.
wssse = model.computeCost(dataset)
print("Within Set Sum of Squared Errors = " + str(wssse))
# Shows the result.
centers = model.clusterCenters()
print("Cluster Centers: ")
for center in centers:
    print(center)
Refer to the R API docs for more details.
# Fit a k-means model with spark.kmeans
t <- as.data.frame(Titanic)
training <- createDataFrame(t)
df_list <- randomSplit(training, c(7,3), 2)
kmeansDF <- df_list[[1]]
kmeansTestDF <- df_list[[2]]
kmeansModel <- spark.kmeans(kmeansDF, ~ Class + Sex + Age + Freq,
                            k = 3)
# Model summary
summary(kmeansModel)
# Get fitted result from the k-means model
head(fitted(kmeansModel))
# Prediction
kmeansPredictions <- predict(kmeansModel, kmeansTestDF)
head(kmeansPredictions)
Latent Dirichlet allocation (LDA)
LDA is implemented as an Estimator that supports both EMLDAOptimizer and OnlineLDAOptimizer,
and generates a LDAModel as the base model. Expert users may cast a LDAModel generated by
EMLDAOptimizer to a DistributedLDAModel if needed.
Examples
Refer to the Scala API docs for more details.
import org.apache.spark.ml.clustering.LDA
// Loads data.
val dataset = spark.read.format("libsvm")
  .load("data/mllib/sample_lda_libsvm_data.txt")
// Trains a LDA model.
val lda = new LDA().setK(10).setMaxIter(10)
val model = lda.fit(dataset)
val ll = model.logLikelihood(dataset)
val lp = model.logPerplexity(dataset)
println(s"The lower bound on the log likelihood of the entire corpus: $ll")
println(s"The upper bound on perplexity: $lp")
// Describe topics.
val topics = model.describeTopics(3)
println("The topics described by their top-weighted terms:")
topics.show(false)
// Shows the result.
val transformed = model.transform(dataset)
transformed.show(false)
Refer to the Java API docs for more details.
import org.apache.spark.ml.clustering.LDA;
import org.apache.spark.ml.clustering.LDAModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
// Loads data.
Dataset<Row> dataset = spark.read().format("libsvm")
  .load("data/mllib/sample_lda_libsvm_data.txt");
// Trains a LDA model.
LDA lda = new LDA().setK(10).setMaxIter(10);
LDAModel model = lda.fit(dataset);
double ll = model.logLikelihood(dataset);
double lp = model.logPerplexity(dataset);
System.out.println("The lower bound on the log likelihood of the entire corpus: " + ll);
System.out.println("The upper bound on perplexity: " + lp);
// Describe topics.
Dataset<Row> topics = model.describeTopics(3);
System.out.println("The topics described by their top-weighted terms:");
topics.show(false);
// Shows the result.
Dataset<Row> transformed = model.transform(dataset);
transformed.show(false);
Refer to the Python API docs for more details.
from pyspark.ml.clustering import LDA
# Loads data.
dataset = spark.read.format("libsvm").load("data/mllib/sample_lda_libsvm_data.txt")
# Trains a LDA model.
lda = LDA(k=10, maxIter=10)
model = lda.fit(dataset)
ll = model.logLikelihood(dataset)
lp = model.logPerplexity(dataset)
print("The lower bound on the log likelihood of the entire corpus: " + str(ll))
print("The upper bound on perplexity: " + str(lp))
# Describe topics.
topics = model.describeTopics(3)
print("The topics described by their top-weighted terms:")
topics.show(truncate=False)
# Shows the result
transformed = model.transform(dataset)
transformed.show(truncate=False)
Refer to the R API docs for more details.
# Load training data
df <- read.df("data/mllib/sample_lda_libsvm_data.txt", source = "libsvm")
training <- df
test <- df
# Fit a latent dirichlet allocation model with spark.lda
model <- spark.lda(training, k = 10, maxIter = 10)
# Model summary
summary(model)
# Posterior probabilities
posterior <- spark.posterior(model, test)
head(posterior)
# The log perplexity of the LDA model
logPerplexity <- spark.perplexity(model, test)
print(paste0("The upper bound bound on perplexity: ", logPerplexity))
Bisecting k-means
Bisecting k-means is a kind of hierarchical clustering using a divisive (or “top-down”) approach: all observations start in one cluster, and splits are performed recursively as one moves down the hierarchy.
Bisecting K-means can often be much faster than regular K-means, but it will generally produce a different clustering.
BisectingKMeans is implemented as an Estimator and generates a BisectingKMeansModel as the base model.
Examples
Refer to the Scala API docs for more details.
import org.apache.spark.ml.clustering.BisectingKMeans
// Loads data.
val dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")
// Trains a bisecting k-means model.
val bkm = new BisectingKMeans().setK(2).setSeed(1)
val model = bkm.fit(dataset)
// Evaluate clustering.
val cost = model.computeCost(dataset)
println(s"Within Set Sum of Squared Errors = $cost")
// Shows the result.
println("Cluster Centers: ")
val centers = model.clusterCenters
centers.foreach(println)
Refer to the Java API docs for more details.
import org.apache.spark.ml.clustering.BisectingKMeans;
import org.apache.spark.ml.clustering.BisectingKMeansModel;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
// Loads data.
Dataset<Row> dataset = spark.read().format("libsvm").load("data/mllib/sample_kmeans_data.txt");
// Trains a bisecting k-means model.
BisectingKMeans bkm = new BisectingKMeans().setK(2).setSeed(1);
BisectingKMeansModel model = bkm.fit(dataset);
// Evaluate clustering.
double cost = model.computeCost(dataset);
System.out.println("Within Set Sum of Squared Errors = " + cost);
// Shows the result.
System.out.println("Cluster Centers: ");
Vector[] centers = model.clusterCenters();
for (Vector center : centers) {
  System.out.println(center);
}
Refer to the Python API docs for more details.
from pyspark.ml.clustering import BisectingKMeans
# Loads data.
dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")
# Trains a bisecting k-means model.
bkm = BisectingKMeans().setK(2).setSeed(1)
model = bkm.fit(dataset)
# Evaluate clustering.
cost = model.computeCost(dataset)
print("Within Set Sum of Squared Errors = " + str(cost))
# Shows the result.
print("Cluster Centers: ")
centers = model.clusterCenters()
for center in centers:
    print(center)
Refer to the R API docs for more details.
t <- as.data.frame(Titanic)
training <- createDataFrame(t)
# Fit bisecting k-means model with four centers
model <- spark.bisectingKmeans(training, Class ~ Survived, k = 4)
# get fitted result from a bisecting k-means model
fitted.model <- fitted(model, "centers")
# Model summary
head(summary(fitted.model))
# fitted values on training data
fitted <- predict(model, training)
head(select(fitted, "Class", "prediction"))
Gaussian Mixture Model (GMM)
A Gaussian Mixture Model
represents a composite distribution whereby points are drawn from one of k Gaussian sub-distributions,
each with its own probability. The spark.ml implementation uses the
expectation-maximization
algorithm to induce the maximum-likelihood model given a set of samples.
GaussianMixture is implemented as an Estimator and generates a GaussianMixtureModel as the base
model.
Input Columns
| Param name | Type(s) | Default | Description | 
|---|---|---|---|
| featuresCol | Vector | "features" | Feature vector | 
Output Columns
| Param name | Type(s) | Default | Description | 
|---|---|---|---|
| predictionCol | Int | "prediction" | Predicted cluster center | 
| probabilityCol | Vector | "probability" | Probability of each cluster | 
Examples
Refer to the Scala API docs for more details.
import org.apache.spark.ml.clustering.GaussianMixture
// Loads data
val dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")
// Trains Gaussian Mixture Model
val gmm = new GaussianMixture()
  .setK(2)
val model = gmm.fit(dataset)
// output parameters of mixture model model
for (i <- 0 until model.getK) {
  println(s"Gaussian $i:\nweight=${model.weights(i)}\n" +
      s"mu=${model.gaussians(i).mean}\nsigma=\n${model.gaussians(i).cov}\n")
}
Refer to the Java API docs for more details.
import org.apache.spark.ml.clustering.GaussianMixture;
import org.apache.spark.ml.clustering.GaussianMixtureModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
// Loads data
Dataset<Row> dataset = spark.read().format("libsvm").load("data/mllib/sample_kmeans_data.txt");
// Trains a GaussianMixture model
GaussianMixture gmm = new GaussianMixture()
  .setK(2);
GaussianMixtureModel model = gmm.fit(dataset);
// Output the parameters of the mixture model
for (int i = 0; i < model.getK(); i++) {
  System.out.printf("Gaussian %d:\nweight=%f\nmu=%s\nsigma=\n%s\n\n",
          i, model.weights()[i], model.gaussians()[i].mean(), model.gaussians()[i].cov());
}
Refer to the Python API docs for more details.
from pyspark.ml.clustering import GaussianMixture
# loads data
dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")
gmm = GaussianMixture().setK(2).setSeed(538009335)
model = gmm.fit(dataset)
print("Gaussians shown as a DataFrame: ")
model.gaussiansDF.show(truncate=False)
Refer to the R API docs for more details.
# Load training data
df <- read.df("data/mllib/sample_kmeans_data.txt", source = "libsvm")
training <- df
test <- df
# Fit a gaussian mixture clustering model with spark.gaussianMixture
model <- spark.gaussianMixture(training, ~ features, k = 2)
# Model summary
summary(model)
# Prediction
predictions <- predict(model, test)
head(predictions)
