import operator
import sys
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, ClassVar, Literal, cast, get_args, overload
import jax
import numpy as np
from equinox.internal._omega import _Metaω
from jax import Array
from jax.tree_util import register_dataclass
from jax.typing import ArrayLike
from jaxtyping import DTypeLike, Float, Integer, Key, PyTree, ScalarLike
if sys.version_info < (3, 11):
from typing_extensions import Self
else:
from typing import Self
from furax.exceptions import StructureError
from furax.tree import (
add,
as_promoted_dtype,
dot,
full_like,
mul,
normal_like,
ones_like,
power,
sub,
truediv,
uniform_like,
zeros_like,
)
__all__ = ['Stokes', 'StokesI', 'StokesQU', 'StokesIQU', 'StokesIQUV', 'ValidStokesType']
ValidStokesType = Literal['I', 'QU', 'IQU', 'IQUV']
[docs]
@register_dataclass
@dataclass
class Stokes(ABC):
stokes: ClassVar[ValidStokesType]
@property
def shape(self) -> tuple[int, ...]:
return cast(tuple[int, ...], getattr(self, self.stokes[0].lower()).shape)
@property
def dtype(self) -> DTypeLike:
return cast(DTypeLike, getattr(self, self.stokes[0].lower()).dtype)
@property
def structure(self) -> PyTree[jax.ShapeDtypeStruct]:
return self.structure_for(self.shape, self.dtype)
def __getitem__(self, index: Integer[Array, '...']) -> Self:
arrays = [getattr(self, stoke.lower())[index] for stoke in self.stokes]
return type(self)(*arrays)
[docs]
def __matmul__(self, other: Any) -> Any:
"""Scalar product between Stokes pytrees."""
if not isinstance(other, type(self)):
return NotImplemented
return dot(self, other)
def __abs__(self) -> Self:
result: Self = jax.tree.map(operator.abs, self)
return result
def __pos__(self) -> Self:
return self
def __neg__(self) -> Self:
result: Self = jax.tree.map(operator.neg, self)
return result
def __add__(self, other: Any) -> Self:
try:
result: Self = add(self, other)
except StructureError:
return NotImplemented
return result
def __sub__(self, other: Any) -> Self:
try:
result: Self = sub(self, other)
except StructureError:
return NotImplemented
return result
def __mul__(self, other: Any) -> Self:
try:
result: Self = mul(self, other)
except StructureError:
return NotImplemented
return result
def __truediv__(self, other: Any) -> Self:
try:
result: Self = truediv(self, other)
except StructureError:
return NotImplemented
return result
def __pow__(self, other: Any) -> Self:
if isinstance(other, _Metaω):
return NotImplemented
try:
result: Self = power(self, other)
except StructureError:
return NotImplemented
return result
def __radd__(self, other: Any) -> Self:
try:
result: Self = add(self, other)
except StructureError:
return NotImplemented
return result
def __rsub__(self, other: Any) -> Self:
try:
result: Self = sub(other, self)
except StructureError:
return NotImplemented
return result
def __rmul__(self, other: Any) -> Self:
try:
result: Self = mul(self, other)
except StructureError:
return NotImplemented
return result
def __rtruediv__(self, other: Any) -> Self:
try:
result: Self = truediv(other, self)
except StructureError:
return NotImplemented
return result
def __rpow__(self, other: Any) -> Self:
try:
result: Self = power(other, self)
except StructureError:
return NotImplemented
return result
[docs]
def ravel(self) -> Self:
"""Ravels each Stokes component."""
return jax.tree.map(lambda x: x.ravel(), self) # type: ignore[no-any-return]
[docs]
def reshape(self, shape: tuple[int, ...]) -> Self:
"""Reshape each Stokes component."""
return jax.tree.map(lambda x: x.reshape(shape), self) # type: ignore[no-any-return]
@classmethod
@overload
def class_for(cls, stokes: Literal['I']) -> type['StokesI']: ...
@classmethod
@overload
def class_for(cls, stokes: Literal['QU']) -> type['StokesQU']: ...
@classmethod
@overload
def class_for(cls, stokes: Literal['IQU']) -> type['StokesIQU']: ...
@classmethod
@overload
def class_for(cls, stokes: Literal['IQUV']) -> type['StokesIQUV']: ...
[docs]
@classmethod
def class_for(cls, stokes: str) -> type['StokesPyTreeType']:
"""Returns the StokesPyTree subclass associated to the specified Stokes types."""
if stokes not in get_args(ValidStokesType):
raise ValueError(f'Invalid Stokes parameters: {stokes!r}')
requested_cls = {
'I': StokesI,
'QU': StokesQU,
'IQU': StokesIQU,
'IQUV': StokesIQUV,
}[stokes]
return cast(type[StokesPyTreeType], requested_cls)
[docs]
@classmethod
def structure_for(
cls,
shape: tuple[int, ...],
dtype: DTypeLike = np.float64,
) -> Self:
stokes_arrays = len(cls.stokes) * [jax.ShapeDtypeStruct(shape, dtype)]
return cls(*stokes_arrays)
@classmethod
@overload
def from_stokes(cls, i: ArrayLike) -> 'StokesI': ...
@classmethod
@overload
def from_stokes(cls, i: jax.ShapeDtypeStruct) -> 'StokesI': ...
@classmethod
@overload
def from_stokes(cls, q: ArrayLike, u: ArrayLike) -> 'StokesQU': ...
@classmethod
@overload
def from_stokes(cls, q: jax.ShapeDtypeStruct, u: jax.ShapeDtypeStruct) -> 'StokesQU': ...
@classmethod
@overload
def from_stokes(cls, i: ArrayLike, q: ArrayLike, u: ArrayLike) -> 'StokesIQU': ...
@classmethod
@overload
def from_stokes(
cls, i: jax.ShapeDtypeStruct, q: jax.ShapeDtypeStruct, u: jax.ShapeDtypeStruct
) -> 'StokesIQU': ...
@classmethod
@overload
def from_stokes(
cls, i: ArrayLike, q: ArrayLike, u: ArrayLike, v: ArrayLike
) -> 'StokesIQUV': ...
@classmethod
@overload
def from_stokes(
cls,
i: jax.ShapeDtypeStruct,
q: jax.ShapeDtypeStruct,
u: jax.ShapeDtypeStruct,
v: jax.ShapeDtypeStruct,
) -> 'StokesIQUV': ...
[docs]
@classmethod
def from_stokes(
cls,
*args: Any,
**keywords: Any,
) -> 'Stokes':
"""Returns a StokesPyTree according to the specified Stokes vectors.
Examples:
>>> tod_i = Stokes.from_stokes(i)
>>> tod_qu = Stokes.from_stokes(q, u)
>>> tod_iqu = Stokes.from_stokes(i, q, u)
>>> tod_iquv = Stokes.from_stokes(i, q, u, v)
"""
if args and keywords:
raise TypeError(
'The Stokes parameters should be specified either through positional or keyword '
'arguments.'
)
if keywords:
stokes = ''.join(sorted(keywords))
if stokes not in get_args(ValidStokesType):
raise TypeError(
f"Invalid Stokes vectors: {stokes!r}. Use 'I', 'QU', 'IQU' or 'IQUV'."
)
args = tuple(keywords[stoke] for stoke in stokes)
args = as_promoted_dtype(args)
if len(args) == 1:
return StokesI(*args)
if len(args) == 2:
return StokesQU(*args)
if len(args) == 3:
return StokesIQU(*args)
if len(args) == 4:
return StokesIQUV(*args)
raise TypeError(f'Unexpected number of Stokes parameters: {len(args)}.')
[docs]
@classmethod
@abstractmethod
def from_iquv(
cls,
i: Float[Array, '...'],
q: Float[Array, '...'],
u: Float[Array, '...'],
v: Float[Array, '...'],
) -> Self:
"""Returns a StokesPyTree ignoring the Stokes components not in the type."""
[docs]
@classmethod
def zeros(cls, shape: tuple[int, ...], dtype: DTypeLike = float) -> Self:
return zeros_like(cls.structure_for(shape, dtype))
[docs]
@classmethod
def ones(cls, shape: tuple[int, ...], dtype: DTypeLike = float) -> Self:
return ones_like(cls.structure_for(shape, dtype))
[docs]
@classmethod
def full(cls, shape: tuple[int, ...], fill_value: ScalarLike, dtype: DTypeLike = float) -> Self:
return full_like(cls.structure_for(shape, dtype), fill_value)
[docs]
@classmethod
def normal(cls, key: Key[Array, ''], shape: tuple[int, ...], dtype: DTypeLike = float) -> Self:
return normal_like(cls.structure_for(shape, dtype), key)
[docs]
@register_dataclass
@dataclass
class StokesI(Stokes):
stokes: ClassVar[ValidStokesType] = 'I'
i: Array
[docs]
@classmethod
def from_iquv(
cls,
i: Float[Array, '...'],
q: Float[Array, '...'],
u: Float[Array, '...'],
v: Float[Array, '...'],
) -> Self:
return cls(i)
[docs]
@register_dataclass
@dataclass
class StokesQU(Stokes):
stokes: ClassVar[ValidStokesType] = 'QU'
q: Array
u: Array
[docs]
@classmethod
def from_iquv(
cls,
i: Float[Array, '...'],
q: Float[Array, '...'],
u: Float[Array, '...'],
v: Float[Array, '...'],
) -> Self:
q, u = as_promoted_dtype((q, u))
return cls(q, u)
[docs]
@register_dataclass
@dataclass
class StokesIQU(Stokes):
stokes: ClassVar[ValidStokesType] = 'IQU'
i: Array
q: Array
u: Array
[docs]
@classmethod
def from_iquv(
cls,
i: Float[Array, '...'],
q: Float[Array, '...'],
u: Float[Array, '...'],
v: Float[Array, '...'],
) -> Self:
i, q, u = as_promoted_dtype((i, q, u))
return cls(i, q, u)
[docs]
@register_dataclass
@dataclass
class StokesIQUV(Stokes):
stokes: ClassVar[ValidStokesType] = 'IQUV'
i: Array
q: Array
u: Array
v: Array
[docs]
@classmethod
def from_iquv(
cls,
i: Float[Array, '...'],
q: Float[Array, '...'],
u: Float[Array, '...'],
v: Float[Array, '...'],
) -> Self:
i, q, u, v = as_promoted_dtype((i, q, u, v))
return cls(i, q, u, v)
StokesPyTreeType = StokesI | StokesQU | StokesIQU | StokesIQUV