Source code for tensossht.layers.harmonic_transforms

from pathlib import Path
from typing import Callable, Optional, Tuple, Union, cast

import numpy as np
import tensorflow as tf

from tensossht.sampling import (
    HarmonicAxes,
    HarmonicSampling,
    ImageAxes,
    ImageSamplingBase,
    ImageSamplingSchemes,
)

__doc__ = Path(__file__).with_suffix(".rst").read_text()


class BaseHarmonicTransformLayer(tf.keras.layers.Layer):
    def __init__(
        self,
        is_real: bool = False,
        sampling: Union[str, ImageSamplingSchemes, ImageSamplingBase] = "mw",
        spin: int = 0,
        dtype: Union[str, np.dtype, tf.DType] = tf.float32,
        in_axes: Optional[Union[HarmonicAxes, ImageAxes]] = None,
        out_axes: Optional[Union[HarmonicAxes, ImageAxes]] = None,
        **kwargs,
    ):
        from tensossht.sampling import image_sampling_scheme

        rdtype = tf.dtypes.as_dtype(dtype).real_dtype
        cdtype = tf.complex(tf.zeros(0, dtype=rdtype), tf.zeros(0, dtype=rdtype)).dtype

        if is_real and spin != 0:
            raise ValueError("Non-zero spins are complex")

        if kwargs.pop("trainable", False):
            raise ValueError("This layer is not trainable")

        super().__init__(trainable=False, dtype=rdtype if is_real else cdtype, **kwargs)
        self.is_real = is_real
        self.sampling = image_sampling_scheme(sampling)
        self.spin = spin
        self.transform: Callable = BaseHarmonicTransformLayer._build_not_called
        self.in_axes = in_axes
        self.out_axes = out_axes

    def __call__(self, tensor: tf.Tensor) -> tf.Tensor:
        return super().__call__(tensor)

    def call(self, inputs) -> tf.Tensor:
        in_axes = {} if self.in_axes is None else self.in_axes._asdict()
        in_axes = {f"{k}_dim": v for k, v in in_axes.items() if v is not None}
        out_axes = {} if self.out_axes is None else self.out_axes._asdict()
        out_axes = {f"{k}_dim": v for k, v in out_axes.items() if v is not None}
        if "spin_dim" in out_axes:
            out_axes["out_spin_dim"] = out_axes.pop("spin_dim")

        return self.transform(inputs, **in_axes, **out_axes)

    @staticmethod
    def _build_not_called(
        x: tf.Tensor, theta_dim: int = -2, phi_dim: int = -1, coeff_dim: int = -1
    ) -> tf.Tensor:
        raise RuntimeError("Build was not called.")

    def get_config(self):
        config = super().get_config()
        config.update(
            dict(
                is_real=self.is_real,
                sampling=self.sampling.name,
                spin=self.spin,
                in_axes=None if self.in_axes is None else self.in_axes._asdict(),
                out_axes=None if self.out_axes is None else self.out_axes._asdict(),
            )
        )
        return config

    @classmethod
    def from_config(cls, config):
        if config["in_axes"] is not None and "phi" in config["in_axes"]:
            config["in_axes"] = ImageAxes(**config.pop("in_axes"))
        elif config["in_axes"] is not None:
            config["in_axes"] = HarmonicAxes(**config.pop("in_axes"))
        if config["out_axes"] is not None and "phi" in config["out_axes"]:
            config["out_axes"] = ImageAxes(**config.pop("out_axes"))
        elif config["out_axes"] is not None:
            config["out_axes"] = HarmonicAxes(**config.pop("out_axes"))
        return cls(**config)


class BaseHarmonicSpinTransformLayer(BaseHarmonicTransformLayer):
    def __init__(
        self,
        sampling: Union[str, ImageSamplingSchemes] = "mw",
        spin: int = 0,
        dtype: Union[str, np.dtype, tf.DType] = tf.float32,
        in_axes: Optional[Union[HarmonicAxes, ImageAxes]] = None,
        out_axes: Optional[Union[HarmonicAxes, ImageAxes]] = None,
        flip_spin: bool = False,
        centered_spin: bool = True,
        **kwargs,
    ):
        super().__init__(
            is_real=False,
            sampling=sampling,
            spin=spin,
            dtype=dtype,
            out_axes=out_axes,
            in_axes=in_axes,
            **kwargs,
        )
        self.flip_spin = flip_spin
        """Flip spin direction, useful for wigner transform."""
        self.centered_spin = centered_spin
        """Defines how the spins are determined from the shape of the input.

        If `True`, then the range is :math:`[-n, n]`. If `False`, then it is
        :math:`[0, n]`
        """

    def get_config(self):
        config = super().get_config()
        config.update(dict(flip_spin=self.flip_spin, centered_spin=self.centered_spin))
        return config

    def spin_range(self, axis_size: int) -> Tuple[int, int]:
        if self.centered_spin and axis_size % 2 != 1:
            raise ValueError(
                "Expected odd number of spin channels "
                "since the channels are centered around 0."
            )
        elif self.centered_spin:
            nspins = (axis_size - 1) // 2
            return self.spin - nspins, self.spin + nspins
        return 0, axis_size - 1


[docs]class InverseLayer(BaseHarmonicTransformLayer): def __init__( self, is_real: bool = False, sampling: Union[str, ImageSamplingSchemes, ImageSamplingBase] = "mw", spin: int = 0, dtype: Union[str, np.dtype, tf.DType] = tf.float32, theta_dim: int = -2, phi_dim: int = -1, coeff_dim: int = -1, **kwargs, ): return super().__init__( is_real=is_real, sampling=sampling, spin=spin, dtype=dtype, in_axes=HarmonicAxes(coeff=coeff_dim), out_axes=ImageAxes(phi=phi_dim, theta=theta_dim), **kwargs, )
[docs] def build(self, shape: tf.TensorShape) -> None: from tensossht.transforms.inverse import inverse_transform axes = ( cast(HarmonicAxes, self.in_axes) if self.in_axes is not None else HarmonicAxes(coeff=-1) ) if not self.is_real: lmax = int(np.sqrt(shape[axes.coeff] + self.spin * self.spin)) assert lmax * lmax - self.spin * self.spin == shape[axes.coeff] mmin = None else: lmax = int(np.sqrt(1 + 8 * shape[axes.coeff]) - 1) // 2 assert lmax * (lmax + 1) == 2 * shape[axes.coeff] mmin = 0 self.transform = inverse_transform( lmax, mmin=mmin, spin=self.spin, sampling=self.sampling, dtype=self.dtype )
[docs]class InverseSpinLayer(BaseHarmonicSpinTransformLayer): def __init__( self, sampling: Union[str, ImageSamplingSchemes] = "mw", dtype: Union[str, np.dtype, tf.DType] = tf.float32, spin_dim: int = -2, coeff_dim: int = -1, out_spin_dim: int = -3, theta_dim: int = -2, phi_dim: int = -1, spin: int = 0, **kwargs, ): super().__init__( sampling=sampling, spin=spin, dtype=dtype, in_axes=HarmonicAxes(coeff=coeff_dim, spin=spin_dim), out_axes=ImageAxes(phi=phi_dim, theta=theta_dim, spin=out_spin_dim), **kwargs, )
[docs] def harmonic_sampling(self, shape: tf.TensorShape) -> HarmonicSampling: """Figures harmonic sampling from shape of the input tensor.""" from tensossht.sampling import harmonic_sampling_scheme axes = ( cast(HarmonicAxes, self.in_axes) if self.in_axes is not None else HarmonicAxes(coeff=-1) ) lmax = int(np.sqrt(shape[axes.coeff])) assert lmax * lmax == shape[axes.coeff] smin, smax = self.spin_range(shape[axes.spin]) hsampling = harmonic_sampling_scheme( lmax, mmin=None, smin=smin, smax=smax, compact_spin=False ) if self.flip_spin: labels = tf.concat((hsampling.labels[:2], -hsampling.labels[2:]), axis=0) hsampling = harmonic_sampling_scheme(labels=labels) return hsampling
[docs] def build(self, shape: tf.TensorShape) -> None: from tensossht.transforms.inverse import inverse_transform hsampling = self.harmonic_sampling(shape) self.transform = inverse_transform( hsampling, sampling=self.sampling, dtype=self.dtype )
[docs] def get_config(self): config = super().get_config() config.update(dict(flip_spin=self.flip_spin)) return config
[docs] def compute_output_shape(self, input_shape: tf.TensorShape) -> tf.TensorShape: assert isinstance(self.in_axes, HarmonicAxes) assert self.in_axes.spin is not None in_axes = self.in_axes % len(input_shape) result = [j for i, j in enumerate(input_shape) if i not in in_axes] hsampling = self.harmonic_sampling(input_shape) nthetas, nphis = self.sampling.value(hsampling.lmax).shape result += [hsampling.nspins, nthetas, nphis] assert isinstance(self.out_axes, ImageAxes) permutation = ImageAxes(spin=-3, theta=-2, phi=-1).permutation( self.out_axes, len(input_shape) + 1 ) return tf.TensorShape([result[i] for i in permutation])
[docs]class ForwardLayer(BaseHarmonicTransformLayer): def __init__( self, is_real: bool = False, sampling: Union[str, ImageSamplingSchemes] = "mw", spin: int = 0, dtype: Union[str, np.dtype, tf.DType] = tf.float32, theta_dim: int = -2, phi_dim: int = -1, coeff_dim: int = -1, **kwargs, ): return super().__init__( is_real=is_real, sampling=sampling, spin=spin, dtype=dtype, out_axes=HarmonicAxes(coeff=coeff_dim), in_axes=ImageAxes(phi=phi_dim, theta=theta_dim), **kwargs, )
[docs] def build(self, shape: tf.TensorShape) -> None: from tensossht.transforms.forward import forward_transform mmin = 0 if self.is_real else None axes = ( cast(ImageAxes, self.in_axes) if self.in_axes is not None else ImageAxes(theta=-2, phi=-1) ) if self.sampling == ImageSamplingSchemes.MW: lmax = shape[axes.theta] else: lmax = shape[axes.theta] - 1 self.transform = forward_transform( lmax, mmin=mmin, spin=self.spin, sampling=self.sampling, dtype=self.dtype )
[docs]class ForwardSpinLayer(BaseHarmonicSpinTransformLayer): def __init__( self, sampling: Union[str, ImageSamplingSchemes] = "mw", dtype: Union[str, np.dtype, tf.DType] = tf.float32, spin_dim: int = -3, theta_dim: int = -2, phi_dim: int = -1, coeff_dim: int = -1, out_spin_dim: int = -2, spin: int = 0, **kwargs, ): super().__init__( sampling=sampling, spin=spin, dtype=dtype, out_axes=HarmonicAxes(coeff=coeff_dim, spin=out_spin_dim), in_axes=ImageAxes(phi=phi_dim, theta=theta_dim, spin=spin_dim), **kwargs, )
[docs] def harmonic_sampling(self, shape: tf.TensorShape) -> HarmonicSampling: """Figures harmonic sampling from shape of the input tensor.""" from tensossht.sampling import harmonic_sampling_scheme axes = ( cast(ImageAxes, self.in_axes) if self.in_axes is not None else ImageAxes(theta=-2, phi=-1) ) if self.sampling == ImageSamplingSchemes.MW: lmax = shape[axes.theta] else: lmax = shape[axes.theta] - 1 smin, smax = self.spin_range(shape[axes.spin]) hsampling = harmonic_sampling_scheme( lmax, mmin=None, smin=smin, smax=smax, compact_spin=False ) if self.flip_spin: labels = tf.concat((hsampling.labels[:2], -hsampling.labels[2:]), axis=0) hsampling = harmonic_sampling_scheme(labels=labels) return hsampling
[docs] def build(self, shape: tf.TensorShape) -> None: from tensossht.transforms.forward import forward_transform hsampling = self.harmonic_sampling(shape) self.transform = forward_transform( hsampling, dtype=self.dtype, sampling=self.sampling )