Getting Started#
Welcome to Furax! This guide will help you get up and running with CMB analysis using Furax’s composable linear operators and specialized data structures.
Installation#
Basic Installation#
Install Furax using pip:
pip install furax
Development Installation#
For development or to access the latest features:
git clone https://github.com/your-org/furax.git
cd furax
pip install -e .[dev]
Component Separation Features#
For advanced component separation capabilities:
pip install -e .[comp_sep]
This includes additional dependencies like PySM3 for foreground modeling.
Dependencies#
Furax relies on the JAX ecosystem and scientific Python packages:
Core: JAX
Astronomy: jax-healpy, astropy
Development: pytest, pre-commit, ruff, mypy
First Steps#
Enable 64-bit precision for better numerical accuracy:
import jax
jax.config.update('jax_enable_x64', True)
Create Your First Sky Map#
import jax.random as jr
from furax.obs.landscapes import HealpixLandscape
# Create a HEALPix landscape for polarization analysis
landscape = HealpixLandscape(nside=32, stokes='IQU')
# Generate a random CMB-like sky
cmb_map = landscape.normal(jr.key(42))
print(f'Map shape: {cmb_map.shape}')
print(f'Stokes parameters: {cmb_map.stokes}')
print(f'Number of pixels: {landscape.shape[0]}')
Basic Linear Operators#
Furax provides composable linear operators that can be combined through addition, composition or block assembly. The primary interest of these operators is that they rely on a sparse representation of the underlying matrices.
import jax.numpy as jnp
import jax.random as jr
from furax import DiagonalOperator
from furax.tree import as_structure
from furax.obs.landscapes import HealpixLandscape
landscape = HealpixLandscape(nside=32, stokes='IQU')
cmb_map = landscape.normal(jr.key(42))
n_pixel = landscape.shape[0]
# Create a noise weighting operator
noise_weights = DiagonalOperator(1.0 / jnp.full(n_pixel, jnp.sqrt(n_pixel)), in_structure=landscape.structure)
# Apply the weights to the I, Q and U Stokes parameters of the map
weighted_map = noise_weights(cmb_map)
print(f'Input type: {as_structure(cmb_map)}')
print(f'Output type: {as_structure(weighted_map)}')
Operator Composition#
The power of Furax comes from composable operators:
import jax.numpy as jnp
import jax.random as jr
from furax import BlockDiagonalOperator, DiagonalOperator
from furax.obs.landscapes import HealpixLandscape
from furax.obs.stokes import StokesIQU
landscape = HealpixLandscape(nside=32, stokes='IQU')
cmb_map = landscape.normal(jr.key(42))
n_pixel = landscape.shape[0]
# Create a component-wise processing
i_processor = DiagonalOperator(1.0 * jnp.ones(n_pixel)) # No change to I
q_processor = DiagonalOperator(2.0 * jnp.ones(n_pixel)) # Amplify Q
u_processor = DiagonalOperator(0.5 * jnp.ones(n_pixel)) # Reduce U
# Combine into block diagonal operator
component_processor = BlockDiagonalOperator(StokesIQU(i_processor, q_processor, u_processor))
# The noise weights apply the same diagonal matrix to I, Q and U
noise_weights = DiagonalOperator(1.0 / jnp.full(n_pixel, jnp.sqrt(n_pixel)), in_structure=landscape.structure)
# Compose with noise weighting
full_pipeline = component_processor @ noise_weights
# Apply the full pipeline
processed_map = full_pipeline(cmb_map)
print(f'Pipeline applied successfully!')
Working with Real Data#
Reconstruction problem#
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
from furax import HomothetyOperator, IndexOperator
from furax.tree import as_structure
n_pixel = 10
pixels = jnp.arange(n_pixel, dtype=jnp.int32)
obs_key, map_key, noise_key = jr.split(jr.key(0), 3)
observed_pixels = jnp.concatenate([jr.permutation(key, pixels) for key in jr.split(obs_key, 100)])
actual_map = jr.normal(map_key, (n_pixel,))
σ_noise = 0.01
noise = jr.normal(noise_key, observed_pixels.shape) * σ_noise
acquisition_op = IndexOperator(observed_pixels, in_structure=as_structure(actual_map))
observed_values = acquisition_op(actual_map) + noise
noise_op = HomothetyOperator(σ_noise ** 2, in_structure=as_structure(observed_values))
ml = (acquisition_op.T @ noise_op.I @ acquisition_op).I @ acquisition_op.T @ noise_op.I
# Using default setup (using CG)
maximum_likelihood_map = ml(observed_values)
print('Actual map:', actual_map)
print('Reconstructed map:', maximum_likelihood_map)
print('Difference:', abs(actual_map - maximum_likelihood_map))
# Use high-precision solver for critical calculations
solver = lx.CG(rtol=1e-10, atol=1e-10, max_steps=2000)
high_precision_ml = (acquisition_op.T @ noise_op.I @ acquisition_op).I(solver=solver) @ acquisition_op.T @ noise_op.I
high_precision_map = high_precision_ml(observed_values)
print('Difference:', abs(actual_map - high_precision_map))
Pixel Masking#
import jax.numpy as jnp
import jax.random as jr
import jax_healpy as hp
from furax import IndexOperator
from furax.obs.landscapes import HealpixLandscape
from furax.tree import as_structure
GALACTIC_MAX_LATITUDE = 5. # degrees
landscape = HealpixLandscape(nside=128, stokes='IQU')
n_pixel = landscape.shape[0]
pixels = jnp.arange(n_pixel, dtype=jnp.int32)
lon, lat = hp.pix2ang(landscape.nside, pixels, lonlat=True)
good_pixels = abs(lat) > GALACTIC_MAX_LATITUDE
# Create a galactic plane mask (simplified)
mask_operator = IndexOperator(jnp.where(good_pixels), in_structure=landscape.structure)
# Apply mask
cmb_map = landscape.normal(jr.key(0))
masked_map = mask_operator(cmb_map)
print(f'Input map: {as_structure(cmb_map)}')
print(f'Output map: {as_structure(masked_map)}')
Frequency Analysis#
import jax.numpy as jnp
import jax.random as jr
from furax import IndexOperator
from furax.obs.landscapes import HealpixLandscape
from furax.tree import as_structure
# Multi-frequency analysis setup
frequencies = jnp.array([70., 150., 353.]) # GHz
landscape = HealpixLandscape(nside=128, stokes='IQU')
n_pixel = landscape.shape[0]
# Create multi-frequency landscape
obs_key, *keys = jr.split(jr.key(0), len(frequencies) + 1)
freq_maps = [landscape.normal(key) for key in keys]
pixels = jr.randint(obs_key, (100,), 0, n_pixel - 1)
projection = IndexOperator(pixels, in_structure=landscape.structure)
# get the observed pixels (noiseless)
tod = projection(freq_maps)
# The tod is a list of StokesIQU
print(f'Multi-frequency tod structure: {as_structure(tod)}')
Error Handling and Debugging#
Check Operator Properties#
# Inspect operator properties
op = ...
print(f'Operator is symmetric: {op.is_symmetric}')
print(f'Operator is positive definite: {op.is_positive_semidefinite}')
print(f'Operator input structure: {op.in_structure}')
print(f'Operator output structure: {op.out_structure}')
Matrix Visualization#
For small problems, visualize operators as matrices:
import jax.numpy as jnp
from furax import DiagonalOperator
from furax.obs.landscapes import HealpixLandscape
# Only for small operators!
small_landscape = HealpixLandscape(nside=2, stokes='I') # 48 pixels
small_weights = DiagonalOperator(1. + jnp.arange(small_landscape.shape[0]))
# Convert to explicit matrix for debugging
weight_matrix = small_weights.as_matrix()
print(f'Weight matrix shape: {weight_matrix.shape}')
print('Diagonal elements:', jnp.diag(weight_matrix))
Performance Tips#
Use JAX Transformations#
import jax
import jax.random as jr
from furax import DiagonalOperator
from furax.obs.landscapes import HealpixLandscape
# JIT compile for repeated operations
@jax.jit
def process_many_maps(operator, maps):
return jax.vmap(lambda m: operator(m))(maps)
batch_size = 10
landscape = HealpixLandscape(nside=128, stokes='IQU')
op_key = jr.key(0)
map_keys = jr.split(jr.key(1), batch_size)
op = DiagonalOperator(1 + 0.01 * jr.normal(op_key, landscape.shape))
# Generate batch of maps
map_batch = jax.vmap(landscape.normal)(map_keys)
# Process batch efficiently
processed_batch = process_many_maps(op, map_batch)
print(f'Processed {batch_size} maps in batch: {processed_batch.structure}')
Memory Management#
import jax.numpy as jnp
import jax.random as jr
from furax import DiagonalOperator
from furax.obs.landscapes import HealpixLandscape
# For large problems, avoid creating explicit matrices
landscape = HealpixLandscape(nside=256, stokes='IQU') # ~200k parameters
# Good: matrix-free operations
large_weights = DiagonalOperator(jnp.ones(landscape.shape[0]), in_structure=landscape.structure)
large_map = landscape.normal(jr.key(0))
result = large_weights(large_map)
# Avoid: large_weights.as_matrix() - would use ~160GB for float64!
Next Steps#
Now that you’ve learned the basics:
Data Structures: Explore data_structures.md for advanced Stokes parameter usage
Linear Operators: Learn about operator composition in operators.md
Examples: Try the component_separation.md and mapmaking.md tutorials
API Reference: Browse the complete API reference for all available functions
Happy analyzing!