Source code for tensossht.layers.wigner

from pathlib import Path
from typing import Callable, Optional, Sequence, Union, cast

import numpy as np
import tensorflow as tf

from tensossht.layers.harmonic_transforms import ForwardSpinLayer, InverseSpinLayer
from tensossht.sampling import HarmonicAxes, HarmonicSampling, ImageSamplingSchemes

__doc__ = Path(__file__).with_suffix(".rst").read_text()


[docs]class FourierLayer(tf.keras.layers.Layer): """Performs a Fourier transform. Performs a forward or inverse Fourier transform. The axis is moved back to the ``out_axis`` position, defaulting to last position to avoid unnecessary transforms. """ def __init__( self, is_forward: bool = True, is_real: bool = False, axis: int = -1, out_axis: int = -1, is_odd_spin: Optional[bool] = None, **kwargs, ): if kwargs.pop("trainable", False): raise ValueError("This layer is not trainable") super().__init__(trainable=False, **kwargs) self.is_forward = is_forward self.axis = axis self.out_axis = out_axis self.is_real = is_real self.is_odd_spin = is_odd_spin self._input_perm: Callable[[tf.Tensor], tf.Tensor] = self._build_not_called self._output_perm: Callable[[tf.Tensor], tf.Tensor] = self._build_not_called self._transform: Callable[[tf.Tensor], tf.Tensor] = self._build_not_called def __call__(self, tensor: tf.Tensor) -> tf.Tensor: return super().__call__(tensor)
[docs] def build(self, input_shape: tf.TensorShape): from functools import partial super().build(input_shape) def permutation(axis, shape): from functools import partial axis %= len(shape) permutation = [i for i in range(len(shape)) if i != axis] permutation.append(axis) if axis == len(shape) - 1: return lambda x: x else: return partial(tf.transpose, perm=permutation) self._input_perm = permutation(self.axis, input_shape) self._output_perm = permutation(self.out_axis, input_shape) if self.is_real and self.is_forward: self._transform = cast(Callable[[tf.Tensor], tf.Tensor], tf.signal.rfft) elif self.is_real and self.is_odd_spin: fft_length = tf.constant((2 * input_shape[self.axis] - 1,), dtype=tf.int32) self._transform = partial(tf.signal.irfft, fft_length=fft_length) elif self.is_real: self._transform = cast(Callable[[tf.Tensor], tf.Tensor], tf.signal.irfft) elif self.is_forward: self._transform = cast(Callable[[tf.Tensor], tf.Tensor], tf.signal.fft) else: self._transform = cast(Callable[[tf.Tensor], tf.Tensor], tf.signal.ifft)
@tf.function def call(self, inputs) -> tf.Tensor: permuted = self._input_perm(inputs) transformed = self._transform(permuted) return self._output_perm(transformed)
[docs] def get_config(self): config = super().get_config() config.update( dict( is_forward=self.is_forward, is_real=self.is_real, axis=self.axis, is_odd_spin=self.is_odd_spin, ) ) return config
[docs] def compute_output_shape(self, input_shape: tf.TensorShape) -> tf.TensorShape: if not self.is_real: nspins = input_shape[self.axis] elif self.is_forward: nspins = input_shape[self.axis] // 2 + 1 elif self.is_odd_spin: nspins = 2 * input_shape[self.axis] - 1 else: nspins = 2 * input_shape[self.axis] - 2 result = list(input_shape) result.pop(self.axis) result.insert(self.out_axis % len(input_shape), nspins) return tf.TensorShape(result)
@staticmethod def _build_not_called(_: tf.Tensor) -> tf.Tensor: raise RuntimeError("Build was not called")
[docs]class ForwardWignerLayer(tf.keras.layers.Layer): """Performs a forward Wigner transform.""" def __init__( self, fourier_transform: FourierLayer, harmonic_transform: ForwardSpinLayer, **kwargs, ): if kwargs.pop("trainable", False): raise ValueError("This layer is not trainable") super().__init__(trainable=False, **kwargs) self._fourier = fourier_transform self._harmonic = harmonic_transform self._spin_offset = tf.constant(0) self._lfactor = tf.constant(0)
[docs] def build(self, input_shape): from tensossht.sampling import Axis, HarmonicAxes, ImageAxes super().build(input_shape) assert isinstance(self._harmonic.in_axes, ImageAxes) self._harmonic.in_axes = self._harmonic.in_axes.shift( len(input_shape), Axis.SPIN ) self._fourier.build(input_shape) if input_shape is None: raise ValueError() shape = self._fourier.compute_output_shape(input_shape) self._harmonic.build(shape) hsampling = self._harmonic.harmonic_sampling(shape) rdtype = tf.dtypes.as_dtype(self.dtype).real_dtype cdtype = tf.complex(tf.zeros(0, dtype=rdtype), tf.zeros(0, dtype=rdtype)).dtype self._lfactor = ( _lfactor( hsampling=hsampling, axes=cast(HarmonicAxes, self._harmonic.out_axes), input_shape=input_shape, dtype=cdtype, ) * _sfactor( hsampling=hsampling, axes=cast(HarmonicAxes, self._harmonic.out_axes), input_shape=input_shape, dtype=cdtype, ) / input_shape[self._fourier.axis] )
@tf.function def call(self, inputs: tf.Tensor) -> tf.Tensor: gamma_fourier = self._fourier(inputs) if not self._fourier.is_real: gamma_fourier = tf.signal.fftshift(gamma_fourier, self._fourier.out_axis) return cast(Callable, self._harmonic)(gamma_fourier) * self._lfactor
[docs]class InverseWignerLayer(tf.keras.layers.Layer): """Performs an inverse Wigner transform.""" def __init__( self, harmonic_transform: InverseSpinLayer, fourier_transform: FourierLayer, **kwargs, ): if kwargs.pop("trainable", False): raise ValueError("This layer is not trainable") super().__init__(trainable=False, **kwargs) self._harmonic = harmonic_transform self._fourier = fourier_transform self._lfactor = tf.constant(0) def __call__(self, tensor: tf.Tensor) -> tf.Tensor: return super().__call__(tensor)
[docs] def build(self, input_shape): from tensossht.sampling import Axis, ImageAxes assert isinstance(self._harmonic.out_axes, ImageAxes) self._harmonic.out_axes = self._harmonic.out_axes.shift( len(input_shape), Axis.SPIN ) self._harmonic.build(input_shape) harmonic_shape = self._harmonic.compute_output_shape(input_shape) self._fourier.build(harmonic_shape) hsampling = self._harmonic.harmonic_sampling(input_shape) rdtype = tf.dtypes.as_dtype(self.dtype).real_dtype cdtype = tf.complex(tf.zeros(0, dtype=rdtype), tf.zeros(0, dtype=rdtype)).dtype nspins = self._fourier.compute_output_shape(harmonic_shape)[ self._fourier.out_axis ] self._factor = ( nspins / _lfactor( hsampling=hsampling, axes=cast(HarmonicAxes, self._harmonic.in_axes), input_shape=input_shape, dtype=cdtype, ) * _sfactor( hsampling=hsampling, axes=cast(HarmonicAxes, self._harmonic.in_axes), input_shape=input_shape, dtype=cdtype, ) )
@tf.function def call(self, inputs: tf.Tensor) -> tf.Tensor: gamma_fourier = self._harmonic.__call__(inputs * self._factor) if not self._fourier.is_real: gamma_fourier = tf.signal.ifftshift(gamma_fourier, self._fourier.axis) return self._fourier(gamma_fourier)
[docs]def wigner_layer( is_forward: bool = True, is_real: bool = False, is_odd_spin: bool = True, sampling: Union[str, ImageSamplingSchemes] = "mw", theta_dim: int = -3, phi_dim: int = -2, gamma_dim: int = -1, spin_dim: int = -2, coeff_dim: int = -1, dtype: Union[str, np.dtype, tf.DType] = tf.float64, **kwargs, ) -> tf.keras.layers.Layer: """Factory for wigner transform layers. Args: is_forward: If ``True``, performs image to Wigner space transform. If ``False``, performs Wigner to image space transform. is_real: If ``True``, the image-space signal is real. is_odd_spin: Only useful for inverse transforms to real image-space signals. If ``True``, recovers an odd number of spin functions. Otherwise, recovers an even number of spin functions. sampling: Real-space sampling. theta_dim: Index of the :math:`\\theta` dimension. phi_dim: Index of the :math:`\\phi` dimension. gamma_dim: Index of the :math:`\\gamma` dimension. spin_dim: Index of the spin dimension. coeff_dim: Index of the coefficients dimension. dtype: Underlying type of floating point operations. Only the bit-size of the floating point matters. A complex ``dtype`` can be given equivalently. """ from tensossht.layers.harmonic_transforms import ForwardSpinLayer, InverseSpinLayer rdtype = tf.dtypes.as_dtype(dtype).real_dtype cdtype = tf.complex(tf.zeros(0, dtype=rdtype), tf.zeros(0, dtype=rdtype)).dtype fft = FourierLayer( is_forward=is_forward, is_real=is_real, axis=gamma_dim if is_forward else -1, out_axis=-1 if is_forward else gamma_dim, dtype=rdtype if is_real else cdtype, is_odd_spin=is_odd_spin, **kwargs, ) HTL = ForwardSpinLayer if is_forward else InverseSpinLayer ht = HTL( sampling=sampling, theta_dim=theta_dim, phi_dim=phi_dim, spin_dim=gamma_dim if is_forward else spin_dim, out_spin_dim=spin_dim if is_forward else gamma_dim, coeff_dim=coeff_dim, flip_spin=True, centered_spin=not is_real, dtype=cdtype, **kwargs, ) WTL = ForwardWignerLayer if is_forward else InverseWignerLayer return WTL( # type: ignore fourier_transform=fft, harmonic_transform=ht, dtype=fft.dtype, **kwargs )
def _lfactor( hsampling: HarmonicSampling, axes: HarmonicAxes, input_shape: Sequence[int], dtype: Union[str, np.dtype, tf.DType] = tf.float64, ) -> tf.Tensor: assert axes.spin is not None rdtype = tf.dtypes.as_dtype(dtype).real_dtype llabels = hsampling.llabels[: hsampling.ncoeffs // hsampling.nspins] lfactor = ( tf.constant(4 * np.pi, rdtype) * tf.math.sqrt(tf.constant(np.pi, rdtype)) / tf.math.sqrt(tf.cast(2 * llabels + 1, rdtype)) ) return tf.reshape( tf.cast(lfactor, dtype), tf.one_hot( axes.coeff % (len(input_shape) - 1), len(input_shape) - 1, on_value=-1, off_value=1, ), ) def _sfactor( hsampling: HarmonicSampling, axes: HarmonicAxes, input_shape: Sequence[int], dtype: Union[str, np.dtype, tf.DType] = tf.float64, ) -> tf.Tensor: assert axes.spin is not None slabels = tf.reshape(hsampling.slabels, [hsampling.nspins, -1])[:, 0] return tf.reshape( tf.cast(1 - 2 * (slabels % 2), dtype), tf.one_hot( axes.spin % (len(input_shape) - 1), len(input_shape) - 1, on_value=-1, off_value=1, ), )