class scib_metrics.utils.KMeans(n_clusters=8, init='k-means++', n_init=1, max_iter=300, tol=0.0001, seed=0)[source]#

Jax implementation of sklearn.cluster.KMeans.

This implementation is limited to Euclidean distance.

  • n_clusters (int (default: 8)) – Number of clusters.

  • init (Literal['k-means++', 'random'] (default: 'k-means++')) –

    Cluster centroid initialization method. One of the following:

    • 'k-means++': Sample initial cluster centroids based on an

      empirical distribution of the points’ contributions to the overall inertia.

    • 'random': Uniformly sample observations as initial centroids

  • n_init (int (default: 1)) – Number of times the k-means algorithm will be initialized.

  • max_iter (int (default: 300)) – Maximum number of iterations of the k-means algorithm for a single run.

  • tol (float (default: 0.0001)) – Relative tolerance with regards to inertia to declare convergence.

  • seed (Union[int, Array] (default: 0)) – Random seed.

Methods table#


Fit the model to the data.


Fit the model to the data.