Grid search for hyperparameter evaluation of clustering in scikit-learn

Jamie Bull picture Jamie Bull · Jan 5, 2016 · Viewed 12k times · Source

I'm clustering a sample of about 100 records (unlabelled) and trying to use grid_search to evaluate the clustering algorithm with various hyperparameters. I'm scoring using silhouette_score which works fine.

My problem here is that I don't need to use the cross-validation aspect of the GridSearchCV/RandomizedSearchCV, but I can't find a simple GridSearch/RandomizedSearch. I can write my own but the ParameterSampler and ParameterGrid objects are very useful.

My next step will be to subclass BaseSearchCV and implement my own _fit() method, but thought it was worth asking is there a simpler way to do this, for example by passing something to the cv parameter?

def silhouette_score(estimator, X):
    clusters = estimator.fit_predict(X)
    score = metrics.silhouette_score(distance_matrix, clusters, metric='precomputed')
    return score

ca = KMeans()
param_grid = {"n_clusters": range(2, 11)}

# run randomized search
search = GridSearchCV(
    ca,
    param_distributions=param_dist,
    n_iter=n_iter_search,
    scoring=silhouette_score,
    cv= # can I pass something here to only use a single fold?
    )
search.fit(distance_matrix)

Answer

erdogant picture erdogant · Jun 18, 2020

The clusteval library will help you to evaluate the data and find the optimal number of clusters. This library contains five methods that can be used to evaluate clusterings: silhouette, dbindex, derivative, dbscan and hdbscan.

pip install clusteval

Depending on your data, the evaluation method can be chosen.

# Import library
from clusteval import clusteval

# Set parameters, as an example dbscan
ce = clusteval(method='dbscan')

# Fit to find optimal number of clusters using dbscan
results= ce.fit(X)

# Make plot of the cluster evaluation
ce.plot()

# Make scatter plot. Note that the first two coordinates are used for plotting.
ce.scatter(X)

# results is a dict with various output statistics. One of them are the labels.
cluster_labels = results['labx']