Source code for furax.core._trees

from dataclasses import field
from typing import Any

import jax
import jax.numpy as jnp
from jax import Array
from jax.tree_util import PyTreeDef
from jaxtyping import Inexact, PyTree

from ..tree import _dense_to_tree, _get_outer_treedef, _tree_to_dense, matmat, matvec
from ._base import AbstractLinearOperator
from .rules import AbstractBinaryRule


[docs] class TreeOperator(AbstractLinearOperator): """Operator defined by a generalized matrix as a pytree of pytrees. More memory-efficient than dense matrices when the structure allows XLA optimizations (symmetries, zeros, shared elements, broadcasting). The structure of the generalized matrix is the tree product: outer_treedef x inner_treedef, analogous to a matrix where rows are represented by an outer tree and columns by the inner tree. Leaves within the same row must be broadcastable; no such constraint applies across rows. Attributes: tree: The generalized matrix as a pytree of pytrees. outer_treedef: PyTreeDef for rows. inner_treedef: PyTreeDef for columns. tree_shape: (num_rows, num_cols) in terms of tree leaves. Example: To represent the Mueller Matrix of a quarter-wave plate with a vertical fast-axis: >>> from furax.obs.stokes import StokesIQUV >>> op = TreeOperator( ... StokesIQUV( StokesIQUV(1, 0, 0, 0), StokesIQUV(0, 1, 0, 0), StokesIQUV(0, 0, 0, -1), StokesIQUV(0, 0, 1, 0), ), in_structure=StokesIQUV.structure_for((), jnp.float32) ... ) >>> op.as_matrix() Array([[ 1., 0., 0., 0.], [ 0., 1., 0., 0.], [ 0., 0., 0., -1.], [ 0., 0., 1., 0.]], dtype=float32) >>> op(StokesIQUV(1., 1., 1., 1.)) StokesIQUV(i=1.0, q=1.0, u=-1.0, v=1.0) """ tree: PyTree[Array] inner_treedef: PyTreeDef = field(metadata={'static': True}) outer_treedef: PyTreeDef = field(metadata={'static': True}) def __init__( self, tree: PyTree[PyTree[Any]], *, in_structure: PyTree[jax.ShapeDtypeStruct], inner_treedef: PyTreeDef | None = None, outer_treedef: PyTreeDef | None = None, ) -> None: if inner_treedef is None: inner_treedef = jax.tree.structure(in_structure) if outer_treedef is None: outer_treedef = _get_outer_treedef(in_structure, tree) object.__setattr__(self, 'tree', tree) object.__setattr__(self, 'inner_treedef', inner_treedef) object.__setattr__(self, 'outer_treedef', outer_treedef) object.__setattr__(self, 'in_structure', in_structure) @property def tree_shape(self) -> tuple[int, int]: """Return the number of leaves of the outer and inner structures.""" return self.outer_treedef.num_leaves, self.inner_treedef.num_leaves
[docs] def mv(self, x: PyTree[Inexact[Array, ' _a']]) -> PyTree[Inexact[Array, ' _b']]: return matvec(self.outer_treedef, self.tree, x)
[docs] def transpose(self) -> AbstractLinearOperator: transposed_tree = jax.tree.transpose(self.outer_treedef, self.inner_treedef, self.tree) return TreeOperator(transposed_tree, in_structure=self.out_structure)
[docs] def inverse(self) -> AbstractLinearOperator: dense = _tree_to_dense(self.outer_treedef, self.inner_treedef, self.tree) dense_pinv = jnp.linalg.pinv(dense) tree = _dense_to_tree(self.inner_treedef, self.outer_treedef, dense_pinv) return TreeOperator(tree, in_structure=self.out_structure)
class TreeMultiplicationRule(AbstractBinaryRule): """Binary rule for `tree_left @ tree_right.""" left_operator_class = TreeOperator right_operator_class = TreeOperator def apply( self, left: AbstractLinearOperator, right: AbstractLinearOperator ) -> list[AbstractLinearOperator]: assert isinstance(left, TreeOperator) assert isinstance(right, TreeOperator) return [ TreeOperator( matmat(left.outer_treedef, left.tree, right.outer_treedef, right.tree), in_structure=right.in_structure, ) ]