API reference

Solvers

jaxclust._src.solvers.kruskals(S: Array, ncc: int) Tuple[Array, Array]

Calculates the adjacency matrix and cluster connectivity matrix of the minimum weight ncc-spanning forest using Kruskal’s algorithm.

Parameters:
  • S (jax.Array) – similarity matrix.

  • ncc (int) – number of connected components.

Returns:

A, M

Return type:

Tuple[jax.Array, jax.Array]

\(A_{ij} = 1\) if the edge (i, j) is in the forest.

\(M_{ij} = 1\) if i and j are in the same connected component of the forest.

jaxclust._src.solvers.kruskals_prims_pre(S: Array, ncc: int) Tuple[Array, Array]

Calculates the adjacency matrix and cluster connectivity matrix of the minimum weight ncc-spanning forest. Uses Prim’s algorithm to construct the full spanning tree, then applies Kruskal’s algorithm to the edges in the spanning tree in order to calculate the forest.

Parameters:
  • S (jax.Array) – similarity matrix.

  • ncc (int) – number of connected components.

Returns:

A, M

Return type:

Tuple[jax.Array, jax.Array]

\(A_{ij} = 1\) if the edge (i, j) is in the forest.

\(M_{ij} = 1\) if i and j are in the same connected component of the forest.

jaxclust._src.solvers.ckruskals(S: Array, ncc: int, C: Array) Tuple[Array, Array]

Calculates the adjacency matrix and cluster connectivity matrix of the minimum weight ncc-spanning forest. Uses a biased heuristic based on kruskals algorithm to create the forest.

Parameters:
  • S (jax.Array) – similarity matrix.

  • ncc (int) – number of connected components.

  • C (jax.Array) – constraint matrix.

Returns:

A, M

Return type:

Tuple[jax.Array, jax.Array]

\(A_{ij} = 1\) if the edge (i, j) is in the forest.

\(M_{ij} = 1\) if i and j are in the same connected component of the forest.

\(C_{ij}=1\) if (i, j) has a must-link (ml) constraint.

\(C_{ij}=-1\) if (i, j) has a must-not-link (mnl) constraint.

\(C_{ij}=0\) if (i, j) has no constraints.

jaxclust._src.solvers.ckruskals_prims_post(S: Array, ncc: int, C: Array) Tuple[Array, Array]

Calculates the adjacency matrix and cluster connectivity matrix of the minimum weight ncc-spanning forest. Uses a biased heuristic based on kruskals algorithm to create the forest. Afterwards applies prims algorithm to recalculate the spanning tree of each connected component in the forest (hence guarenteed to obtain a solution at least as good if not better than ckruskals).

Parameters:
  • S (jax.Array) – similarity matrix.

  • ncc (int) – number of connected components.

  • C (jax.Array) – constraint matrix.

Returns:

A, M

Return type:

Tuple[jax.Array, jax.Array]

\(A_{ij} = 1\) if the edge (i, j) is in the forest.

\(M_{ij} = 1\) if i and j are in the same connected component of the forest.

\(C_{ij}=1\) if (i, j) has a must-link (ml) constraint.

\(C_{ij}=-1\) if (i, j) has a must-not-link (mnl) constraint.

\(C_{ij}=0\) if (i, j) has no constraints.

Perturbations

jaxclust._src.perturbations.make_pert_flp_solver(flp_solver: ~typing.Callable, constrained: bool, num_samples: int = 1000, noise=<jaxclust._src.perturbations.Normal object>, control_variate: bool = False) Callable

Creates a perturbed solver of the maximum weight k-connected-component forest lp (flp).

Parameters:
  • flp_solver (Callable) – an flp solver from jaxclust.solvers.

  • constrained (bool) – indicates if flp_solver takes constraints.

  • num_samples (int, optional) – number of samples for MC estimator. Defaults to 1000.

  • noise (Class, optional) – noise distribution. Defaults to Normal().

  • control_variate (bool) – use a control variate for jacobians of adjacency and connectivity matrix.

Returns:

an flp solver taking the same args as flp_solver as well the additional arguments (following those of flp_solver):

sigma (float): magnitude of noise. rng (jax.random.PRNGKey).

Return type:

Callable