Source code for furax.core._axes

from collections.abc import Sequence
from dataclasses import field
from math import prod

import jax
from jax import Array
from jax import numpy as jnp
from jaxtyping import Inexact, PyTree

from ._base import AbstractLinearOperator, IdentityOperator, TransposeOperator
from .rules import AbstractBinaryRule, NoReduction

__all__ = [
    'MoveAxisOperator',
    'RavelOperator',
    'ReshapeOperator',
]


[docs] class MoveAxisOperator(AbstractLinearOperator): """Operator that moves axes of pytree leaves: y = jnp.moveaxis(x, source, destination). This operator is orthogonal: its transpose and inverse are the reverse axis move. Attributes: source: The source axis or axes to move. destination: The destination axis or axes. Example: >>> in_structure = jax.ShapeDtypeStruct((2, 3), jnp.float32) >>> op = MoveAxisOperator(0, 1, in_structure=in_structure) >>> op(jnp.array([[1., 1, 1], [2, 2, 2]])) Array([[1., 2.], [1., 2.], [1., 2.]], dtype=float32) """ source: tuple[int, ...] = field(metadata={'static': True}) destination: tuple[int, ...] = field(metadata={'static': True}) def __init__( self, source: int | Sequence[int], destination: int | Sequence[int], *, in_structure: PyTree[jax.ShapeDtypeStruct], ) -> None: if isinstance(source, int): source = (source,) elif not isinstance(source, tuple): source = tuple(source) if isinstance(destination, int): destination = (destination,) elif not isinstance(destination, tuple): destination = tuple(destination) object.__setattr__(self, 'source', source) object.__setattr__(self, 'destination', destination) object.__setattr__(self, 'in_structure', in_structure)
[docs] def mv(self, x: PyTree[Array, '...']) -> PyTree[Array, '...']: return jax.tree.map(lambda leaf: jnp.moveaxis(leaf, self.source, self.destination), x)
[docs] def transpose(self) -> AbstractLinearOperator: return MoveAxisOperator( source=self.destination, destination=self.source, in_structure=self.out_structure )
inverse = transpose
class MoveAxisInverseRule(AbstractBinaryRule): """Binary rule for move_axis.T @ move_axis = I`. Note: We cannot simply decorate MoveAxisOperator with :orthogonal: because it is not square, in the sense that its input and output structures are different. """ left_operator_class = MoveAxisOperator right_operator_class = MoveAxisOperator def apply( self, left: AbstractLinearOperator, right: AbstractLinearOperator ) -> list[AbstractLinearOperator]: assert isinstance(left, MoveAxisOperator) # mypy assert assert isinstance(right, MoveAxisOperator) # mypy assert if left.source != right.destination or left.destination != right.source: raise NoReduction return [] # Note: if an algebraic rule to compose MoveAxisOperators is to be implemented, it may be best # to implement a class TransposeOperator wrapping jnp.transpose and transform MoveAxisOperator # instances into TransposeOperator instances. That way, it would be easier to include reductions for # new operators, such as SwapAxesOperator, etc. class AbstractRavelOrReshapeOperator(AbstractLinearOperator): def as_matrix(self) -> Inexact[Array, 'a b']: return jnp.eye(self.in_size, dtype=self.out_promoted_dtype) def transpose(self) -> AbstractLinearOperator: return ReshapeTransposeOperator(self) # type: ignore[arg-type] def reduce(self) -> AbstractLinearOperator: if self.out_structure == self.in_structure: return IdentityOperator(in_structure=self.in_structure) return self
[docs] class RavelOperator(AbstractRavelOrReshapeOperator): """Operator that flattens pytree leaves: y = x.ravel(). By default, all dimensions are flattened. Use ``first_axis`` and ``last_axis`` to flatten only a subset of contiguous axes. This operator is orthogonal: its transpose restores the original shape. Attributes: first_axis: The first axis to flatten (default: 0). last_axis: The last axis to flatten (default: -1). Example: To flatten the leaves of a pytree: >>> in_structure = jax.ShapeDtypeStruct((2, 3), jnp.float32) >>> op = RavelOperator(in_structure=in_structure) >>> op.out_structure ShapeDtypeStruct(shape=(6,), dtype=float32) To flatten the first two axes of the leaves of a pytree: >>> import furax as fx >>> x = [jnp.ones((2, 2)), jnp.ones((2, 2, 8))] >>> op = RavelOperator(0, 1, in_structure=fx.tree.as_structure(x)) >>> op.out_structure [ShapeDtypeStruct(shape=(4,), dtype=float32), ShapeDtypeStruct(shape=(4, 8), dtype=float32)] To flatten the last two axes of the leaves of a pytree: >>> import furax as fx >>> x = [jnp.ones((2, 2, 3)), jnp.ones((2, 8))] >>> op = RavelOperator(-2, -1, in_structure=fx.tree.as_structure(x)) >>> op.out_structure [ShapeDtypeStruct(shape=(2, 6), dtype=float32), ShapeDtypeStruct(shape=(16,), dtype=float32)] """ first_axis: int = field(default=0, metadata={'static': True}) last_axis: int = field(default=-1, metadata={'static': True}) def __post_init__(self) -> None: first_axis = self.first_axis last_axis = self.last_axis in_structure = self.in_structure if 0 <= last_axis < first_axis or last_axis < first_axis < 0: raise ValueError( f'the first axis ({first_axis}) to be flattened should be before the last one ' f'({last_axis}).' ) if first_axis < 0 <= last_axis or last_axis < 0 <= first_axis: for leaf in jax.tree.leaves(in_structure): first = leaf.ndim + first_axis if first_axis < 0 else first_axis last = leaf.ndim + last_axis if last_axis < 0 else last_axis if first > last: raise ValueError( f'there are no dimensions between {first_axis} and {last_axis} ' f'to be flattened in leaf of shape {leaf.shape}.' )
[docs] def mv(self, x: PyTree[Inexact[Array, ' _a']]) -> PyTree[Inexact[Array, ' _b']]: def func(leaf: Inexact[Array, ' _a']) -> Inexact[Array, ' _b']: first_axis = leaf.ndim + self.first_axis if self.first_axis < 0 else self.first_axis last_axis = leaf.ndim + self.last_axis if self.last_axis < 0 else self.last_axis if first_axis > last_axis: assert False, 'unreachable' if first_axis == last_axis: return leaf new_shape = leaf.shape[:first_axis] + (-1,) + leaf.shape[last_axis + 1 :] return leaf.reshape(new_shape) return jax.tree.map(func, x)
[docs] class ReshapeOperator(AbstractRavelOrReshapeOperator): """Operator that reshapes pytree leaves: y = x.reshape(shape). This operator is orthogonal: its transpose restores the original shape. Attributes: shape: The new shape of the pytree leaves. Use -1 for one inferred dimension. """ shape: tuple[int, ...] = field(metadata={'static': True}) def __post_init__(self) -> None: super().__post_init__() for leaf in jax.tree.leaves(self.in_structure): new_shape = self._normalize_shape(self.shape, leaf.shape) if leaf.size != prod(new_shape): raise ValueError(f'invalid new shape {self.shape} for leaf of shape {leaf.shape}.') @staticmethod def _normalize_shape(shape: tuple[int, ...], leaf_shape: tuple[int, ...]) -> tuple[int, ...]: if any(_ < -1 for _ in shape): raise ValueError(f'reshape new sizes should be all positive, got {shape}.') try: index = shape.index(-1) except ValueError: return shape before = shape[:index] after = shape[index + 1 :] if -1 in after: raise ValueError('can only specify one unknown dimension.') unknown_dimension = -prod(leaf_shape) / prod(shape) if unknown_dimension != int(unknown_dimension): raise ValueError(f'cannot reshape array of shape {leaf_shape} into shape {shape}.') return before + (int(unknown_dimension),) + after
[docs] def mv(self, x: PyTree[Inexact[Array, ' _a']]) -> PyTree[Inexact[Array, ' _b']]: return jax.tree.map(lambda leaf: leaf.reshape(self.shape), x)
class ReshapeTransposeOperator(TransposeOperator): operator: ReshapeOperator | RavelOperator def mv(self, x: PyTree[Inexact[Array, ' _a']]) -> PyTree[Inexact[Array, ' _b']]: return jax.tree.map( lambda leaf, out_structure_leaf: leaf.reshape(out_structure_leaf.shape), x, self.out_structure, ) class ReshapeInverseRule(AbstractBinaryRule): """Binary rule for reshape.T @ reshape = I and reshape @ reshape.T = I`. Note: We cannot simply decorate ReshapeOperator with :orthogonal: because it is not square, in the sense that its input and output structures are different. """ left_operator_class = (AbstractRavelOrReshapeOperator, ReshapeTransposeOperator) right_operator_class = (AbstractRavelOrReshapeOperator, ReshapeTransposeOperator) def apply( self, left: AbstractLinearOperator, right: AbstractLinearOperator ) -> list[AbstractLinearOperator]: if isinstance(left, AbstractRavelOrReshapeOperator): if not isinstance(right, ReshapeTransposeOperator): raise NoReduction if right.operator is not left: raise NoReduction return [] elif isinstance(left, ReshapeTransposeOperator): if not isinstance(right, AbstractRavelOrReshapeOperator): raise NoReduction if left.operator is not right: raise NoReduction return [] else: assert False, 'unreachable'