Source code for furax.core._dense

import functools as ft
from dataclasses import field

import jax
from jax import Array
from jax import numpy as jnp
from jaxtyping import Inexact, PyTree

from furax.tree import is_leaf

from ._base import AbstractLinearOperator


[docs] class DenseBlockDiagonalOperator(AbstractLinearOperator): """Operator that applies block diagonal dense matrices via einsum. Only the diagonal blocks are stored, making this more memory-efficient than a full dense matrix. The operation is defined by einsum subscripts. Attributes: blocks: The dense blocks as an array (at least 2D). subscripts: Einsum subscripts defining the operation (default: 'ij...,j...->i...'). Example: For a matrix made of three 2x4 diagonal blocks, and input block columns of three blocks of four elements each, the operator can be written as: >>> blocks = jnp.arange(24).reshape(3, 2, 4) >>> op = DenseBlockDiagonalOperator( ... blocks, jax.ShapeDtypeStruct((3, 4), jnp.int32), 'imn,in->im') >>> op.as_matrix() Array([[ 0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0], [ 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0], [ 0, 0, 0, 0, 8, 9, 10, 11, 0, 0, 0, 0], [ 0, 0, 0, 0, 12, 13, 14, 15, 0, 0, 0, 0], [ 0, 0, 0, 0, 0, 0, 0, 0, 16, 17, 18, 19], [ 0, 0, 0, 0, 0, 0, 0, 0, 20, 21, 22, 23]], dtype=int32) The axes along which the operator is block diagonal can be non-leading dimensions. As a matter of fact, by default, the diagonal axes are assumed to be "on the right". The notion of block diagonality should be understood in a tensor context. The representation of this operator as a 2d matrix, which relies on the row-major layout, may not be block diagonal. >>> blocks = jnp.arange(24).reshape(3, 2, 4) >>> op = DenseBlockDiagonalOperator(blocks, jax.ShapeDtypeStruct((2, 4), jnp.int32)) >>> op.as_matrix() Array([[ 0, 0, 0, 0, 4, 0, 0, 0], [ 0, 1, 0, 0, 0, 5, 0, 0], [ 0, 0, 2, 0, 0, 0, 6, 0], [ 0, 0, 0, 3, 0, 0, 0, 7], [ 8, 0, 0, 0, 12, 0, 0, 0], [ 0, 9, 0, 0, 0, 13, 0, 0], [ 0, 0, 10, 0, 0, 0, 14, 0], [ 0, 0, 0, 11, 0, 0, 0, 15], [16, 0, 0, 0, 20, 0, 0, 0], [ 0, 17, 0, 0, 0, 21, 0, 0], [ 0, 0, 18, 0, 0, 0, 22, 0], [ 0, 0, 0, 19, 0, 0, 0, 23]], dtype=int32) """ blocks: Inexact[Array, '...'] subscripts: str = field(default='ij...,j...->i...', metadata={'static': True}) def __post_init__(self) -> None: subscripts = self.subscripts.replace(' ', '') if subscripts != self.subscripts: object.__setattr__(self, 'subscripts', subscripts) if not jax.tree.all(jax.tree.map(lambda leaf: len(leaf.shape) >= 2, self.blocks)): raise ValueError('The blocks should at least have 2 dimensions.') self._parse_subscripts(subscripts)
[docs] def mv(self, x: PyTree[Array, '...']) -> PyTree[Array]: if is_leaf(x): return jnp.einsum(self.subscripts, self.blocks, x) leaves, treedef = jax.tree.flatten(x) if is_leaf(self.blocks): return jax.tree.unflatten( treedef, [jnp.einsum(self.subscripts, self.blocks, leaf) for leaf in leaves] ) return jax.tree.map(ft.partial(jnp.einsum, self.subscripts), self.blocks, x)
[docs] def transpose(self) -> AbstractLinearOperator: return DenseBlockDiagonalOperator( blocks=self.blocks, in_structure=self.out_structure, subscripts=self._get_transposed_subscripts(self.subscripts), )
@staticmethod def _parse_subscripts(subscripts: str) -> tuple[str, str, str]: split_subscripts = subscripts.split(',') if len(split_subscripts) != 2: raise ValueError(f'There should be a single comma in the subscripts: {subscripts!r}."') left_subscripts, subscripts = split_subscripts split_subscripts = subscripts.split('->') if len(split_subscripts) != 2: raise ValueError('Explicit mode (with `->) is required for the einsum subscripts.') right_subscripts, result_subscripts = split_subscripts return left_subscripts, right_subscripts, result_subscripts @staticmethod def _get_transposed_subscripts(subscripts: str) -> PyTree[jax.ShapeDtypeStruct]: """Returns the einsum subscripts for the transpose operation. Examples: ij...,j...->i... gives ji...,j...->i... hij...,hj...->hi... gives hji...,hj...->hi... ikj,kj->ki gives jki,kj->ki """ lefts, rights, results = DenseBlockDiagonalOperator._parse_subscripts(subscripts) lefts_as_set = set(lefts.replace('...', '')) rights_as_set = set(rights.replace('...', '')) results_as_set = set(results.replace('...', '')) # the sum axis is in the subscripts left and right but not in result sum_axis_as_set = lefts_as_set & rights_as_set - results_as_set if len(sum_axis_as_set) != 1: raise ValueError(f'The summation should be performed in one axis {subscripts!r}.') sum_axis = sum_axis_as_set.pop() # the transpose axis is in the subscripts left and result but not in right transpose_axis_as_set = lefts_as_set & results_as_set - rights_as_set if len(transpose_axis_as_set) == 0: raise ValueError(f'No transposition axis has been specified {subscripts!r}.') if len(transpose_axis_as_set) > 1: raise ValueError(f'Several transposition axes have been specified: {subscripts!r}.') transpose_axis = transpose_axis_as_set.pop() # we swap the transpose and sum axes sum_axis_number = lefts.index(sum_axis) transpose_axis_number = lefts.index(transpose_axis) lefts_as_list = list(lefts) lefts_as_list[sum_axis_number] = transpose_axis lefts_as_list[transpose_axis_number] = sum_axis lefts = ''.join(lefts_as_list) transpose_axis_number = results.index(transpose_axis) results_as_list = list(results) results_as_list[transpose_axis_number] = sum_axis expected_results = ''.join(results_as_list) if expected_results != rights: raise ValueError( f'The dimensions of the inputs {rights!r} cannot be reordered ' f'into {expected_results!r}.' ) return f'{lefts},{rights}->{results}'