Linear Operators API

Contents

Linear Operators API#

This section provides detailed API documentation for Furax linear operators.

Core Operators#

Abstract Base Class#

class furax.core._base.AbstractLinearOperator(*, in_structure=None)[source]#

Bases: ABC

Base class for linear operators.

Parameters:

in_structure (PyTree[jax._src.api.ShapeDtypeStruct])

class_tags: ClassVar[OperatorTag] = 0#
in_structure: PyTree[jax._src.api.ShapeDtypeStruct] = None#
property tags: OperatorTag#

Get the tags for this operator instance.

property is_square: bool#
property is_symmetric: bool#
property is_orthogonal: bool#
property is_diagonal: bool#
property is_tridiagonal: bool#
property is_lower_triangular: bool#
property is_upper_triangular: bool#
property is_positive_semidefinite: bool#
property is_negative_semidefinite: bool#
abstract mv(x)[source]#
Parameters:

x (PyTree[jaxtyping.Inexact[Array, '_a']])

Return type:

PyTree[jaxtyping.Inexact[Array, ‘_b’]]

reduce()[source]#

Returns a linear operator with a reduced structure.

Return type:

AbstractLinearOperator

as_matrix()[source]#

Returns the operator as a dense matrix.

Input and output PyTrees are flattened and concatenated.

Return type:

Inexact[Array, ‘a b’]

transpose()[source]#
Return type:

AbstractLinearOperator

property T: AbstractLinearOperator#
inverse()[source]#
Return type:

AbstractLinearOperator

property I: AbstractLinearOperator#
property out_structure: PyTree[jax._src.api.ShapeDtypeStruct]#
property in_size: int#

The number of elements in the input PyTree.

property out_size: int#

The number of elements in the output PyTree.

property in_promoted_dtype: dtype[Any]#

Returns the promoted data type of the operator’s input leaves.

property out_promoted_dtype: dtype[Any]#

Returns the promoted data type of the operator’s output leaves.

Diagonal Operators#

class furax.core._diagonal.DiagonalOperator(diagonal, *, axis_destination=-1, in_structure=None)[source]#

Bases: BroadcastDiagonalOperator

Operator that performs element-wise multiplication: D(x) = d * x.

The diagonal operator is symmetric and square. Its inverse divides by the diagonal values (zeros are handled by returning zero).

The multiplication axes can be specified via axis_destination:
  • axis_destination=0: diagonal[:, None] * x (multiply along first axis)

  • axis_destination=-1: diagonal * x (standard broadcasting, default)

Variables:
  • diagonal – The diagonal values.

  • axis_destination (int | tuple[int, ...]) – The axes along which the diagonal values are applied to the input. If the type is a sequence, there should be as many axes as the number of dimensions in the diagonal input. If the type is a non-negative scalar integer, the dimensions will be (axis, ..., axis + diagonal.ndim - 1). If the type is a negative scalar integer, the dimensions will be (axis - diagonal.ndim, ..., axis).

Parameters:

Example

>>> import furax as fx
>>> from numpy.testing import assert_allclose
>>> key_gain, key_tod, key_common = jax.random.split(jax.random.PRNGKey(0), 3)
>>> detector_count = 3
>>> sample_count = 10
>>> x = {
...     'tod': jax.random.normal(key_tod, (detector_count, sample_count)),
...     'ground': jax.random.normal(key_common, (detector_count,)),
... }
>>> detector_gains = jax.random.normal(key_gain, (detector_count,)) / 100 + 1
>>> op = DiagonalOperator(
...     detector_gains, axis_destination=0, in_structure=fx.tree.as_structure(x)
... )
>>> y = op(x)
>>> assert_allclose(x['tod'] * detector_gains[:, None], y['tod'])
>>> assert_allclose(x['ground'] * detector_gains, y['ground'])
inverse()[source]#
Return type:

AbstractLinearOperator

as_matrix()[source]#

Returns the operator as a dense matrix.

Input and output PyTrees are flattened and concatenated.

Return type:

Inexact[Array, ‘a b’]

class_tags: ClassVar[OperatorTag] = 11#
property out_structure#
transpose()#
class furax.core._diagonal.BroadcastDiagonalOperator(diagonal, *, axis_destination=-1, in_structure=None)[source]#

Bases: AbstractLinearOperator

Operator that performs element-wise multiplication with broadcasting.

Unlike DiagonalOperator, this operator can change the output shape through broadcasting, making it non-square. Depending on the broadcasting direction:

  • Left broadcasting (extending dimensions on the left): equivalent to a block row operator with diagonal blocks.

  • Right broadcasting (extending dimensions on the right): equivalent to a block diagonal operator with column blocks.

Variables:
  • diagonal – The diagonal values.

  • axis_destination (int | tuple[int, ...]) – The axes along which the generalized diagonal values are applied to the input. If the type is a sequence, there should be as many axes as the number of dimensions in the diagonal input. If the type is a non-negative scalar integer, the dimensions will be (axis, ..., axis + diagonal.ndim - 1). If the type is a negative scalar integer, the dimensions will be (axis - diagonal.ndim, ..., axis).

Parameters:

Example

>>> import furax as fx
>>> import jax.numpy as jnp
>>> from numpy.testing import assert_allclose
>>> x = jnp.array([1, 2, 3])
>>> values = jnp.array([[1, 1, 1], [2, 1, 0]])
>>> op = BroadcastDiagonalOperator(
...     values, in_structure=fx.tree.as_structure(x), axis_destination=-1
... )
>>> assert_allclose(op(x), jnp.array([[1, 2, 3], [2, 2, 0]]))
>>> op.as_matrix()
Array([[1, 0, 0],
       [0, 1, 0],
       [0, 0, 1],
       [2, 0, 0],
       [0, 1, 0],
       [0, 0, 0]], dtype=int32)
>>> x = jnp.array([1, 2])
>>> values = jnp.array([[2, 3, 1], [1, 0, 1]])
>>> op = BroadcastDiagonalOperator(
...     values, in_structure=fx.tree.as_structure(x), axis_destination=0
... )
>>> assert_allclose(op(x), jnp.array([[2, 3, 1], [2, 0, 2]]))
>>> op.as_matrix()
Array([[2, 0],
       [3, 0],
       [1, 0],
       [0, 1],
       [0, 0],
       [0, 1]], dtype=int32)
>>> x = jnp.array([[0, 1, 2], [2, 3, 4]])
>>> values = jnp.array([2, 1])
>>> op = BroadcastDiagonalOperator(
...     values, in_structure=fx.tree.as_structure(x), axis_destination=0
... )
>>> assert_allclose(op(x), jnp.array([[0, 2, 4], [2, 3, 4]]))
>>> op.as_matrix()
Array([[2, 0, 0, 0, 0, 0],
       [0, 2, 0, 0, 0, 0],
       [0, 0, 2, 0, 0, 0],
       [0, 0, 0, 1, 0, 0],
       [0, 0, 0, 0, 1, 0],
       [0, 0, 0, 0, 0, 1]], dtype=int32)
axis_destination: int | tuple[int, ...] = -1#
property diagonal: Inexact[Array, '...']#
mv(x)[source]#
Parameters:

x (PyTree[jaxtyping.Inexact[Array, '...']])

Return type:

PyTree[jaxtyping.Inexact[Array, ‘…’]]

Block Operators#

class furax.core._blocks.BlockDiagonalOperator(blocks)[source]#

Bases: AbstractBlockOperator

Operator with independent diagonal blocks: diag(A, B, C).

Applies each block independently to the corresponding part of a pytree input. No constraints on block input/output structures.

The inverse is the block diagonal of individual inverses (if all blocks are square).

Variables:

blocks (jaxtyping.PyTree[furax.core._base.AbstractLinearOperator]) – A pytree of operators.

Parameters:

blocks (PyTree[furax.core._base.AbstractLinearOperator])

Example

>>> x = jnp.array([1, 2], jnp.float32)
>>> H = DenseBlockDiagonalOperator(
...     jnp.array([[0, 1], [1, 0]]),
...     jax.ShapeDtypeStruct((2,), jnp.float32)
... )
>>> H.as_matrix()
Array([[0., 1.],
       [1., 0.]], dtype=float32)
>>> op_list = BlockDiagonalOperator([H, 2*H, 3*H])
>>> op_list.as_matrix()
Array([[0., 1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 2., 0., 0.],
       [0., 0., 2., 0., 0., 0.],
       [0., 0., 0., 0., 0., 3.],
       [0., 0., 0., 0., 3., 0.]], dtype=float32)
>>> op_list([x, x, x])
[Array([2., 1.], dtype=float32),
 Array([4., 2.], dtype=float32),
 Array([6., 3.], dtype=float32)]
>>> op_dict = BlockDiagonalOperator({'a': H, 'b': 2*H, 'c': 3*H})
>>> op_dict({'a': x, 'b': x, 'c': x})
{'a': Array([2., 1.], dtype=float32),
 'b': Array([4., 2.], dtype=float32),
 'c': Array([6., 3.], dtype=float32)}
mv(vector)[source]#
Parameters:

vector (PyTree[jaxtyping.Inexact[Array, '_b']])

Return type:

PyTree[jaxtyping.Inexact[Array, ‘_a’]]

transpose()[source]#
Return type:

AbstractLinearOperator

inverse()[source]#
Return type:

AbstractLinearOperator

as_matrix()[source]#

Returns the operator as a dense matrix.

Input and output PyTrees are flattened and concatenated.

Return type:

Inexact[Array, ‘a b’]

reduce()[source]#

BlockDiagonalOperator([I, I, …]) -> I.

Return type:

AbstractLinearOperator

class furax.core._blocks.BlockRowOperator(blocks)[source]#

Bases: AbstractBlockOperator

Operator that horizontally concatenates block operators: [A | B | C].

Applies each block to the corresponding part of a pytree input and sums the results. All blocks must have the same output structure.

Transpose: BlockRowOperator.T = BlockColumnOperator

Variables:

blocks (jaxtyping.PyTree[furax.core._base.AbstractLinearOperator]) – A pytree of operators (all with identical output structure).

Parameters:

blocks (PyTree[furax.core._base.AbstractLinearOperator])

Example

>>> x = jnp.array([1, 2], jnp.float32)
>>> I = IdentityOperator(in_structure=jax.ShapeDtypeStruct((2,), jnp.float32))
>>> op_list = BlockRowOperator([I, 2*I, 3*I])
>>> op_list.as_matrix()
Array([[1., 0., 2., 0., 3., 0.],
       [0., 1., 0., 2., 0., 3.]], dtype=float32)
>>> op_list([x, x, x])
Array([ 6., 12.], dtype=float32)
>>> op_dict = BlockRowOperator({'a': I, 'b': 2*I, 'c': 3*I})
>>> op_dict({'a': x, 'b': x, 'c': x})
Array([ 6., 12.], dtype=float32)
mv(x)[source]#
Parameters:

x (PyTree[jaxtyping.Inexact[Array, '_b']])

Return type:

PyTree[jaxtyping.Inexact[Array, ‘_a’]]

transpose()[source]#
Return type:

AbstractLinearOperator

property out_structure: PyTree[jax._src.api.ShapeDtypeStruct]#
as_matrix()[source]#

Returns the operator as a dense matrix.

Input and output PyTrees are flattened and concatenated.

Return type:

Inexact[Array, ‘a b’]

class furax.core._blocks.BlockColumnOperator(blocks)[source]#

Bases: AbstractBlockOperator

Operator that vertically stacks block operators: [A; B; C].

Applies each block to the same input and returns a pytree of outputs. All blocks must have the same input structure.

Transpose: BlockColumnOperator.T = BlockRowOperator

Variables:

blocks (jaxtyping.PyTree[furax.core._base.AbstractLinearOperator]) – A pytree of operators (all with identical input structure).

Parameters:

blocks (PyTree[furax.core._base.AbstractLinearOperator])

Example

>>> x = jnp.array([1, 2], jnp.float32)
>>> I = IdentityOperator(in_structure=jax.ShapeDtypeStruct((2,), jnp.float32))
>>> op_list = BlockColumnOperator([I, I, I])
>>> op_list.as_matrix()
Array([[1., 0.],
       [0., 1.],
       [1., 0.],
       [0., 1.],
       [1., 0.],
       [0., 1.]], dtype=float32)
>>> op_list(x)
[Array([1., 2.], dtype=float32),
 Array([1., 2.], dtype=float32),
 Array([1., 2.], dtype=float32)]
>>> op_dict = BlockColumnOperator({'a': I, 'b': I, 'c': I})
>>> op_dict(x)
{'a': Array([1., 2.], dtype=float32),
 'b': Array([1., 2.], dtype=float32),
 'c': Array([1., 2.], dtype=float32)}
mv(vector)[source]#
Parameters:

vector (PyTree[jaxtyping.Inexact[Array, '_b']])

Return type:

PyTree[jaxtyping.Inexact[Array, ‘_a’]]

transpose()[source]#
Return type:

AbstractLinearOperator

as_matrix()[source]#

Returns the operator as a dense matrix.

Input and output PyTrees are flattened and concatenated.

Return type:

Inexact[Array, ‘a b’]

Dense Operators#

class furax.core._dense.DenseBlockDiagonalOperator(blocks, subscripts='ij...,j...->i...', *, in_structure=None)[source]#

Bases: AbstractLinearOperator

Operator that applies block diagonal dense matrices via einsum.

Only the diagonal blocks are stored, making this more memory-efficient than a full dense matrix. The operation is defined by einsum subscripts.

Variables:
  • blocks (jaxtyping.Inexact[Array, '...']) – The dense blocks as an array (at least 2D).

  • subscripts (str) – Einsum subscripts defining the operation (default: ‘ij…,j…->i…’).

Parameters:
  • blocks (Inexact[Array, '...'])

  • subscripts (str)

  • in_structure (PyTree[jax._src.api.ShapeDtypeStruct])

Example

For a matrix made of three 2x4 diagonal blocks, and input block columns of three blocks of four elements each, the operator can be written as:

>>> blocks = jnp.arange(24).reshape(3, 2, 4)
>>> op = DenseBlockDiagonalOperator(
...     blocks, jax.ShapeDtypeStruct((3, 4), jnp.int32), 'imn,in->im')
>>> op.as_matrix()
Array([[ 0,  1,  2,  3,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 4,  5,  6,  7,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  8,  9, 10, 11,  0,  0,  0,  0],
       [ 0,  0,  0,  0, 12, 13, 14, 15,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0,  0, 16, 17, 18, 19],
       [ 0,  0,  0,  0,  0,  0,  0,  0, 20, 21, 22, 23]], dtype=int32)

The axes along which the operator is block diagonal can be non-leading dimensions. As a matter of fact, by default, the diagonal axes are assumed to be “on the right”. The notion of block diagonality should be understood in a tensor context. The representation of this operator as a 2d matrix, which relies on the row-major layout, may not be block diagonal.

>>> blocks = jnp.arange(24).reshape(3, 2, 4)
>>> op = DenseBlockDiagonalOperator(blocks, jax.ShapeDtypeStruct((2, 4), jnp.int32))
>>> op.as_matrix()
Array([[ 0,  0,  0,  0,  4,  0,  0,  0],
       [ 0,  1,  0,  0,  0,  5,  0,  0],
       [ 0,  0,  2,  0,  0,  0,  6,  0],
       [ 0,  0,  0,  3,  0,  0,  0,  7],
       [ 8,  0,  0,  0, 12,  0,  0,  0],
       [ 0,  9,  0,  0,  0, 13,  0,  0],
       [ 0,  0, 10,  0,  0,  0, 14,  0],
       [ 0,  0,  0, 11,  0,  0,  0, 15],
       [16,  0,  0,  0, 20,  0,  0,  0],
       [ 0, 17,  0,  0,  0, 21,  0,  0],
       [ 0,  0, 18,  0,  0,  0, 22,  0],
       [ 0,  0,  0, 19,  0,  0,  0, 23]], dtype=int32)
blocks: Inexact[Array, '...']#
subscripts: str = 'ij...,j...->i...'#
mv(x)[source]#
Parameters:

x (PyTree[jax.Array, '...'])

Return type:

PyTree[jax.Array]

transpose()[source]#
Return type:

AbstractLinearOperator

Toeplitz Operators#

class furax.core._toeplitz.SymmetricBandToeplitzOperator(band_values, *, in_structure, method='overlap_save_parallel', fft_size=None)[source]#

Bases: AbstractLinearOperator

Operator for symmetric band Toeplitz convolution.

A Toeplitz matrix has constant diagonals. This operator is symmetric and exploits the band structure for efficient computation. For multidimensional band values, the operator is block diagonal.

Available methods (N = matrix size, K = number of bands):
  • dense: dense matrix multiplication

  • direct: direct convolution

  • fft: FFT on the whole input

  • overlap_save_sequential: sequential FFT on chunks

  • overlap_save_parallel: batched FFT on chunks (default)

Variables:
  • band_values (jaxtyping.Float[Array, '...']) – The band values (first element is the diagonal).

  • method (str) – The computation method.

  • fft_size (int | None) – FFT size for the fft and overlap methods.

Parameters:
  • band_values (Float[Array, '...'])

  • in_structure (PyTree[jax._src.api.ShapeDtypeStruct])

  • method (str)

  • fft_size (int | None)

Example

>>> tod = jnp.ones((2, 5))
>>> op = SymmetricBandToeplitzOperator(
...     jnp.array([[1., 0.5], [1, 0.25]]),
...     in_structure=jax.ShapeDtypeStruct(tod.shape, tod.dtype))
>>> op(tod)
Array([[1.5 , 2.  , 2.  , 2.  , 1.5 ],
       [1.25, 1.5 , 1.5 , 1.5 , 1.25]], dtype=float64)
>>> op.as_matrix()
Array([[1.  , 0.5 , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.5 , 1.  , 0.5 , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.5 , 1.  , 0.5 , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.5 , 1.  , 0.5 , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.5 , 1.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 1.  , 0.25, 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.25, 1.  , 0.25, 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.25, 1.  , 0.25, 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.25, 1.  , 0.25],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.25, 1.  ]],      dtype=float64)
METHODS: ClassVar[tuple[str, ...]] = ('dense', 'direct', 'fft', 'overlap_save_parallel', 'overlap_save_sequential')#
band_values: Float[Array, '...']#
method: str#
fft_size: int | None#
mv(x)[source]#
Parameters:

x (Float[Array, '...'])

Return type:

Float[Array, ‘…’]

as_matrix()[source]#

Returns the operator as a dense matrix.

Input and output PyTrees are flattened and concatenated.

Return type:

Inexact[Array, ‘a a’]

class_tags: ClassVar[OperatorTag] = 3#
property out_structure#
transpose()#

Index and Reshape Operators#

class furax.core._indices.IndexOperator(indices, *, in_structure, out_structure=None, unique_indices=None, _out_structure=None)[source]#

Bases: AbstractLinearOperator

Operator that extracts elements by indexing: y = x[indices].

Supports integer indices, slices, boolean masks, and advanced indexing. When indices are unique, the operator satisfies: I @ I.T = Identity.

Variables:
  • indices (tuple[int | slice | jaxtyping.Bool[Array, '...'] | jaxtyping.Integer[Array, '...'] | ellipsis, ...]) – The indexing tuple (integers, slices, arrays, or Ellipsis).

  • unique_indices (bool) – Whether the indices select unique elements (enables optimizations).

Parameters:
  • indices (tuple[int | slice | Bool[Array, '...'] | Integer[Array, '...'] | ellipsis, ...])

  • in_structure (PyTree[jax._src.api.ShapeDtypeStruct])

  • out_structure (PyTree[jax._src.api.ShapeDtypeStruct] | None)

  • unique_indices (bool)

  • _out_structure (PyTree[jax._src.api.ShapeDtypeStruct])

Example: To extract the second element of the first axis:

>>> op = IndexOperator(1, in_structure=jax.ShapeDtypeStruct((10, 4), jax.numpy.float32))

To extract values from the penultimate axis given an array of indices:

>>> indices = jax.numpy.array([2, 4, 4, 5, 7])
>>> in_structure = jax.ShapeDtypeStruct((9, 8, 3), jax.numpy.float32)
>>> op = IndexOperator((..., indices, slice(None)), in_structure=in_structure)

In order to extract values using a boolean mask, it is required to specify an output structure:

>>> indices = jax.numpy.array([True, False, True, False])
>>> in_structure = jax.ShapeDtypeStruct((4,), jax.numpy.float32)
>>> out_structure = jax.ShapeDtypeStruct((2,), jax.numpy.float32)
>>> op = IndexOperator(indices, in_structure=in_structure, out_structure=out_structure)

So it is usually better to specify an index mask:

>>> op = IndexOperator(jnp.where(indices), in_structure=in_structure)
indices: tuple[int | slice | Bool[Array, '...'] | Integer[Array, '...'] | ellipsis, ...]#
unique_indices: bool#
mv(x)[source]#
Parameters:

x (PyTree[jaxtyping.Inexact[Array, '_a']])

Return type:

PyTree[jaxtyping.Inexact[Array, ‘_b’]]

property out_structure: PyTree[jax._src.api.ShapeDtypeStruct]#
reduce()[source]#

Returns a linear operator with a reduced structure.

Return type:

AbstractLinearOperator

property indexed_axes: list[int]#

Returns the list of axes for which an indexing is performed.

Example: for an indexing of (slice(None), 3, …, jnp.array([1, 2])),

it returns [1, -1].

class furax.core._axes.ReshapeOperator(shape, *, in_structure=None)[source]#

Bases: AbstractRavelOrReshapeOperator

Operator that reshapes pytree leaves: y = x.reshape(shape).

This operator is orthogonal: its transpose restores the original shape.

Variables:

shape (tuple[int, ...]) – The new shape of the pytree leaves. Use -1 for one inferred dimension.

Parameters:
  • shape (tuple[int, ...])

  • in_structure (PyTree[jax._src.api.ShapeDtypeStruct])

shape: tuple[int, ...]#
mv(x)[source]#
Parameters:

x (PyTree[jaxtyping.Inexact[Array, '_a']])

Return type:

PyTree[jaxtyping.Inexact[Array, ‘_b’]]

class furax.core._axes.MoveAxisOperator(source, destination, *, in_structure)[source]#

Bases: AbstractLinearOperator

Operator that moves axes of pytree leaves: y = jnp.moveaxis(x, source, destination).

This operator is orthogonal: its transpose and inverse are the reverse axis move.

Variables:
  • source (tuple[int, ...]) – The source axis or axes to move.

  • destination (tuple[int, ...]) – The destination axis or axes.

Parameters:
  • source (tuple[int, ...])

  • destination (tuple[int, ...])

  • in_structure (PyTree[jax._src.api.ShapeDtypeStruct])

Example

>>> in_structure = jax.ShapeDtypeStruct((2, 3), jnp.float32)
>>> op = MoveAxisOperator(0, 1, in_structure=in_structure)
>>> op(jnp.array([[1., 1, 1], [2, 2, 2]]))
Array([[1., 2.],
       [1., 2.],
       [1., 2.]], dtype=float32)
source: tuple[int, ...]#
destination: tuple[int, ...]#
mv(x)[source]#
Parameters:

x (PyTree[jax.Array, '...'])

Return type:

PyTree[jax.Array, ‘…’]

transpose()[source]#
Return type:

AbstractLinearOperator

inverse()#
Return type:

AbstractLinearOperator

class furax.core._axes.RavelOperator(first_axis=0, last_axis=-1, *, in_structure=None)[source]#

Bases: AbstractRavelOrReshapeOperator

Operator that flattens pytree leaves: y = x.ravel().

By default, all dimensions are flattened. Use first_axis and last_axis to flatten only a subset of contiguous axes.

This operator is orthogonal: its transpose restores the original shape.

Variables:
  • first_axis (int) – The first axis to flatten (default: 0).

  • last_axis (int) – The last axis to flatten (default: -1).

Parameters:
  • first_axis (int)

  • last_axis (int)

  • in_structure (PyTree[jax._src.api.ShapeDtypeStruct])

Example

To flatten the leaves of a pytree:

>>> in_structure = jax.ShapeDtypeStruct((2, 3), jnp.float32)
>>> op = RavelOperator(in_structure=in_structure)
>>> op.out_structure
ShapeDtypeStruct(shape=(6,), dtype=float32)

To flatten the first two axes of the leaves of a pytree:

>>> import furax as fx
>>> x = [jnp.ones((2, 2)), jnp.ones((2, 2, 8))]
>>> op = RavelOperator(0, 1, in_structure=fx.tree.as_structure(x))
>>> op.out_structure
[ShapeDtypeStruct(shape=(4,), dtype=float32),
ShapeDtypeStruct(shape=(4, 8), dtype=float32)]

To flatten the last two axes of the leaves of a pytree:

>>> import furax as fx
>>> x = [jnp.ones((2, 2, 3)), jnp.ones((2, 8))]
>>> op = RavelOperator(-2, -1, in_structure=fx.tree.as_structure(x))
>>> op.out_structure
[ShapeDtypeStruct(shape=(2, 6), dtype=float32),
ShapeDtypeStruct(shape=(16,), dtype=float32)]
first_axis: int = 0#
last_axis: int = -1#
mv(x)[source]#
Parameters:

x (PyTree[jaxtyping.Inexact[Array, '_a']])

Return type:

PyTree[jaxtyping.Inexact[Array, ‘_b’]]

Tree Operators#

class furax.core._trees.TreeOperator(tree, *, in_structure, inner_treedef=None, outer_treedef=None)[source]#

Bases: AbstractLinearOperator

Operator defined by a generalized matrix as a pytree of pytrees.

More memory-efficient than dense matrices when the structure allows XLA optimizations (symmetries, zeros, shared elements, broadcasting).

The structure of the generalized matrix is the tree product: outer_treedef x inner_treedef, analogous to a matrix where rows are represented by an outer tree and columns by the inner tree.

Leaves within the same row must be broadcastable; no such constraint applies across rows.

Variables:
  • tree (jaxtyping.PyTree[jax.Array]) – The generalized matrix as a pytree of pytrees.

  • outer_treedef (jaxlib._jax.pytree.PyTreeDef) – PyTreeDef for rows.

  • inner_treedef (jaxlib._jax.pytree.PyTreeDef) – PyTreeDef for columns.

  • tree_shape – (num_rows, num_cols) in terms of tree leaves.

Parameters:
  • tree (PyTree[jax.Array])

  • in_structure (PyTree[jax._src.api.ShapeDtypeStruct])

  • inner_treedef (PyTreeDef)

  • outer_treedef (PyTreeDef)

Example

To represent the Mueller Matrix of a quarter-wave plate with a vertical fast-axis:

>>> from furax.obs.stokes import StokesIQUV
>>> op = TreeOperator(
...     StokesIQUV(
            StokesIQUV(1, 0, 0,  0),
            StokesIQUV(0, 1, 0,  0),
            StokesIQUV(0, 0, 0, -1),
            StokesIQUV(0, 0, 1,  0),
        ),
        in_structure=StokesIQUV.structure_for((), jnp.float32)
... )
>>> op.as_matrix()
Array([[ 1.,  0.,  0.,  0.],
       [ 0.,  1.,  0.,  0.],
       [ 0.,  0.,  0., -1.],
       [ 0.,  0.,  1.,  0.]], dtype=float32)
>>> op(StokesIQUV(1., 1., 1., 1.))
StokesIQUV(i=1.0, q=1.0, u=-1.0, v=1.0)
tree: PyTree[jax.Array]#
inner_treedef: PyTreeDef#
outer_treedef: PyTreeDef#
property tree_shape: tuple[int, int]#

Return the number of leaves of the outer and inner structures.

mv(x)[source]#
Parameters:

x (PyTree[jaxtyping.Inexact[Array, '_a']])

Return type:

PyTree[jaxtyping.Inexact[Array, ‘_b’]]

transpose()[source]#
Return type:

AbstractLinearOperator

inverse()[source]#
Return type:

AbstractLinearOperator

Configuration#

class furax._config.Config(**kwargs)[source]#

Bases: object

Parameters:

kwargs (Any)

classmethod instance()[source]#
Return type:

ConfigState