Source code for furax.core._blocks

import functools
from abc import ABC
from collections.abc import Callable
from typing import Any

import jax
import jax.numpy as jnp
import jax.scipy.linalg as jsl
from jax import Array
from jax.tree_util import PyTreeDef
from jaxtyping import Inexact, PyTree

from ..tree import add
from ._base import AbstractLinearOperator, AdditionOperator, IdentityOperator
from .rules import AbstractBinaryRule


class AbstractBlockOperator(AbstractLinearOperator, ABC):
    blocks: PyTree[AbstractLinearOperator]

    def __init__(self, blocks: PyTree[AbstractLinearOperator]) -> None:
        object.__setattr__(self, 'blocks', blocks)
        super().__init__(in_structure=self._tree_map(lambda op: op.in_structure))

    @property
    def out_structure(self) -> PyTree[jax.ShapeDtypeStruct]:
        return self._tree_map(lambda op: op.out_structure)

    def reduce(self) -> AbstractLinearOperator:
        return type(self)(self._tree_map(lambda op: op.reduce()))

    @property
    def block_leaves(self) -> list[AbstractLinearOperator]:
        """Returns the flat list of operators."""
        return jax.tree.leaves(self.blocks, is_leaf=lambda x: isinstance(x, AbstractLinearOperator))

    def _tree_map(self, f: Callable[..., Any], *args: Any) -> Any:
        return jax.tree.map(
            f,
            self.blocks,
            *args,
            is_leaf=lambda x: isinstance(x, AbstractLinearOperator),
        )


[docs] class BlockRowOperator(AbstractBlockOperator): """Operator that horizontally concatenates block operators: [A | B | C]. Applies each block to the corresponding part of a pytree input and sums the results. All blocks must have the same output structure. Transpose: BlockRowOperator.T = BlockColumnOperator Attributes: blocks: A pytree of operators (all with identical output structure). Example: >>> x = jnp.array([1, 2], jnp.float32) >>> I = IdentityOperator(in_structure=jax.ShapeDtypeStruct((2,), jnp.float32)) >>> op_list = BlockRowOperator([I, 2*I, 3*I]) >>> op_list.as_matrix() Array([[1., 0., 2., 0., 3., 0.], [0., 1., 0., 2., 0., 3.]], dtype=float32) >>> op_list([x, x, x]) Array([ 6., 12.], dtype=float32) >>> op_dict = BlockRowOperator({'a': I, 'b': 2*I, 'c': 3*I}) >>> op_dict({'a': x, 'b': x, 'c': x}) Array([ 6., 12.], dtype=float32) """ def __init__(self, blocks: PyTree[AbstractLinearOperator]) -> None: super().__init__(blocks) try: operators = self.block_leaves ref_structure = operators[0].out_structure except (AttributeError, TypeError): # During JAX/equinox tree operations, operators may have boolean placeholders return invalid_structures = [ structure for operator in operators[1:] if (structure := operator.out_structure) != ref_structure ] if len(invalid_structures) > 0: structures_as_str = '\n - '.join(str(structure) for structure in invalid_structures) raise ValueError( f'The operators in a BlockRowOperator must have the same output structure:\n' f' - {ref_structure}\n' f' - {structures_as_str}' )
[docs] def mv(self, x: PyTree[Inexact[Array, ' _b']]) -> PyTree[Inexact[Array, ' _a']]: treedef: PyTreeDef = jax.tree.structure( self.blocks, is_leaf=lambda op: isinstance(op, AbstractLinearOperator) ) output_leaves = ( block(leaf) for block, leaf in zip(self.block_leaves, treedef.flatten_up_to(x)) ) return functools.reduce(lambda a, b: add(a, b), output_leaves)
[docs] def transpose(self) -> AbstractLinearOperator: return BlockColumnOperator(self._tree_map(lambda op: op.T))
@property def out_structure(self) -> PyTree[jax.ShapeDtypeStruct]: return self.block_leaves[0].out_structure
[docs] def as_matrix(self) -> Inexact[Array, 'a b']: return jnp.hstack([op.as_matrix() for op in self.block_leaves])
[docs] class BlockDiagonalOperator(AbstractBlockOperator): """Operator with independent diagonal blocks: diag(A, B, C). Applies each block independently to the corresponding part of a pytree input. No constraints on block input/output structures. The inverse is the block diagonal of individual inverses (if all blocks are square). Attributes: blocks: A pytree of operators. Example: >>> x = jnp.array([1, 2], jnp.float32) >>> H = DenseBlockDiagonalOperator( ... jnp.array([[0, 1], [1, 0]]), ... jax.ShapeDtypeStruct((2,), jnp.float32) ... ) >>> H.as_matrix() Array([[0., 1.], [1., 0.]], dtype=float32) >>> op_list = BlockDiagonalOperator([H, 2*H, 3*H]) >>> op_list.as_matrix() Array([[0., 1., 0., 0., 0., 0.], [1., 0., 0., 0., 0., 0.], [0., 0., 0., 2., 0., 0.], [0., 0., 2., 0., 0., 0.], [0., 0., 0., 0., 0., 3.], [0., 0., 0., 0., 3., 0.]], dtype=float32) >>> op_list([x, x, x]) [Array([2., 1.], dtype=float32), Array([4., 2.], dtype=float32), Array([6., 3.], dtype=float32)] >>> op_dict = BlockDiagonalOperator({'a': H, 'b': 2*H, 'c': 3*H}) >>> op_dict({'a': x, 'b': x, 'c': x}) {'a': Array([2., 1.], dtype=float32), 'b': Array([4., 2.], dtype=float32), 'c': Array([6., 3.], dtype=float32)} """ def __init__(self, blocks: PyTree[AbstractLinearOperator]) -> None: # required: otherwise, the parent constructor would not be called by the dataclass-generated constructor super().__init__(blocks)
[docs] def mv(self, vector: PyTree[Inexact[Array, ' _b']]) -> PyTree[Inexact[Array, ' _a']]: return self._tree_map(lambda op, vect: op.mv(vect), vector)
[docs] def transpose(self) -> AbstractLinearOperator: return BlockDiagonalOperator(self._tree_map(lambda op: op.T))
[docs] def inverse(self) -> AbstractLinearOperator: # if some of the blocks are not square, let's defer to the default inverse method if not jax.tree.all(self._tree_map(lambda op: op.in_structure == op.out_structure)): return super().inverse() return BlockDiagonalOperator(self._tree_map(lambda op: op.I))
[docs] def as_matrix(self) -> Inexact[Array, 'a b']: return jsl.block_diag(*[op.as_matrix() for op in self.block_leaves]) # type: ignore[no-any-return] # noqa: E501
[docs] def reduce(self) -> AbstractLinearOperator: """BlockDiagonalOperator([I, I, ...]) -> I.""" op = super().reduce() assert isinstance(op, BlockDiagonalOperator) if all(isinstance(block, IdentityOperator) for block in op.block_leaves): return IdentityOperator(in_structure=self.in_structure) return op
[docs] class BlockColumnOperator(AbstractBlockOperator): """Operator that vertically stacks block operators: [A; B; C]. Applies each block to the same input and returns a pytree of outputs. All blocks must have the same input structure. Transpose: BlockColumnOperator.T = BlockRowOperator Attributes: blocks: A pytree of operators (all with identical input structure). Example: >>> x = jnp.array([1, 2], jnp.float32) >>> I = IdentityOperator(in_structure=jax.ShapeDtypeStruct((2,), jnp.float32)) >>> op_list = BlockColumnOperator([I, I, I]) >>> op_list.as_matrix() Array([[1., 0.], [0., 1.], [1., 0.], [0., 1.], [1., 0.], [0., 1.]], dtype=float32) >>> op_list(x) [Array([1., 2.], dtype=float32), Array([1., 2.], dtype=float32), Array([1., 2.], dtype=float32)] >>> op_dict = BlockColumnOperator({'a': I, 'b': I, 'c': I}) >>> op_dict(x) {'a': Array([1., 2.], dtype=float32), 'b': Array([1., 2.], dtype=float32), 'c': Array([1., 2.], dtype=float32)} """ def __init__(self, blocks: PyTree[AbstractLinearOperator]) -> None: super().__init__(blocks) try: operators = self.block_leaves ref_structure = operators[0].in_structure except (AttributeError, TypeError): # During JAX/equinox tree operations, operators may have boolean placeholders return object.__setattr__(self, 'in_structure', ref_structure) invalid_structures = [ structure for operator in operators[1:] if (structure := operator.in_structure) != ref_structure ] if len(invalid_structures) > 0: structures_as_str = '\n - '.join(str(structure) for structure in invalid_structures) raise ValueError( f'The operators in a BlockColumnOperator must have the same input structure:\n' f' - {ref_structure}\n' f' - {structures_as_str}' )
[docs] def mv(self, vector: PyTree[Inexact[Array, ' _b']]) -> PyTree[Inexact[Array, ' _a']]: return self._tree_map(lambda op: op.mv(vector))
[docs] def transpose(self) -> AbstractLinearOperator: return BlockRowOperator(self._tree_map(lambda op: op.T))
[docs] def as_matrix(self) -> Inexact[Array, 'a b']: return jnp.vstack([op.as_matrix() for op in self.block_leaves])
class AbstractBlockDiagonalRule(AbstractBinaryRule): reduced_class: type[AbstractBlockOperator] | type[AdditionOperator] def apply( self, left: AbstractLinearOperator, right: AbstractLinearOperator ) -> list[AbstractLinearOperator]: assert isinstance(left, AbstractBlockOperator) # mypy assert assert isinstance(right, AbstractBlockOperator) # mypy assert return [self.reduced_class(left._tree_map(lambda l, r: l @ r, right.blocks)).reduce()] class BlockRowBlockDiagonalRule(AbstractBlockDiagonalRule): """Binary rule for the composition of a block row and a block diagonal operator. BlockRow(A_i, ...) @ BlockDiagonal(B_i, ...) = BlockRow(A_i @ B_i, ...) """ left_operator_class = BlockRowOperator right_operator_class = BlockDiagonalOperator reduced_class = BlockRowOperator class BlockDiagonalBlockColumnRule(AbstractBlockDiagonalRule): """Binary rule for the composition of a block diagonal and a block column operator. BlockDiagonal(A_i, ...) @ BlockColumn(B_i, ...) = BlockColumn(A_i @ B_i, ...) """ left_operator_class = BlockDiagonalOperator right_operator_class = BlockColumnOperator reduced_class = BlockColumnOperator class BlockDiagonalBlockDiagonalRule(AbstractBlockDiagonalRule): """Binary rule for the composition of two block diagonal operators. BlockDiagonal(A_i, ...) @ BlockDiagonal(B_i, ...) = BlockDiagonal(A_i @ B_i, ...) """ left_operator_class = BlockDiagonalOperator right_operator_class = BlockDiagonalOperator reduced_class = BlockDiagonalOperator class BlockRowBlockColumnRule(AbstractBlockDiagonalRule): """Binary rule for the composition of a block row and a block column operator. BlockRow(A_i, ...) @ BlockColumn(B_i, ...) = Σ A_i @ B_i """ left_operator_class = BlockRowOperator right_operator_class = BlockColumnOperator reduced_class = AdditionOperator