rqutils.math.matrix_ufunc

rqutils.math.matrix_ufunc(op, mat, hermitian=0, with_diagonals=False, npmod=np, save_errors=False)

Apply a unitary-invariant unary matrix operator to an array of normal matrices.

The argument mat must be an array of normal (i.e. square diagonalizable) matrices in the last two dimensions. This function unitary-diagonalizes the matrices, applies op to the diagonals, and inverts the diagonalization.

Diagonalization and gradient

When using this function with an autodiff library (e.g. JAX), the gradient diverges when an input parameter controls off-diagonal elements of mat but mat is diagonal. Use an alternative function (that is hopefully available) in such cases:

# Reshape the matrix to gather all off-diagonal elements to a block ([:, 1:])
mat_dim = mat.shape[-1]
diag_checker = mat.reshape(-1, mat_dim ** 2)
# The very last element is a part of diagonal -> can ignore for this purpose
diag_checker = diag_checker[:, :-1].reshape(-1, mat_dim - 1, mat_dim + 1)
is_diagonal = ~jnp.any(diag_checker[:, :, 1:], axis=(1, 2))
has_diagonal = jnp.any(is_diagonal)

result = jax.lax.cond(has_diagonal,
                      alternative_X,
                      functools.partial(matrix_ufunc, X),
                      mat)
Parameters
  • op (Callable) – Unary operator to be applied to the diagonals of mat.

  • mat (Union[Number, Sequence[Number], numpy.ndarray]) – Array of normal matrices (shape (…, n, n)). No check on normality is performed.

  • hermitian (Union[int, bool]) – 1 or True -> mat is Hermitian, -1 -> mat is anti-hermitian, 0 or False -> otherwise

  • with_diagonals (bool) – If True, also return the array op(eigenvalues).

Return type

Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]

Returns

An array corresponding to op(mat). If diagonals==True, another array corresponding to op(eigvals).