Source code for furax._config

import contextvars
from collections.abc import Callable
from dataclasses import asdict, dataclass, field, replace
from types import TracebackType
from typing import Any

import lineax as lx
import yaml

__all__ = ['Config']


def default_solver_callback(solution: lx.Solution) -> None:
    pass


def verbose_solver_callback(solution: lx.Solution) -> None:
    num_steps = solution.stats['num_steps']
    ok = num_steps < solution.stats['max_steps']
    if ok:
        print(f'Converged in {num_steps} iterations')
    else:
        print(f'Did not converge in {num_steps} iterations')


@dataclass(frozen=True)
class ConfigState:
    solver: lx.AbstractLinearSolver = lx.CG(rtol=1e-6, atol=1e-6, max_steps=500)
    solver_throw: bool = False
    solver_options: dict[str, Any] = field(default_factory=dict)
    solver_callback: Callable[[lx.Solution], None] = default_solver_callback

    def tree_flatten(self):  # type: ignore[no-untyped-def]
        return (), asdict(self)

    @classmethod
    def tree_unflatten(cls, aux_data, children):  # type: ignore[no-untyped-def]
        return cls(**aux_data)


_config_var = contextvars.ContextVar('config', default=ConfigState())


[docs] class Config: def __init__(self, **kwargs: Any) -> None: config = _config_var.get() self._instance = replace(config, **kwargs) def __str__(self) -> str: return yaml.dump(self._instance, indent=4) def __enter__(self) -> ConfigState: self.token = _config_var.set(self._instance) return self._instance def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: _config_var.reset(self.token)
[docs] @classmethod def instance(cls) -> ConfigState: return _config_var.get()