R/agglomerative.R
cuda_ml_agglomerative_clustering.Rd
Recursively merge the pair of clusters that minimally increases a given linkage distance.
cuda_ml_agglomerative_clustering( x, n_clusters = 2L, metric = c("euclidean", "l1", "l2", "manhattan", "cosine"), connectivity = c("pairwise", "knn"), n_neighbors = 15L )
x | The input matrix or dataframe. Each data point should be a row and should consist of numeric values only. |
---|---|
n_clusters | The number of clusters to find. Default: 2L. |
metric | Metric used for linkage computation. Must be one of "euclidean", "l1", "l2", "manhattan", "cosine". If connectivity is "knn" then only "euclidean" is accepted. Default: "euclidean". |
connectivity | The type of connectivity matrix to compute. Must be one of "pairwise", "knn". Default: "pairwise". - 'pairwise' will compute the entire fully-connected graph of pairwise distances between each set of points. This is the fastest to compute and can be very fast for smaller datasets but requires O(n^2) space. - 'knn' will sparsify the fully-connected connectivity matrix to save memory and enable much larger inputs. "n_neighbors" will control the amount of memory used and the graph will be connected automatically in the event "n_neighbors" was not large enough to connect it. |
n_neighbors | The number of neighbors to compute when
|
A clustering object with the following attributes:
"n_clusters": The number of clusters found by the algorithm.
"children": The children of each non-leaf node. Values less than
nrow(x)
correspond to leaves of the tree which are the original
samples. children[i + 1][1]
and children[i + 1][2]
were
merged to form node (nrow(x) + i)
in the i
-th iteration.
"labels": cluster label of each data point.
#> #>#>#> #>set.seed(0L) gen_pts <- function() { centers <- list(c(1000, 1000), c(-1000, -1000), c(-1000, 1000)) pts <- centers %>% map(~ mvrnorm(50, mu = .x, Sigma = diag(2))) rlang::exec(rbind, !!!pts) %>% as.matrix() } clust <- cuda_ml_agglomerative_clustering( x = gen_pts(), metric = "euclidean", n_clusters = 3L ) print(clust$labels)#> NULL