import functools
import inspect
import sys
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import IntFlag, auto
from typing import Any, ClassVar, TypeVar, overload
if sys.version_info >= (3, 11):
from typing import dataclass_transform
else:
from typing_extensions import dataclass_transform
import jax
import jax.numpy as jnp
import lineax as lx
from jax import Array
from jax._src.typing import DType
from jax.tree_util import Partial
from jaxtyping import Inexact, PyTree, Scalar, ScalarLike
from furax._config import Config, ConfigState
from furax.tree import zeros_like
from .utils import register_dataclass_with_keys
class OperatorTag(IntFlag):
"""Flags representing properties of linear operators."""
NONE = 0
SQUARE = auto()
SYMMETRIC = auto()
ORTHOGONAL = auto()
DIAGONAL = auto()
TRIDIAGONAL = auto()
LOWER_TRIANGULAR = auto()
UPPER_TRIANGULAR = auto()
POSITIVE_SEMIDEFINITE = auto()
NEGATIVE_SEMIDEFINITE = auto()
[docs]
@register_dataclass_with_keys
@dataclass(frozen=True)
@dataclass_transform(frozen_default=True, field_specifiers=(field,))
class AbstractLinearOperator(ABC):
"""Base class for linear operators."""
# Class-level tags (set by decorators)
class_tags: ClassVar[OperatorTag] = OperatorTag.NONE
in_structure: PyTree[jax.ShapeDtypeStruct] = field(
kw_only=True, metadata={'static': True}, default=None
)
def __init_subclass__(cls, **kwargs: Any) -> None:
dataclass(frozen=True)(cls)
register_dataclass_with_keys(cls)
def __post_init__(self) -> None:
if self.in_structure is None:
raise ValueError('The input structure of the operator is not defined.')
@property
def tags(self) -> OperatorTag:
"""Get the tags for this operator instance."""
return self.class_tags
# Operator properties with default False values
@property
def is_square(self) -> bool:
return bool(self.tags & OperatorTag.SQUARE)
@property
def is_symmetric(self) -> bool:
return bool(self.tags & OperatorTag.SYMMETRIC)
@property
def is_orthogonal(self) -> bool:
return bool(self.tags & OperatorTag.ORTHOGONAL)
@property
def is_diagonal(self) -> bool:
return bool(self.tags & OperatorTag.DIAGONAL)
@property
def is_tridiagonal(self) -> bool:
return bool(self.tags & OperatorTag.TRIDIAGONAL)
@property
def is_lower_triangular(self) -> bool:
return bool(self.tags & OperatorTag.LOWER_TRIANGULAR)
@property
def is_upper_triangular(self) -> bool:
return bool(self.tags & OperatorTag.UPPER_TRIANGULAR)
@property
def is_positive_semidefinite(self) -> bool:
return bool(self.tags & OperatorTag.POSITIVE_SEMIDEFINITE)
@property
def is_negative_semidefinite(self) -> bool:
return bool(self.tags & OperatorTag.NEGATIVE_SEMIDEFINITE)
@overload
def __call__(
self, *, solver: lx.AbstractLinearSolver, **keywords: Any
) -> 'AbstractLinearOperator': ...
@overload
def __call__(self, x: PyTree[jax.ShapeDtypeStruct]) -> PyTree[jax.ShapeDtypeStruct]: ...
def __call__(
self, x: PyTree[jax.ShapeDtypeStruct] | None = None, **keywords: Any
) -> 'AbstractLinearOperator | PyTree[jax.ShapeDtypeStruct]':
if keywords:
raise TypeError('No keywords is allowed in AbstractLinearOperator __call__ method')
if isinstance(x, AbstractLinearOperator):
raise ValueError("Use '@' to compose operators")
return self.mv(x)
def __matmul__(self, other: Any) -> 'AbstractLinearOperator':
if not isinstance(other, AbstractLinearOperator):
return NotImplemented
if self.in_structure != other.out_structure:
msg = (
f'Incompatible linear operator structures: '
f'self.in_structure={self.in_structure}, '
f'other.out_structure={other.out_structure}'
)
raise ValueError(msg)
if isinstance(other, CompositionOperator):
return NotImplemented
if isinstance(other, AbstractLazyInverseOperator):
if other.operator is self:
return IdentityOperator(in_structure=self.in_structure)
return CompositionOperator([self, other])
def __add__(self, other: Any) -> 'AbstractLinearOperator':
if not isinstance(other, AbstractLinearOperator):
return NotImplemented
if self.in_structure != other.in_structure:
raise ValueError('Incompatible linear operator input structures')
if self.out_structure != other.out_structure:
raise ValueError('Incompatible linear operator output structures')
if isinstance(other, AdditionOperator):
return NotImplemented
return AdditionOperator([self, other])
def __sub__(self, other: Any) -> 'AbstractLinearOperator':
if not isinstance(other, AbstractLinearOperator):
return NotImplemented
if self.in_structure != other.in_structure:
raise ValueError('Incompatible linear operator input structures')
if self.out_structure != other.out_structure:
raise ValueError('Incompatible linear operator output structures')
result: AbstractLinearOperator = self + (-other)
return result
def __mul__(self, other: ScalarLike) -> 'AbstractLinearOperator':
result = other * self
assert isinstance(result, AbstractLinearOperator) # mypy
return result
def __rmul__(self, other: ScalarLike) -> 'AbstractLinearOperator':
other = jnp.asarray(other)
if other.shape != ():
raise ValueError('Can only multiply AbstractLinearOperators by scalars.')
return HomothetyOperator(other, in_structure=self.out_structure) @ self
def __truediv__(self, other: ScalarLike) -> 'AbstractLinearOperator':
other = jnp.asarray(other)
if other.shape != ():
raise ValueError('Can only divide AbstractLinearOperators by scalars.')
return HomothetyOperator(1 / other, in_structure=self.out_structure) @ self
def __pos__(self) -> 'AbstractLinearOperator':
return self
def __neg__(self) -> 'AbstractLinearOperator':
return (-1) * self
[docs]
@abstractmethod
def mv(self, x: PyTree[Inexact[Array, ' _a']]) -> PyTree[Inexact[Array, ' _b']]: ...
[docs]
def reduce(self) -> 'AbstractLinearOperator':
"""Returns a linear operator with a reduced structure."""
return self
[docs]
def as_matrix(self) -> Inexact[Array, 'a b']:
"""Returns the operator as a dense matrix.
Input and output PyTrees are flattened and concatenated.
"""
in_pytree = zeros_like(self.in_structure)
in_leaves_ref, in_treedef = jax.tree.flatten(in_pytree)
matrix = jnp.empty((self.out_size, self.in_size), dtype=self.out_promoted_dtype)
jcounter = 0
for ileaf, leaf in enumerate(in_leaves_ref):
def body(index, carry): # type: ignore[no-untyped-def]
matrix, jcounter = carry
zeros = in_leaves_ref.copy()
zeros[ileaf] = leaf.ravel().at[index].set(1).reshape(leaf.shape)
in_pytree = jax.tree.unflatten(in_treedef, zeros)
out_pytree = self.mv(in_pytree)
out_leaves = [leaf.ravel() for leaf in jax.tree.leaves(out_pytree)]
matrix = matrix.at[:, jcounter].set(jnp.concatenate(out_leaves))
jcounter += 1
return matrix, jcounter
matrix, jcounter = jax.lax.fori_loop(0, leaf.size, body, (matrix, jcounter))
return matrix
[docs]
def transpose(self) -> 'AbstractLinearOperator':
return TransposeOperator(self)
@property
def T(self) -> 'AbstractLinearOperator':
return self.transpose()
[docs]
def inverse(self) -> 'AbstractLinearOperator':
return InverseOperator(self)
@property
def I(self) -> 'AbstractLinearOperator': # noqa: E743
return self.inverse()
@property
def out_structure(self) -> PyTree[jax.ShapeDtypeStruct]:
return jax.eval_shape(self.mv, self.in_structure)
@property
def in_size(self) -> int:
"""The number of elements in the input PyTree."""
return sum(_.size for _ in jax.tree.leaves(self.in_structure))
@property
def out_size(self) -> int:
"""The number of elements in the output PyTree."""
return sum(_.size for _ in jax.tree.leaves(self.out_structure))
@property
def in_promoted_dtype(self) -> DType[Any]:
"""Returns the promoted data type of the operator's input leaves."""
leaves = jax.tree.leaves(self.in_structure)
return jnp.result_type(*leaves)
@property
def out_promoted_dtype(self) -> DType[Any]:
"""Returns the promoted data type of the operator's output leaves."""
leaves = jax.tree.leaves(self.out_structure)
return jnp.result_type(*leaves)
T = TypeVar('T', bound=AbstractLinearOperator)
def square(cls: type[T]) -> type[T]:
"""Mark an operator as square."""
cls.class_tags |= OperatorTag.SQUARE
cls.out_structure = property(lambda self: self.in_structure) # type: ignore[assignment,method-assign]
return cls
def symmetric(cls: type[T]) -> type[T]:
"""Mark an operator as symmetric (implies square)."""
square(cls)
cls.class_tags |= OperatorTag.SYMMETRIC
cls.transpose = lambda self: self # type: ignore[method-assign]
return cls
def orthogonal(cls: type[T]) -> type[T]:
"""Mark an operator as orthogonal (implies square)."""
square(cls)
cls.class_tags |= OperatorTag.ORTHOGONAL
cls.inverse = cls.transpose # type: ignore[method-assign]
return cls
def diagonal(cls: type[T]) -> type[T]:
"""Mark an operator as diagonal (implies symmetric, which implies square)."""
symmetric(cls)
cls.class_tags |= OperatorTag.DIAGONAL
return cls
def tridiagonal(cls: type[T]) -> type[T]:
"""Mark an operator as tridiagonal (implies square)."""
square(cls)
cls.class_tags |= OperatorTag.TRIDIAGONAL
return cls
def lower_triangular(cls: type[T]) -> type[T]:
"""Mark an operator as lower triangular (implies square)."""
square(cls)
cls.class_tags |= OperatorTag.LOWER_TRIANGULAR
return cls
def upper_triangular(cls: type[T]) -> type[T]:
"""Mark an operator as upper triangular (implies square)."""
square(cls)
cls.class_tags |= OperatorTag.UPPER_TRIANGULAR
return cls
def positive_semidefinite(cls: type[T]) -> type[T]:
"""Mark an operator as positive semi-definite (implies square)."""
square(cls)
cls.class_tags |= OperatorTag.POSITIVE_SEMIDEFINITE
return cls
def negative_semidefinite(cls: type[T]) -> type[T]:
"""Mark an operator as negative semi-definite (implies square)."""
square(cls)
cls.class_tags |= OperatorTag.NEGATIVE_SEMIDEFINITE
return cls
class AdditionOperator(AbstractLinearOperator):
"""An operator that adds two operators, as in C = A + B."""
operands: PyTree[AbstractLinearOperator]
def __init__(self, operands: PyTree[AbstractLinearOperator]) -> None:
object.__setattr__(self, 'operands', operands)
super().__init__(in_structure=self.operand_leaves[0].in_structure)
# Tag propagation properties
@property
def is_square(self) -> bool:
return super().is_square or self.operand_leaves[0].is_square
@property
def is_symmetric(self) -> bool:
return super().is_symmetric or all(op.is_symmetric for op in self.operand_leaves)
@property
def is_diagonal(self) -> bool:
return super().is_diagonal or all(op.is_diagonal for op in self.operand_leaves)
@property
def is_positive_semidefinite(self) -> bool:
return super().is_positive_semidefinite or all(
op.is_positive_semidefinite for op in self.operand_leaves
)
@property
def is_negative_semidefinite(self) -> bool:
return super().is_negative_semidefinite or all(
op.is_negative_semidefinite for op in self.operand_leaves
)
def mv(self, x: PyTree[Inexact[Array, ' _a']]) -> PyTree[Inexact[Array, ' _b']]:
operands = self.operand_leaves
y = operands[0](x)
for operand in operands[1:]:
y = jax.tree.map(jnp.add, y, operand(x))
return y
def transpose(self) -> AbstractLinearOperator:
return AdditionOperator(self._tree_map(lambda operand: operand.T))
def __add__(self, other: AbstractLinearOperator) -> 'AdditionOperator':
if not isinstance(other, AbstractLinearOperator):
return NotImplemented
if self.in_structure != other.in_structure:
raise ValueError('Incompatible linear operator input structures')
if self.out_structure != other.out_structure:
raise ValueError('Incompatible linear operator output structures')
if isinstance(other, AdditionOperator):
operands = other.operand_leaves
else:
operands = [other]
return AdditionOperator(self.operand_leaves + operands)
def __radd__(self, other: AbstractLinearOperator) -> 'AdditionOperator':
if not isinstance(other, AbstractLinearOperator):
return NotImplemented
if self.in_structure != other.in_structure:
raise ValueError('Incompatible linear operator input structures')
if self.out_structure != other.out_structure:
raise ValueError('Incompatible linear operator output structures')
return AdditionOperator([other] + self.operand_leaves)
def __neg__(self) -> 'AdditionOperator':
return AdditionOperator(self._tree_map(lambda operand: (-1) * operand))
@property
def out_structure(self) -> PyTree[jax.ShapeDtypeStruct]:
return self.operand_leaves[0].out_structure
def as_matrix(self) -> Inexact[Array, 'a b']:
return functools.reduce(jnp.add, (operand.as_matrix() for operand in self.operand_leaves))
def reduce(self) -> AbstractLinearOperator:
operands = self._tree_map(lambda operand: operand.reduce())
operand_leaves = jax.tree.leaves(
operands, is_leaf=lambda leaf: isinstance(leaf, AbstractLinearOperator)
)
if len(operand_leaves) == 1:
leaf: AbstractLinearOperator = operand_leaves[0]
return leaf
return AdditionOperator(operands)
@property
def operand_leaves(self) -> list[AbstractLinearOperator]:
"""Returns the flat list of operators."""
return jax.tree.leaves(
self.operands, is_leaf=lambda x: isinstance(x, AbstractLinearOperator)
)
def _tree_map(self, f: Callable[..., Any], *args: Any) -> Any:
return jax.tree.map(
f,
self.operands,
*args,
is_leaf=lambda x: isinstance(x, AbstractLinearOperator),
)
class CompositionOperator(AbstractLinearOperator):
"""An operator that composes two operators, as in C = B ∘ A."""
operands: list[AbstractLinearOperator]
def __init__(self, operands: list[AbstractLinearOperator]) -> None:
object.__setattr__(self, 'operands', operands)
super().__init__(in_structure=operands[-1].in_structure)
# Tag propagation properties
@property
def is_square(self) -> bool:
result: bool = super().is_square or (
self.operands[0].out_structure == self.operands[-1].in_structure
)
return result
@property
def is_diagonal(self) -> bool:
return super().is_diagonal or all(op.is_diagonal for op in self.operands)
def mv(self, x: PyTree[Inexact[Array, ' _a']]) -> PyTree[Inexact[Array, ' _b']]:
for operand in reversed(self.operands):
x = operand.mv(x)
return x
def transpose(self) -> AbstractLinearOperator:
return CompositionOperator([_.T for _ in reversed(self.operands)])
def __matmul__(self, other: AbstractLinearOperator) -> 'CompositionOperator':
if not isinstance(other, AbstractLinearOperator):
return NotImplemented
if self.in_structure != other.out_structure:
msg = (
f'Incompatible linear operator structures: '
f'self.in_structure={self.in_structure}, '
f'other.out_structure={other.out_structure}'
)
raise ValueError(msg)
if isinstance(other, CompositionOperator):
operands = other.operands
else:
operands = [other]
return CompositionOperator(self.operands + operands)
def __rmatmul__(self, other: AbstractLinearOperator) -> 'CompositionOperator':
if not isinstance(other, AbstractLinearOperator):
return NotImplemented
if self.out_structure != other.in_structure:
msg = (
f'Incompatible linear operator structures: '
f'self.in_structure={self.in_structure}, '
f'other.out_structure={other.out_structure}'
)
raise ValueError(msg)
return CompositionOperator([other] + self.operands)
def reduce(self) -> AbstractLinearOperator:
"""Returns a linear operator with a reduced structure."""
from .rules import AlgebraicReductionRule
operands = AlgebraicReductionRule().apply([operand.reduce() for operand in self.operands])
if len(operands) == 0:
return IdentityOperator(in_structure=self.in_structure)
if len(operands) == 1:
return operands[0]
return CompositionOperator(operands)
@property
def out_structure(self) -> PyTree[jax.ShapeDtypeStruct]:
return self.operands[0].out_structure
class _AbstractLazyDualOperator(AbstractLinearOperator):
operator: AbstractLinearOperator
def __post_init__(self) -> None:
# Here we prefer to use __post_init__ over the __init__ constructor: when __init__ constructors are not
# present, dataclasses write their own based on the specification of the fields and the resulting constructor
# does not call the parent constructor.
# Because of that, once an __init__ method is written in an AbstractLinearOperator, all subclasses must
# also write their own to explicitly call the parent __init__ constructor.
# By using the __post_init__ mechanism, subclasses can still define their new fields without having to write
# an __init__ or __post_init__ method.
object.__setattr__(self, 'in_structure', self.operator.out_structure)
@property
def out_structure(self) -> PyTree[jax.ShapeDtypeStruct]:
return self.operator.in_structure
class TransposeOperator(_AbstractLazyDualOperator):
def mv(self, x: PyTree[Inexact[Array, ' _a']]) -> PyTree[Inexact[Array, ' _b']]:
transpose = jax.linear_transpose(self.operator.mv, self.operator.in_structure)
return transpose(x)[0]
def transpose(self) -> AbstractLinearOperator:
return self.operator
class AbstractLazyInverseOperator(_AbstractLazyDualOperator):
def __call__(
self, x: PyTree[jax.ShapeDtypeStruct] | None = None, /, **keywords: Any
) -> AbstractLinearOperator | PyTree[jax.ShapeDtypeStruct]:
if x is not None:
if keywords:
raise ValueError(
'The application of a vector to inverse operator cannot be parametrized. '
'For example, instead of A.I(x, throw=True), use A.I(throw=True)(x).'
)
return self.mv(x)
return self
def __matmul__(self, other: Any) -> AbstractLinearOperator:
if self.operator is other:
return IdentityOperator(in_structure=self.in_structure)
return super().__matmul__(other)
def inverse(self) -> AbstractLinearOperator:
return self.operator
def as_matrix(self) -> Inexact[Array, 'a b']:
matrix: Array = jnp.linalg.inv(self.operator.as_matrix())
return matrix
MISSING = object()
class InverseOperator(AbstractLazyInverseOperator):
config: ConfigState = field(
kw_only=True, metadata={'static': True}, default_factory=lambda: Config.instance()
)
def __post_init__(self) -> None:
super().__post_init__()
if self.operator.in_structure != self.operator.out_structure:
raise ValueError('Only square operators can be inverted.')
object.__setattr__(self, 'operator', self.operator.reduce())
def __call__(
self,
x: PyTree[jax.ShapeDtypeStruct] | None = None,
/,
*,
solver: lx.AbstractLinearSolver | None = None,
throw: bool | None = None,
callback: Callable[[lx.Solution], None] | object = MISSING,
**options: Any,
) -> AbstractLinearOperator | PyTree[jax.ShapeDtypeStruct]:
config_options = {}
if solver is not None:
config_options['solver'] = solver
if throw is not None:
config_options['solver_throw'] = throw
if callback is not MISSING:
config_options['solver_callback'] = callback
if options:
if 'solver_options' in options:
msg = 'pass solver_options (preconditioner, etc.) directly as keyword arguments'
raise ValueError(msg)
config_options['solver_options'] = options
if x is None and config_options:
with Config(**config_options):
return InverseOperator(self.operator)
return super().__call__(x, **config_options)
def mv(self, x: PyTree[Inexact[Array, ' _a']]) -> PyTree[Inexact[Array, ' _b']]:
from furax.interfaces.lineax import as_lineax_operator
solver = self.config.solver
throw = self.config.solver_throw
options = self.config.solver_options.copy()
A = as_lineax_operator(self.operator, OperatorTag.POSITIVE_SEMIDEFINITE)
if preconditioner := options.get('preconditioner'):
if not isinstance(preconditioner, AbstractLinearOperator):
raise TypeError('The preconditioner must be an instance of AbstractLinearOperator.')
options['preconditioner'] = as_lineax_operator(
preconditioner, OperatorTag.POSITIVE_SEMIDEFINITE
)
solution = lx.linear_solve(A, x, solver=solver, throw=throw, options=options)
jax.debug.callback(self.config.solver_callback, solution)
return solution.value
@orthogonal
class AbstractLazyInverseOrthogonalOperator(TransposeOperator, AbstractLazyInverseOperator):
pass
@orthogonal
@diagonal
@positive_semidefinite
class IdentityOperator(AbstractLinearOperator):
"""Operator that returns its input unchanged: I(x) = x.
The identity operator is diagonal, orthogonal and positive semi-definite.
Its transpose and inverse are itself.
Example:
>>> I = IdentityOperator(in_structure=jax.ShapeDtypeStruct((3,), jnp.float32))
>>> x = jnp.array([1.0, 2.0, 3.0])
>>> I(x)
Array([1., 2., 3.], dtype=float32)
"""
def __matmul__(self, other: Any) -> AbstractLinearOperator:
if not isinstance(other, AbstractLinearOperator):
return NotImplemented
return other
def mv(self, x: PyTree[Inexact[Array, '...']]) -> PyTree[Inexact[Array, '...']]:
return x
def as_matrix(self) -> Inexact[Array, 'a b']:
return jnp.identity(self.in_size, dtype=self.in_promoted_dtype)
@diagonal
class HomothetyOperator(AbstractLinearOperator):
"""Operator that multiplies its input by a scalar: H(x) = k * x.
The homothety operator is diagonal, symmetric and positive semi-definite
(for positive values). Two consecutive homotheties compose by multiplying
their scalars: H(k1) @ H(k2) = H(k1 * k2).
Attributes:
value: The scalar multiplier.
Example:
>>> H = HomothetyOperator(2.0, in_structure=jax.ShapeDtypeStruct((3,), jnp.float32))
>>> x = jnp.array([1.0, 2.0, 3.0])
>>> H(x)
Array([2., 4., 6.], dtype=float32)
>>> H.I(x) # Inverse: divides by 2
Array([0.5, 1. , 1.5], dtype=float32)
"""
value: Scalar | int | float
def __matmul__(self, other: Any) -> AbstractLinearOperator:
if isinstance(other, HomothetyOperator):
return HomothetyOperator(self.value * other.value, in_structure=self.in_structure)
return super().__matmul__(other)
def mv(self, x: PyTree[Inexact[Array, '...']]) -> PyTree[Inexact[Array, '...']]:
return jax.tree.map(lambda leaf: self.value * leaf, x)
def inverse(self) -> AbstractLinearOperator:
return HomothetyOperator(1 / self.value, in_structure=self.in_structure)
def as_matrix(self) -> Inexact[Array, 'a b']:
return self.value * jnp.identity(self.in_size, dtype=self.out_promoted_dtype)
def asoperator(
func: Callable[..., Any],
*,
in_structure: PyTree[jax.ShapeDtypeStruct],
**keywords: Any,
) -> AbstractLinearOperator:
"""Wraps a function into an operator.
Args:
func: The function to wrap.
**keywords: Keyword arguments to pass to the function.
Usage:
>>> op = asoperator(lambda x: 2*x, in_structure=jax.ShapeDtypeStruct((), jnp.float32))
>>> op.I(1.)
Array(0.5, dtype=float32)
Returns:
An operator wrapping the function.
"""
class Operator(AbstractLinearOperator):
def mv(self, x: PyTree[Inexact[Array, ' _a']]) -> PyTree[Inexact[Array, ' _b']]:
return partial_func(x)
if not hasattr(func, 'lower'):
func = jax.jit(func)
partial_func = Partial(func, **keywords)
_check_params(partial_func)
return Operator(in_structure=in_structure)
def _check_params(func: Callable[..., Any]) -> None:
"""Returns the number of non-default parameters a function has.
Raises:
TypeError: If the function has no positional arguments or has keyword-only arguments
without default values.
"""
sig = inspect.signature(func)
params = sig.parameters.values()
count = sum(
1
for p in params
if p.kind
in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
)
if count == 0:
raise TypeError('The function must have at least one positional argument.')
names = [
p.name
for p in params
if p.default == inspect.Parameter.empty and p.kind == inspect.Parameter.KEYWORD_ONLY
]
if len(names) != 0:
raise TypeError(
f'The function cannot have keyword-only arguments without default values: {names}.'
)
names = [
p.name
for p in params
if p.default == inspect.Parameter.empty
and p.kind
in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
]
if len(names) > 1:
raise TypeError(
f'The function can only have one positional argument without default value: {names}.'
)