Some clustering techniques allow you to fit models to data, and you can then feed whatever data you like to the model and it will try to classify your samples.
For instance, the Scikit-learn k-means model has a fit method that lets you fit the model to some data. The fit method calculates the necessary centroids for k-means clustering. You can then use the predict method with whatever data you like and it will try to cluster those samples. You can also combine these methods with fit_predict.
Other clustering models don’t have standalone predict methods and don’t let you fit the model to data and then run it on new data, simply because of the way the clustering algorithm works.
The Scikit-learn AgglomerativeClustering and DBSCAN models both have fit methods that run the algorithm on your data, but they do not have predict methods, so you can’t then use the fitted models on arbitrary data. Instead, you can access the calculated clusters for the data you fed to fit, via the labels_ attribute.
What If We Want to Classify Additional Data?
If you want to classify new data using existing cluster information and your model doesn’t allow it, one approach is to classify new samples by checking the clusters of their nearest neighbours and assigning them to whichever cluster is most highly represented among nearest neighbours.
We can do this using Scikit-learn’s NearestNeighbors model.
In this example, we’ll do the following:
- Load the Iris data, then condense the two most similar clusters into one cluster. This gives us some data that should be easy to classify into two clusters by any of a variety of possible techniques.
- Split this data into a “training” segment and a “test” segment.
- Fit agglomerative clustering to the training segment. This gives us y_predicted: the predicted clusters for the training segment.
- Use the KNeighborsClassifier from Scikit-learn to try to predict the clusters of the samples in the test data segment, based on the clusters we found by agglomerative clustering on the training data segment.
- Plot the actual clusters for the test data segment, vs. the clusters found by the above combination of techniques.
from sklearn.neighbors import NearestNeighbors from sklearn.preprocessing import StandardScaler from sklearn.cluster import AgglomerativeClustering from sklearn.neighbors import KNeighborsClassifier from sklearn.datasets import load_iris from sklearn.metrics import normalized_mutual_info_score from sklearn.model_selection import train_test_split import matplotlib.pyplot as plt import seaborn as sn """ Load the iris data set, then condense the two similar species into one. """ iris = load_iris(as_frame=True) df = iris['data'] df.columns = 'sepal length', 'sepal width', 'petal length', 'petal width' X = df y = iris['target'] # Change all 2's to 1, condensing the two # hard-to distinguish clusters into 1 cluster. y = [1 if s == 2 else s for s in y] """ Do train test split, then run agglomerative clustering on the training data. """ X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7, shuffle=True) ac = AgglomerativeClustering(n_clusters=2) ac.fit(X_train) y_predicted = ac.labels_ """ Use the nearest neighbour classifier to find clusters for the test data samples, based on existing clusters """ nc = KNeighborsClassifier(n_neighbors=5) nc.fit(X_train, y_predicted) y_new = nc.predict(X_test) """ Plot the true clusters for the test segment, vs. the clusters found by using nearest neighbors classfication on the results of the agglomerative clustering. """ plot_x = 0 plot_y = 1 x_label = df.columns[plot_x] y_label = df.columns[plot_y] fig = plt.figure() fig.suptitle("Nearest Neighbours after Agglomerative Clustering Iris Flower Dataset") ax = fig.add_subplot(121) ax.set_xlabel(x_label) ax.set_xlabel(y_label) ax.set_title("True Clusters") sn.scatterplot(data=X_test, x=x_label, y=y_label, hue=y_test, palette='pastel', ax=ax) ax = fig.add_subplot(122) ax.set_xlabel(x_label) ax.set_xlabel(y_label) ax.set_title("Predicted Clusters") sn.scatterplot(data=X_test, x=x_label, y=y_label, hue=y_new, palette='pastel', ax=ax) plt.show() print("Mutual info score on training data: ", normalized_mutual_info_score(y_train, y_predicted)) print("Mutual info score on test data: ", normalized_mutual_info_score(y_test, y_new))
Mutual info score on training data: 1.0 Mutual info score on test data: 1.0
We can see both from the normalised mutual information scores and from the scatter plots that the results we obtain by doing this are excellent.
Agglomerative clustering succeeded very well in distinguishing the two clusters in the modified dataset, and classification based on nearest neighbours made a great job of classifying additional data that we then supplied from the test data segment.