from dataclasses import field
from typing import Any
import jax
import jax.numpy as jnp
from jax import Array
from jax.tree_util import PyTreeDef
from jaxtyping import Inexact, PyTree
from ..tree import _dense_to_tree, _get_outer_treedef, _tree_to_dense, matmat, matvec
from ._base import AbstractLinearOperator
from .rules import AbstractBinaryRule
[docs]
class TreeOperator(AbstractLinearOperator):
"""Operator defined by a generalized matrix as a pytree of pytrees.
More memory-efficient than dense matrices when the structure allows XLA
optimizations (symmetries, zeros, shared elements, broadcasting).
The structure of the generalized matrix is the tree product: outer_treedef x inner_treedef,
analogous to a matrix where rows are represented by an outer tree and columns by the inner tree.
Leaves within the same row must be broadcastable; no such constraint applies across rows.
Attributes:
tree: The generalized matrix as a pytree of pytrees.
outer_treedef: PyTreeDef for rows.
inner_treedef: PyTreeDef for columns.
tree_shape: (num_rows, num_cols) in terms of tree leaves.
Example:
To represent the Mueller Matrix of a quarter-wave plate with a vertical fast-axis:
>>> from furax.obs.stokes import StokesIQUV
>>> op = TreeOperator(
... StokesIQUV(
StokesIQUV(1, 0, 0, 0),
StokesIQUV(0, 1, 0, 0),
StokesIQUV(0, 0, 0, -1),
StokesIQUV(0, 0, 1, 0),
),
in_structure=StokesIQUV.structure_for((), jnp.float32)
... )
>>> op.as_matrix()
Array([[ 1., 0., 0., 0.],
[ 0., 1., 0., 0.],
[ 0., 0., 0., -1.],
[ 0., 0., 1., 0.]], dtype=float32)
>>> op(StokesIQUV(1., 1., 1., 1.))
StokesIQUV(i=1.0, q=1.0, u=-1.0, v=1.0)
"""
tree: PyTree[Array]
inner_treedef: PyTreeDef = field(metadata={'static': True})
outer_treedef: PyTreeDef = field(metadata={'static': True})
def __init__(
self,
tree: PyTree[PyTree[Any]],
*,
in_structure: PyTree[jax.ShapeDtypeStruct],
inner_treedef: PyTreeDef | None = None,
outer_treedef: PyTreeDef | None = None,
) -> None:
if inner_treedef is None:
inner_treedef = jax.tree.structure(in_structure)
if outer_treedef is None:
outer_treedef = _get_outer_treedef(in_structure, tree)
object.__setattr__(self, 'tree', tree)
object.__setattr__(self, 'inner_treedef', inner_treedef)
object.__setattr__(self, 'outer_treedef', outer_treedef)
object.__setattr__(self, 'in_structure', in_structure)
@property
def tree_shape(self) -> tuple[int, int]:
"""Return the number of leaves of the outer and inner structures."""
return self.outer_treedef.num_leaves, self.inner_treedef.num_leaves
[docs]
def mv(self, x: PyTree[Inexact[Array, ' _a']]) -> PyTree[Inexact[Array, ' _b']]:
return matvec(self.outer_treedef, self.tree, x)
[docs]
def transpose(self) -> AbstractLinearOperator:
transposed_tree = jax.tree.transpose(self.outer_treedef, self.inner_treedef, self.tree)
return TreeOperator(transposed_tree, in_structure=self.out_structure)
[docs]
def inverse(self) -> AbstractLinearOperator:
dense = _tree_to_dense(self.outer_treedef, self.inner_treedef, self.tree)
dense_pinv = jnp.linalg.pinv(dense)
tree = _dense_to_tree(self.inner_treedef, self.outer_treedef, dense_pinv)
return TreeOperator(tree, in_structure=self.out_structure)
class TreeMultiplicationRule(AbstractBinaryRule):
"""Binary rule for `tree_left @ tree_right."""
left_operator_class = TreeOperator
right_operator_class = TreeOperator
def apply(
self, left: AbstractLinearOperator, right: AbstractLinearOperator
) -> list[AbstractLinearOperator]:
assert isinstance(left, TreeOperator)
assert isinstance(right, TreeOperator)
return [
TreeOperator(
matmat(left.outer_treedef, left.tree, right.outer_treedef, right.tree),
in_structure=right.in_structure,
)
]