API reference
Solvers
- jaxclust._src.solvers.kruskals(S: Array, ncc: int) Tuple[Array, Array]
Calculates the adjacency matrix and cluster connectivity matrix of the minimum weight ncc-spanning forest using Kruskal’s algorithm.
- Parameters:
S (jax.Array) – similarity matrix.
ncc (int) – number of connected components.
- Returns:
A, M
- Return type:
Tuple[jax.Array, jax.Array]
\(A_{ij} = 1\) if the edge (i, j) is in the forest.
\(M_{ij} = 1\) if i and j are in the same connected component of the forest.
- jaxclust._src.solvers.kruskals_prims_pre(S: Array, ncc: int) Tuple[Array, Array]
Calculates the adjacency matrix and cluster connectivity matrix of the minimum weight ncc-spanning forest. Uses Prim’s algorithm to construct the full spanning tree, then applies Kruskal’s algorithm to the edges in the spanning tree in order to calculate the forest.
- Parameters:
S (jax.Array) – similarity matrix.
ncc (int) – number of connected components.
- Returns:
A, M
- Return type:
Tuple[jax.Array, jax.Array]
\(A_{ij} = 1\) if the edge (i, j) is in the forest.
\(M_{ij} = 1\) if i and j are in the same connected component of the forest.
- jaxclust._src.solvers.ckruskals(S: Array, ncc: int, C: Array) Tuple[Array, Array]
Calculates the adjacency matrix and cluster connectivity matrix of the minimum weight ncc-spanning forest. Uses a biased heuristic based on kruskals algorithm to create the forest.
- Parameters:
S (jax.Array) – similarity matrix.
ncc (int) – number of connected components.
C (jax.Array) – constraint matrix.
- Returns:
A, M
- Return type:
Tuple[jax.Array, jax.Array]
\(A_{ij} = 1\) if the edge (i, j) is in the forest.
\(M_{ij} = 1\) if i and j are in the same connected component of the forest.
\(C_{ij}=1\) if (i, j) has a must-link (ml) constraint.
\(C_{ij}=-1\) if (i, j) has a must-not-link (mnl) constraint.
\(C_{ij}=0\) if (i, j) has no constraints.
- jaxclust._src.solvers.ckruskals_prims_post(S: Array, ncc: int, C: Array) Tuple[Array, Array]
Calculates the adjacency matrix and cluster connectivity matrix of the minimum weight ncc-spanning forest. Uses a biased heuristic based on kruskals algorithm to create the forest. Afterwards applies prims algorithm to recalculate the spanning tree of each connected component in the forest (hence guarenteed to obtain a solution at least as good if not better than ckruskals).
- Parameters:
S (jax.Array) – similarity matrix.
ncc (int) – number of connected components.
C (jax.Array) – constraint matrix.
- Returns:
A, M
- Return type:
Tuple[jax.Array, jax.Array]
\(A_{ij} = 1\) if the edge (i, j) is in the forest.
\(M_{ij} = 1\) if i and j are in the same connected component of the forest.
\(C_{ij}=1\) if (i, j) has a must-link (ml) constraint.
\(C_{ij}=-1\) if (i, j) has a must-not-link (mnl) constraint.
\(C_{ij}=0\) if (i, j) has no constraints.
Perturbations
- jaxclust._src.perturbations.make_pert_flp_solver(flp_solver: ~typing.Callable, constrained: bool, num_samples: int = 1000, noise=<jaxclust._src.perturbations.Normal object>, control_variate: bool = False) Callable
Creates a perturbed solver of the maximum weight k-connected-component forest lp (flp).
- Parameters:
flp_solver (Callable) – an flp solver from jaxclust.solvers.
constrained (bool) – indicates if flp_solver takes constraints.
num_samples (int, optional) – number of samples for MC estimator. Defaults to 1000.
noise (Class, optional) – noise distribution. Defaults to Normal().
control_variate (bool) – use a control variate for jacobians of adjacency and connectivity matrix.
- Returns:
an flp solver taking the same args as flp_solver as well the additional arguments (following those of flp_solver):
sigma (float): magnitude of noise. rng (jax.random.PRNGKey).
- Return type:
Callable