Source code for tensossht.transforms.transforms
from typing import Callable, Optional, Union
import numpy as np
import tensorflow as tf
from tensossht.sampling import HarmonicSampling, ImageSamplingBase, ImageSamplingSchemes
[docs]class HarmonicTransform:
"""Fast transform between image space and spherical harmonic space."""
def __init__(
self,
forward: Callable,
inverse: Callable,
image_sampling: ImageSamplingBase,
harmonic_sampling: HarmonicSampling,
):
self.forward = forward
self.inverse = inverse
self.sampling = image_sampling
"""Image-space sampling"""
self.harmonic_sampling = harmonic_sampling
"""Harmonic-space sampling"""
@property
def lmax(self):
return self.harmonic_sampling.lmax
@property
def lmin(self):
return self.harmonic_sampling.lmin
@property
def mmax(self):
return self.harmonic_sampling.mmax
@property
def mmin(self):
return self.harmonic_sampling.mmin
@property
def smax(self):
return self.harmonic_sampling.smax
@property
def smin(self):
return self.harmonic_sampling.smin
@property
def labels(self):
return self.harmonic_sampling.labels
@property
def llabels(self):
return self.harmonic_sampling.labels[0]
@property
def mlabels(self):
return self.harmonic_sampling.labels[1]
@property
def slabels(self):
return self.harmonic_sampling.labels[2]
@property
def ncoeffs(self):
return self.harmonic_sampling.labels.shape[1]
@property
def thetas(self):
return self.sampling.thetas
@property
def phis(self):
return self.sampling.phis
@property
def grid(self):
return self.sampling.grid
@property
def points(self):
"""Image-space points."""
return tf.transpose(tf.reshape(self.grid, (2, -1)))
@property
def real_dtype(self):
if self.sampling.dtype == tf.complex64:
return tf.float32
if self.sampling.dtype == tf.complex128:
return tf.float64
return self.sampling.dtype
@property
def complex_dtype(self):
if self.sampling.dtype == tf.float32:
return tf.complex64
if self.sampling.dtype == tf.float64:
return tf.complex128
return self.sampling.dtype
@property
def real(self):
"""Transform for real signals.
If the transform is already for real signals (:math:`m_\\mathrm{min} \\geq 0`),
then this property returns self. Otherwise it returns an equivalent transform
with :math:`m_\\mathrm{min} = 0`.
"""
from tensossht.sampling import image_sampling_scheme
if self.mmin >= 0:
return self
if self.mmax < 0:
msg = f"Weird mmax={self.mmax}. Cannot infer transform for real signals."
raise RuntimeError(msg)
return harmonic_transform(
lmax=self.lmax,
lmin=self.lmin,
mmin=0,
mmax=self.mmax,
smax=self.smax,
smin=self.smin,
dtype=self.real_dtype,
sampling=image_sampling_scheme(self.sampling.__class__.__name__),
)
@property
def complex(self):
"""Transform for complex signals.
If the transform is already for complex signals
(:math:`m_\\mathrm{min} \\leq 0`), then this property returns self. Otherwise it
returns an equivalent transform with :math:`m_\\mathrm{min} = -l_\\mathrm{max}`.
"""
from tensossht.sampling import image_sampling_scheme
if self.mmin <= 0:
return self
return harmonic_transform(
lmax=self.lmax,
lmin=self.lmin,
mmin=-self.lmax,
mmax=self.mmax,
smax=self.smax,
smin=self.smin,
dtype=self.complex_dtype,
sampling=image_sampling_scheme(self.sampling.__class__.__name__),
)
def harmonic_transform(
lmax: Optional[int] = None,
lmin: Optional[int] = 0,
mmax: Optional[int] = None,
mmin: Optional[int] = None,
smax: Optional[int] = None,
smin: Optional[int] = None,
spin: Optional[int] = None,
sampling: Union[str, ImageSamplingSchemes] = ImageSamplingSchemes.MW,
dtype: Union[str, np.dtype, tf.DType] = tf.float32,
) -> HarmonicTransform:
"""Creates a harmonic transform.
Args:
lmax: Maximum degree. If None or absent, labels should be given.
lmin: Minimum degree, defaults to zero. Ignored if lmax is absent.
mmax: Maximum order. If None, defaults to ``lmax``. Ignored if lmax is absent.
mmin: Minimum order. If None, defaults to ``-lmax``. For real image-space
signals, use ``mmin=0``. Ignored if lmax is absent.
labels: 2-d tensor with the l and m labels. Alternative to lmax and friends. If
absent or None, lmax should be given.
sampling: image-space sampling scheme. Can be a string, e.g. "MW", or one of
:py:enum:`~tensossht.sampling.ImageSamplingSchemes`.
dtype: Underlying floating point type.
"""
from tensossht.sampling import harmonic_sampling_scheme, image_sampling_scheme
from tensossht.transforms.forward import forward_transform
from tensossht.transforms.inverse import inverse_transform
hsampling = harmonic_sampling_scheme(
lmax=lmax,
lmin=lmin,
mmax=mmax,
mmin=mmin,
smin=smin,
smax=smax,
spin=spin,
compact_spin=False,
)
isampling = image_sampling_scheme(sampling).value(hsampling.lmax, dtype)
forward = forward_transform(hsampling, dtype=dtype, sampling=sampling)
inverse = inverse_transform(hsampling, dtype=dtype, sampling=sampling)
return HarmonicTransform(forward, inverse, isampling, hsampling)