Linear Operators API#
This section provides detailed API documentation for Furax linear operators.
Core Operators#
Abstract Base Class#
- class furax.core._base.AbstractLinearOperator(*, in_structure=None)[source]#
Bases:
ABCBase class for linear operators.
- Parameters:
in_structure (PyTree[jax._src.api.ShapeDtypeStruct])
- in_structure: PyTree[jax._src.api.ShapeDtypeStruct] = None#
- property tags: OperatorTag#
Get the tags for this operator instance.
- abstract mv(x)[source]#
- Parameters:
x (PyTree[jaxtyping.Inexact[Array, '_a']])
- Return type:
PyTree[jaxtyping.Inexact[Array, ‘_b’]]
- as_matrix()[source]#
Returns the operator as a dense matrix.
Input and output PyTrees are flattened and concatenated.
- Return type:
Inexact[Array, ‘a b’]
- property T: AbstractLinearOperator#
- property I: AbstractLinearOperator#
- property out_structure: PyTree[jax._src.api.ShapeDtypeStruct]#
Diagonal Operators#
- class furax.core._diagonal.DiagonalOperator(diagonal, *, axis_destination=-1, in_structure=None)[source]#
Bases:
BroadcastDiagonalOperatorOperator that performs element-wise multiplication: D(x) = d * x.
The diagonal operator is symmetric and square. Its inverse divides by the diagonal values (zeros are handled by returning zero).
- The multiplication axes can be specified via
axis_destination: axis_destination=0:diagonal[:, None] * x(multiply along first axis)axis_destination=-1:diagonal * x(standard broadcasting, default)
- Variables:
diagonal – The diagonal values.
axis_destination (int | tuple[int, ...]) – The axes along which the diagonal values are applied to the input. If the type is a sequence, there should be as many axes as the number of dimensions in the
diagonalinput. If the type is a non-negative scalar integer, the dimensions will be(axis, ..., axis + diagonal.ndim - 1). If the type is a negative scalar integer, the dimensions will be(axis - diagonal.ndim, ..., axis).
- Parameters:
Example
>>> import furax as fx >>> from numpy.testing import assert_allclose >>> key_gain, key_tod, key_common = jax.random.split(jax.random.PRNGKey(0), 3) >>> detector_count = 3 >>> sample_count = 10 >>> x = { ... 'tod': jax.random.normal(key_tod, (detector_count, sample_count)), ... 'ground': jax.random.normal(key_common, (detector_count,)), ... } >>> detector_gains = jax.random.normal(key_gain, (detector_count,)) / 100 + 1 >>> op = DiagonalOperator( ... detector_gains, axis_destination=0, in_structure=fx.tree.as_structure(x) ... ) >>> y = op(x) >>> assert_allclose(x['tod'] * detector_gains[:, None], y['tod']) >>> assert_allclose(x['ground'] * detector_gains, y['ground'])
- as_matrix()[source]#
Returns the operator as a dense matrix.
Input and output PyTrees are flattened and concatenated.
- Return type:
Inexact[Array, ‘a b’]
- class_tags: ClassVar[OperatorTag] = 11#
- property out_structure#
- transpose()#
- The multiplication axes can be specified via
- class furax.core._diagonal.BroadcastDiagonalOperator(diagonal, *, axis_destination=-1, in_structure=None)[source]#
Bases:
AbstractLinearOperatorOperator that performs element-wise multiplication with broadcasting.
Unlike DiagonalOperator, this operator can change the output shape through broadcasting, making it non-square. Depending on the broadcasting direction:
Left broadcasting (extending dimensions on the left): equivalent to a block row operator with diagonal blocks.
Right broadcasting (extending dimensions on the right): equivalent to a block diagonal operator with column blocks.
- Variables:
diagonal – The diagonal values.
axis_destination (int | tuple[int, ...]) – The axes along which the generalized diagonal values are applied to the input. If the type is a sequence, there should be as many axes as the number of dimensions in the
diagonalinput. If the type is a non-negative scalar integer, the dimensions will be(axis, ..., axis + diagonal.ndim - 1). If the type is a negative scalar integer, the dimensions will be(axis - diagonal.ndim, ..., axis).
- Parameters:
Example
>>> import furax as fx >>> import jax.numpy as jnp >>> from numpy.testing import assert_allclose >>> x = jnp.array([1, 2, 3]) >>> values = jnp.array([[1, 1, 1], [2, 1, 0]]) >>> op = BroadcastDiagonalOperator( ... values, in_structure=fx.tree.as_structure(x), axis_destination=-1 ... ) >>> assert_allclose(op(x), jnp.array([[1, 2, 3], [2, 2, 0]])) >>> op.as_matrix() Array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [2, 0, 0], [0, 1, 0], [0, 0, 0]], dtype=int32)
>>> x = jnp.array([1, 2]) >>> values = jnp.array([[2, 3, 1], [1, 0, 1]]) >>> op = BroadcastDiagonalOperator( ... values, in_structure=fx.tree.as_structure(x), axis_destination=0 ... ) >>> assert_allclose(op(x), jnp.array([[2, 3, 1], [2, 0, 2]])) >>> op.as_matrix() Array([[2, 0], [3, 0], [1, 0], [0, 1], [0, 0], [0, 1]], dtype=int32)
>>> x = jnp.array([[0, 1, 2], [2, 3, 4]]) >>> values = jnp.array([2, 1]) >>> op = BroadcastDiagonalOperator( ... values, in_structure=fx.tree.as_structure(x), axis_destination=0 ... ) >>> assert_allclose(op(x), jnp.array([[0, 2, 4], [2, 3, 4]])) >>> op.as_matrix() Array([[2, 0, 0, 0, 0, 0], [0, 2, 0, 0, 0, 0], [0, 0, 2, 0, 0, 0], [0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 1]], dtype=int32)
- property diagonal: Inexact[Array, '...']#
Block Operators#
- class furax.core._blocks.BlockDiagonalOperator(blocks)[source]#
Bases:
AbstractBlockOperatorOperator with independent diagonal blocks: diag(A, B, C).
Applies each block independently to the corresponding part of a pytree input. No constraints on block input/output structures.
The inverse is the block diagonal of individual inverses (if all blocks are square).
- Variables:
blocks (jaxtyping.PyTree[furax.core._base.AbstractLinearOperator]) – A pytree of operators.
- Parameters:
blocks (PyTree[furax.core._base.AbstractLinearOperator])
Example
>>> x = jnp.array([1, 2], jnp.float32) >>> H = DenseBlockDiagonalOperator( ... jnp.array([[0, 1], [1, 0]]), ... jax.ShapeDtypeStruct((2,), jnp.float32) ... ) >>> H.as_matrix() Array([[0., 1.], [1., 0.]], dtype=float32) >>> op_list = BlockDiagonalOperator([H, 2*H, 3*H]) >>> op_list.as_matrix() Array([[0., 1., 0., 0., 0., 0.], [1., 0., 0., 0., 0., 0.], [0., 0., 0., 2., 0., 0.], [0., 0., 2., 0., 0., 0.], [0., 0., 0., 0., 0., 3.], [0., 0., 0., 0., 3., 0.]], dtype=float32) >>> op_list([x, x, x]) [Array([2., 1.], dtype=float32), Array([4., 2.], dtype=float32), Array([6., 3.], dtype=float32)]
>>> op_dict = BlockDiagonalOperator({'a': H, 'b': 2*H, 'c': 3*H}) >>> op_dict({'a': x, 'b': x, 'c': x}) {'a': Array([2., 1.], dtype=float32), 'b': Array([4., 2.], dtype=float32), 'c': Array([6., 3.], dtype=float32)}
- mv(vector)[source]#
- Parameters:
vector (PyTree[jaxtyping.Inexact[Array, '_b']])
- Return type:
PyTree[jaxtyping.Inexact[Array, ‘_a’]]
- as_matrix()[source]#
Returns the operator as a dense matrix.
Input and output PyTrees are flattened and concatenated.
- Return type:
Inexact[Array, ‘a b’]
- class furax.core._blocks.BlockRowOperator(blocks)[source]#
Bases:
AbstractBlockOperatorOperator that horizontally concatenates block operators: [A | B | C].
Applies each block to the corresponding part of a pytree input and sums the results. All blocks must have the same output structure.
Transpose: BlockRowOperator.T = BlockColumnOperator
- Variables:
blocks (jaxtyping.PyTree[furax.core._base.AbstractLinearOperator]) – A pytree of operators (all with identical output structure).
- Parameters:
blocks (PyTree[furax.core._base.AbstractLinearOperator])
Example
>>> x = jnp.array([1, 2], jnp.float32) >>> I = IdentityOperator(in_structure=jax.ShapeDtypeStruct((2,), jnp.float32)) >>> op_list = BlockRowOperator([I, 2*I, 3*I]) >>> op_list.as_matrix() Array([[1., 0., 2., 0., 3., 0.], [0., 1., 0., 2., 0., 3.]], dtype=float32) >>> op_list([x, x, x]) Array([ 6., 12.], dtype=float32)
>>> op_dict = BlockRowOperator({'a': I, 'b': 2*I, 'c': 3*I}) >>> op_dict({'a': x, 'b': x, 'c': x}) Array([ 6., 12.], dtype=float32)
- mv(x)[source]#
- Parameters:
x (PyTree[jaxtyping.Inexact[Array, '_b']])
- Return type:
PyTree[jaxtyping.Inexact[Array, ‘_a’]]
- property out_structure: PyTree[jax._src.api.ShapeDtypeStruct]#
- class furax.core._blocks.BlockColumnOperator(blocks)[source]#
Bases:
AbstractBlockOperatorOperator that vertically stacks block operators: [A; B; C].
Applies each block to the same input and returns a pytree of outputs. All blocks must have the same input structure.
Transpose: BlockColumnOperator.T = BlockRowOperator
- Variables:
blocks (jaxtyping.PyTree[furax.core._base.AbstractLinearOperator]) – A pytree of operators (all with identical input structure).
- Parameters:
blocks (PyTree[furax.core._base.AbstractLinearOperator])
Example
>>> x = jnp.array([1, 2], jnp.float32) >>> I = IdentityOperator(in_structure=jax.ShapeDtypeStruct((2,), jnp.float32)) >>> op_list = BlockColumnOperator([I, I, I]) >>> op_list.as_matrix() Array([[1., 0.], [0., 1.], [1., 0.], [0., 1.], [1., 0.], [0., 1.]], dtype=float32) >>> op_list(x) [Array([1., 2.], dtype=float32), Array([1., 2.], dtype=float32), Array([1., 2.], dtype=float32)]
>>> op_dict = BlockColumnOperator({'a': I, 'b': I, 'c': I}) >>> op_dict(x) {'a': Array([1., 2.], dtype=float32), 'b': Array([1., 2.], dtype=float32), 'c': Array([1., 2.], dtype=float32)}
- mv(vector)[source]#
- Parameters:
vector (PyTree[jaxtyping.Inexact[Array, '_b']])
- Return type:
PyTree[jaxtyping.Inexact[Array, ‘_a’]]
Dense Operators#
- class furax.core._dense.DenseBlockDiagonalOperator(blocks, subscripts='ij...,j...->i...', *, in_structure=None)[source]#
Bases:
AbstractLinearOperatorOperator that applies block diagonal dense matrices via einsum.
Only the diagonal blocks are stored, making this more memory-efficient than a full dense matrix. The operation is defined by einsum subscripts.
- Variables:
blocks (jaxtyping.Inexact[Array, '...']) – The dense blocks as an array (at least 2D).
subscripts (str) – Einsum subscripts defining the operation (default: ‘ij…,j…->i…’).
- Parameters:
blocks (Inexact[Array, '...'])
subscripts (str)
in_structure (PyTree[jax._src.api.ShapeDtypeStruct])
Example
For a matrix made of three 2x4 diagonal blocks, and input block columns of three blocks of four elements each, the operator can be written as:
>>> blocks = jnp.arange(24).reshape(3, 2, 4) >>> op = DenseBlockDiagonalOperator( ... blocks, jax.ShapeDtypeStruct((3, 4), jnp.int32), 'imn,in->im') >>> op.as_matrix() Array([[ 0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0], [ 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0], [ 0, 0, 0, 0, 8, 9, 10, 11, 0, 0, 0, 0], [ 0, 0, 0, 0, 12, 13, 14, 15, 0, 0, 0, 0], [ 0, 0, 0, 0, 0, 0, 0, 0, 16, 17, 18, 19], [ 0, 0, 0, 0, 0, 0, 0, 0, 20, 21, 22, 23]], dtype=int32)
The axes along which the operator is block diagonal can be non-leading dimensions. As a matter of fact, by default, the diagonal axes are assumed to be “on the right”. The notion of block diagonality should be understood in a tensor context. The representation of this operator as a 2d matrix, which relies on the row-major layout, may not be block diagonal.
>>> blocks = jnp.arange(24).reshape(3, 2, 4) >>> op = DenseBlockDiagonalOperator(blocks, jax.ShapeDtypeStruct((2, 4), jnp.int32)) >>> op.as_matrix() Array([[ 0, 0, 0, 0, 4, 0, 0, 0], [ 0, 1, 0, 0, 0, 5, 0, 0], [ 0, 0, 2, 0, 0, 0, 6, 0], [ 0, 0, 0, 3, 0, 0, 0, 7], [ 8, 0, 0, 0, 12, 0, 0, 0], [ 0, 9, 0, 0, 0, 13, 0, 0], [ 0, 0, 10, 0, 0, 0, 14, 0], [ 0, 0, 0, 11, 0, 0, 0, 15], [16, 0, 0, 0, 20, 0, 0, 0], [ 0, 17, 0, 0, 0, 21, 0, 0], [ 0, 0, 18, 0, 0, 0, 22, 0], [ 0, 0, 0, 19, 0, 0, 0, 23]], dtype=int32)
- blocks: Inexact[Array, '...']#
Toeplitz Operators#
- class furax.core._toeplitz.SymmetricBandToeplitzOperator(band_values, *, in_structure, method='overlap_save_parallel', fft_size=None)[source]#
Bases:
AbstractLinearOperatorOperator for symmetric band Toeplitz convolution.
A Toeplitz matrix has constant diagonals. This operator is symmetric and exploits the band structure for efficient computation. For multidimensional band values, the operator is block diagonal.
- Available methods (N = matrix size, K = number of bands):
dense: dense matrix multiplicationdirect: direct convolutionfft: FFT on the whole inputoverlap_save_sequential: sequential FFT on chunksoverlap_save_parallel: batched FFT on chunks (default)
- Variables:
- Parameters:
Example
>>> tod = jnp.ones((2, 5)) >>> op = SymmetricBandToeplitzOperator( ... jnp.array([[1., 0.5], [1, 0.25]]), ... in_structure=jax.ShapeDtypeStruct(tod.shape, tod.dtype)) >>> op(tod) Array([[1.5 , 2. , 2. , 2. , 1.5 ], [1.25, 1.5 , 1.5 , 1.5 , 1.25]], dtype=float64) >>> op.as_matrix() Array([[1. , 0.5 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ], [0.5 , 1. , 0.5 , 0. , 0. , 0. , 0. , 0. , 0. , 0. ], [0. , 0.5 , 1. , 0.5 , 0. , 0. , 0. , 0. , 0. , 0. ], [0. , 0. , 0.5 , 1. , 0.5 , 0. , 0. , 0. , 0. , 0. ], [0. , 0. , 0. , 0.5 , 1. , 0. , 0. , 0. , 0. , 0. ], [0. , 0. , 0. , 0. , 0. , 1. , 0.25, 0. , 0. , 0. ], [0. , 0. , 0. , 0. , 0. , 0.25, 1. , 0.25, 0. , 0. ], [0. , 0. , 0. , 0. , 0. , 0. , 0.25, 1. , 0.25, 0. ], [0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.25, 1. , 0.25], [0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.25, 1. ]], dtype=float64)
- METHODS: ClassVar[tuple[str, ...]] = ('dense', 'direct', 'fft', 'overlap_save_parallel', 'overlap_save_sequential')#
- band_values: Float[Array, '...']#
- as_matrix()[source]#
Returns the operator as a dense matrix.
Input and output PyTrees are flattened and concatenated.
- Return type:
Inexact[Array, ‘a a’]
- class_tags: ClassVar[OperatorTag] = 3#
- property out_structure#
- transpose()#
Index and Reshape Operators#
- class furax.core._indices.IndexOperator(indices, *, in_structure, out_structure=None, unique_indices=None, _out_structure=None)[source]#
Bases:
AbstractLinearOperatorOperator that extracts elements by indexing: y = x[indices].
Supports integer indices, slices, boolean masks, and advanced indexing. When indices are unique, the operator satisfies: I @ I.T = Identity.
- Variables:
- Parameters:
Example: To extract the second element of the first axis:
>>> op = IndexOperator(1, in_structure=jax.ShapeDtypeStruct((10, 4), jax.numpy.float32))
To extract values from the penultimate axis given an array of indices:
>>> indices = jax.numpy.array([2, 4, 4, 5, 7]) >>> in_structure = jax.ShapeDtypeStruct((9, 8, 3), jax.numpy.float32) >>> op = IndexOperator((..., indices, slice(None)), in_structure=in_structure)
In order to extract values using a boolean mask, it is required to specify an output structure:
>>> indices = jax.numpy.array([True, False, True, False]) >>> in_structure = jax.ShapeDtypeStruct((4,), jax.numpy.float32) >>> out_structure = jax.ShapeDtypeStruct((2,), jax.numpy.float32) >>> op = IndexOperator(indices, in_structure=in_structure, out_structure=out_structure)
So it is usually better to specify an index mask:
>>> op = IndexOperator(jnp.where(indices), in_structure=in_structure)
- mv(x)[source]#
- Parameters:
x (PyTree[jaxtyping.Inexact[Array, '_a']])
- Return type:
PyTree[jaxtyping.Inexact[Array, ‘_b’]]
- property out_structure: PyTree[jax._src.api.ShapeDtypeStruct]#
- class furax.core._axes.ReshapeOperator(shape, *, in_structure=None)[source]#
Bases:
AbstractRavelOrReshapeOperatorOperator that reshapes pytree leaves: y = x.reshape(shape).
This operator is orthogonal: its transpose restores the original shape.
- class furax.core._axes.MoveAxisOperator(source, destination, *, in_structure)[source]#
Bases:
AbstractLinearOperatorOperator that moves axes of pytree leaves: y = jnp.moveaxis(x, source, destination).
This operator is orthogonal: its transpose and inverse are the reverse axis move.
- Variables:
- Parameters:
Example
>>> in_structure = jax.ShapeDtypeStruct((2, 3), jnp.float32) >>> op = MoveAxisOperator(0, 1, in_structure=in_structure) >>> op(jnp.array([[1., 1, 1], [2, 2, 2]])) Array([[1., 2.], [1., 2.], [1., 2.]], dtype=float32)
- inverse()#
- Return type:
- class furax.core._axes.RavelOperator(first_axis=0, last_axis=-1, *, in_structure=None)[source]#
Bases:
AbstractRavelOrReshapeOperatorOperator that flattens pytree leaves: y = x.ravel().
By default, all dimensions are flattened. Use
first_axisandlast_axisto flatten only a subset of contiguous axes.This operator is orthogonal: its transpose restores the original shape.
- Variables:
- Parameters:
Example
To flatten the leaves of a pytree:
>>> in_structure = jax.ShapeDtypeStruct((2, 3), jnp.float32) >>> op = RavelOperator(in_structure=in_structure) >>> op.out_structure ShapeDtypeStruct(shape=(6,), dtype=float32)
To flatten the first two axes of the leaves of a pytree:
>>> import furax as fx >>> x = [jnp.ones((2, 2)), jnp.ones((2, 2, 8))] >>> op = RavelOperator(0, 1, in_structure=fx.tree.as_structure(x)) >>> op.out_structure [ShapeDtypeStruct(shape=(4,), dtype=float32), ShapeDtypeStruct(shape=(4, 8), dtype=float32)]
To flatten the last two axes of the leaves of a pytree:
>>> import furax as fx >>> x = [jnp.ones((2, 2, 3)), jnp.ones((2, 8))] >>> op = RavelOperator(-2, -1, in_structure=fx.tree.as_structure(x)) >>> op.out_structure [ShapeDtypeStruct(shape=(2, 6), dtype=float32), ShapeDtypeStruct(shape=(16,), dtype=float32)]
Tree Operators#
- class furax.core._trees.TreeOperator(tree, *, in_structure, inner_treedef=None, outer_treedef=None)[source]#
Bases:
AbstractLinearOperatorOperator defined by a generalized matrix as a pytree of pytrees.
More memory-efficient than dense matrices when the structure allows XLA optimizations (symmetries, zeros, shared elements, broadcasting).
The structure of the generalized matrix is the tree product: outer_treedef x inner_treedef, analogous to a matrix where rows are represented by an outer tree and columns by the inner tree.
Leaves within the same row must be broadcastable; no such constraint applies across rows.
- Variables:
tree (jaxtyping.PyTree[jax.Array]) – The generalized matrix as a pytree of pytrees.
outer_treedef (jaxlib._jax.pytree.PyTreeDef) – PyTreeDef for rows.
inner_treedef (jaxlib._jax.pytree.PyTreeDef) – PyTreeDef for columns.
tree_shape – (num_rows, num_cols) in terms of tree leaves.
- Parameters:
tree (PyTree[jax.Array])
in_structure (PyTree[jax._src.api.ShapeDtypeStruct])
inner_treedef (PyTreeDef)
outer_treedef (PyTreeDef)
Example
To represent the Mueller Matrix of a quarter-wave plate with a vertical fast-axis:
>>> from furax.obs.stokes import StokesIQUV >>> op = TreeOperator( ... StokesIQUV( StokesIQUV(1, 0, 0, 0), StokesIQUV(0, 1, 0, 0), StokesIQUV(0, 0, 0, -1), StokesIQUV(0, 0, 1, 0), ), in_structure=StokesIQUV.structure_for((), jnp.float32) ... ) >>> op.as_matrix() Array([[ 1., 0., 0., 0.], [ 0., 1., 0., 0.], [ 0., 0., 0., -1.], [ 0., 0., 1., 0.]], dtype=float32) >>> op(StokesIQUV(1., 1., 1., 1.)) StokesIQUV(i=1.0, q=1.0, u=-1.0, v=1.0)
- inner_treedef: PyTreeDef#
- outer_treedef: PyTreeDef#
- property tree_shape: tuple[int, int]#
Return the number of leaves of the outer and inner structures.
- mv(x)[source]#
- Parameters:
x (PyTree[jaxtyping.Inexact[Array, '_a']])
- Return type:
PyTree[jaxtyping.Inexact[Array, ‘_b’]]