Source code for tensossht.sampling

"""
==================================
Image and Harmonic Space Samplings
==================================

This modules contains helpers to define the different harmonic-space and image space
samplings.

Creating Samplings
==================

.. autofunction :: tensossht.sampling.equiangular
.. autofunction :: tensossht.sampling.image_sampling_scheme
.. autofunction :: tensossht.sampling.harmonic_sampling_scheme
.. autofunction :: tensossht.sampling.wignerd_labels
.. autofunction :: tensossht.sampling.symmetric_labels
.. autofunction :: tensossht.sampling.spin_legendre_labels
.. autofunction :: tensossht.sampling.legendre_labels

Manipulating Samplings
======================

.. autofunction :: tensossht.sampling.transpose
.. autofunction :: tensossht.sampling.equiangular_shape

Data structures
===============

.. autoclass :: tensossht.sampling.HarmonicAxes
.. autoclass :: tensossht.sampling.MW
.. autoclass :: tensossht.sampling.MWSS
.. autoclass :: tensossht.sampling.EquiangularShape
.. autoclass :: tensossht.sampling.ImageSamplingBase
"""
from enum import Enum
from functools import singledispatch
from typing import (
    Any,
    Iterable,
    List,
    NamedTuple,
    Optional,
    Sequence,
    Tuple,
    Union,
    cast,
)

import numpy as np
import tensorflow as tf

from tensossht.typing import Array, TFArray

LType = Union[int, Array, tf.Variable]


def equiangular_theta(
    lmax: LType,
    dtype: Union[str, np.dtype, tf.DType] = tf.float64,
) -> TFArray:
    """Equiangular theta sampling.

    .. math::

        \\theta = \\pi\\frac{2n + 1}{2l_\\text{max} - 1};
        \\quad\\text{for}\\, n \\in [0, l_\\text{max}[

    Example:

        >>> from pytest import approx
        >>> from numpy import pi
        >>> from tensossht.sampling import equiangular_theta
        >>> for lmax in [1, 5, 10]:
        ...     expected = [(2 * t + 1) * pi / (2 * lmax - 1) for t in range(lmax)]
        ...     assert equiangular_theta(lmax).numpy() == approx(expected)

    """
    if lmax < 1:
        raise ValueError("lmax cannot be smaller than 1")

    step = tf.constant(np.pi / (2 * tf.constant(lmax, dtype=dtype) - 1), dtype=dtype)
    return tf.range(step, tf.constant(np.pi, dtype=dtype) + step, 2 * step)


def symmetric_theta(
    lmax: LType,
    dtype: Union[str, np.dtype, tf.DType] = tf.float64,
) -> TFArray:
    """Symmetric theta sampling.

    .. math::

        \\theta = \\pi\\frac{n}{l_\\text{max}};
        \\quad\\text{for}\\, n \\in [0, l_\\text{max} + 1[

    Example:

        >>> from pytest import approx
        >>> from tensossht.sampling import symmetric_theta
        >>> for lmax in [1, 5, 10]:
        ...     expected = [t * np.pi / lmax for t in range(lmax + 1)]
        ...     assert symmetric_theta(lmax).numpy() == approx(expected)

    """
    if lmax < 1:
        raise ValueError("lmax cannot be smaller than 1")

    step = tf.constant(np.pi / tf.constant(lmax, dtype=dtype), dtype=dtype)
    return tf.range(tf.constant(np.pi, dtype=dtype) + step, delta=step)


def equiangular_phi(
    lmax: LType,
    dtype: Union[str, np.dtype, tf.DType] = tf.float64,
) -> TFArray:
    """Equiangular phi sampling.

    .. math::

        \\phi = 2\\pi\\frac{n}{2l_\\text{max} - 1};
        \\quad\\text{for}\\, n \\in [0, 2l_\\text{max} - 1[

    Example:

        >>> from pytest import approx
        >>> from numpy import pi
        >>> from tensossht.sampling import equiangular_phi
        >>> for lmax in [1, 5, 10]:
        ...     expected = [(2 * pi * p) / (2 * lmax - 1) for p in range(2 * lmax - 1)]
        ...     assert equiangular_phi(lmax).numpy() == approx(expected)

    """
    if lmax < 1:
        raise ValueError("lmax cannot be smaller than 1")

    step = 2 * tf.constant(np.pi, dtype=dtype) / tf.constant(2 * lmax - 1, dtype=dtype)
    return tf.range(2 * tf.constant(np.pi, dtype=dtype) - step * 0.5, delta=step)


def symmetric_phi(
    lmax: LType,
    dtype: Union[str, np.dtype, tf.DType] = tf.float64,
) -> TFArray:
    """Symmetric phi sampling.

    .. math::

        \\phi = 2\\pi\\frac{n}{2l_\\text{max}};
        \\quad\\text{for}\\, n \\in [0, 2l_\\text{max}[

    Example:

        >>> from pytest import approx
        >>> from tensossht.sampling import symmetric_phi
        >>> for lmax in [1, 5, 10]:
        ...     expected = [np.pi * p / lmax for p in range(2 * lmax)]
        ...     assert symmetric_phi(lmax).numpy() == approx(expected)

    """
    if lmax < 1:
        raise ValueError("lmax cannot be smaller than 1")

    step = tf.constant(np.pi / tf.constant(lmax, dtype=dtype), dtype=dtype)
    return tf.range(2 * tf.constant(np.pi, dtype=dtype) - 0.5 * step, delta=step)


[docs]class EquiangularShape(NamedTuple): theta: int phi: int
[docs]def equiangular_shape(lmax: int) -> EquiangularShape: """Shape of equiangular samples. Example: >>> from tensossht.sampling import ( ... equiangular_phi, equiangular_theta, equiangular_shape ... ) >>> len(equiangular_theta(5)) == equiangular_shape(5)[0] True >>> len(equiangular_theta(5)) == equiangular_shape(5).theta True >>> len(equiangular_phi(5)) == equiangular_shape(5)[1] True >>> len(equiangular_phi(5)) == equiangular_shape(5).phi True """ return EquiangularShape(lmax, 2 * lmax - 1)
[docs]def equiangular( lmax: int, dtype: Union[str, np.dtype, tf.DType] = tf.float64 ) -> TFArray: """Equiangular samples. Example: >>> from tensossht.sampling import equiangular >>> equiangular(3).numpy().round(2) array([[[0.63, 0.63, 0.63, 0.63, 0.63], [1.88, 1.88, 1.88, 1.88, 1.88], [3.14, 3.14, 3.14, 3.14, 3.14]], <BLANKLINE> [[0. , 1.26, 2.51, 3.77, 5.03], [0. , 1.26, 2.51, 3.77, 5.03], [0. , 1.26, 2.51, 3.77, 5.03]]]) """ thetas = equiangular_theta(lmax, dtype) phis = equiangular_phi(lmax, dtype) return tf.concat( ( tf.repeat(thetas[None, :, None], [phis.shape[0]], axis=2), tf.repeat(phis[None, None, :], [thetas.shape[0]], axis=1), ), axis=0, )
[docs]class ImageSamplingBase: def __init__(self, lmax: int, dtype: tf.DType = tf.float64): self.lmax = lmax self.dtype = dtype def __repr__(self): return f"{self.__class__.__name__}(lmax={self.lmax}, dtype={self.dtype!r})" @property def thetas(_): raise NotImplementedError() @property def phis(_): raise NotImplementedError() @property def grid(self): """Image-space grid.""" thetas = self.thetas phis = self.phis return tf.concat( ( tf.repeat(thetas[None, :, None], [phis.shape[0]], axis=2), tf.repeat(phis[None, None, :], [thetas.shape[0]], axis=1), ), axis=0, ) @property def points(self): """Image-space points.""" return tf.transpose(tf.reshape(self.grid, (2, -1))) @property def shape(self): """Theta and phi shapes.""" return len(self.thetas), len(self.phis)
[docs]class MW(ImageSamplingBase): """McEwen - Viaux image-space sampling.""" @property def thetas(self): return equiangular_theta(lmax=self.lmax, dtype=self.dtype) @property def phis(self): return equiangular_phi(lmax=self.lmax, dtype=self.dtype) @property def shape(self): """Theta and phi shapes.""" return self.lmax, 2 * self.lmax - 1
[docs]class MWSS(ImageSamplingBase): """McEwen - Viaux symmetric image-space sampling.""" @property def thetas(self): return symmetric_theta(lmax=self.lmax, dtype=self.dtype) @property def phis(self): return symmetric_phi(lmax=self.lmax, dtype=self.dtype) @property def shape(self): """Theta and phi shapes.""" return self.lmax + 1, 2 * self.lmax
class ImageSamplingSchemes(Enum): """Helper for factories that need a sampling scheme as input. Example: >>> from tensossht import sampling >>> mw = sampling.ImageSamplingSchemes.MW >>> assert mw == sampling.image_sampling_scheme("mw") >>> assert mw.value is sampling.MW >>> mwss = sampling.ImageSamplingSchemes['MWSS'] >>> assert mwss == sampling.image_sampling_scheme("mwss") >>> assert mwss.value is sampling.MWSS """ MW = MW MWSS = MWSS
[docs]@singledispatch def image_sampling_scheme(sampling) -> ImageSamplingSchemes: raise RuntimeError(f"Unexpected input type {(sampling)}")
@image_sampling_scheme.register(ImageSamplingSchemes) def _image_sampling_scheme0(sampling: ImageSamplingSchemes) -> ImageSamplingSchemes: return sampling @image_sampling_scheme.register(str) def _image_sampling_scheme1(sampling: str) -> ImageSamplingSchemes: for smpl in ImageSamplingSchemes: if smpl.name.lower() == sampling.lower(): return smpl else: raise ValueError(f"Unkown sampling scheme {sampling}") @image_sampling_scheme.register(ImageSamplingBase) def _image_sampling_scheme2(sampling: ImageSamplingBase) -> ImageSamplingSchemes: for smpl in ImageSamplingSchemes: if isinstance(sampling, smpl.value): return smpl else: raise ValueError(f"Unkown sampling scheme {sampling}") class HarmonicSampling(NamedTuple): """Sampling in harmonic space. I.e. which spherical harmonics are included in the basis. """ lmax: int lmin: int mmax: int mmin: int smin: int smax: int is_separable_spin: bool labels: TFArray @property def is_real(self): return self.mmin >= 0 and self.smin == 0 and self.smax == 0 @property def is_complex(self): return not self.is_real @property def is_multi_spin(self): return self.smin != self.smax @property def is_single_spin(self): return not self.is_multi_spin @property def llabels(self): return self.labels[0] @property def mlabels(self): return self.labels[1] @property def slabels(self): return self.labels[2] @property def ncoeffs(self): return self.labels.shape[1] @property def nspins(self): return self.smax - self.smin + 1 @property def valid(self): """Whether the tuple (l, m, s) is valid. .. math:: l >= 0 |m| <= l |s| <= l """ return tf.logical_and( self.llabels >= 0, tf.logical_and( tf.math.abs(self.mlabels) <= self.llabels, tf.math.abs(self.slabels) <= self.llabels, ), ) @property def spins(self) -> Union[int, TFArray]: """Range of spins in the sampling.""" if self.is_multi_spin and self.is_separable_spin: return tf.reshape(self.slabels, (self.nspins, -1))[:, 0] elif self.smin == self.smax: return self.smin raise AttributeError("Cannot compute spin range")
[docs]def harmonic_sampling_scheme( lmax: Optional[Union[int, Array, HarmonicSampling]] = None, lmin: Optional[int] = None, mmax: Optional[int] = None, mmin: Optional[int] = None, smax: Optional[int] = None, smin: Optional[int] = None, spin: Optional[int] = None, labels: Optional[Array] = None, compact_spin: Optional[bool] = None, ) -> HarmonicSampling: """Basis functions in harmonic space. Args: lmax: Maximum degree, :math:`l < l_\\mathrm{max}`. For practical purposes, the first argument of the function can also be the 3 by n tensor ``labels``, or a :py:class:`~tensossht.sampling.HarmonicSampling` instance. lmin: minimum degree, :math:`l \\geq l_\\mathrm{min}` mmax: Maximum order, :math:`m \\leq m_\\mathrm{max}` mmin: Maximum order, :math:`m \\geq m_\\mathrm{min}` smax: Maximum spin, :math:`s \\leq s_\\mathrm{max}` smin: Maximum spin, :math:`s \\geq s_\\mathrm{min}` spin: shortcut for :math:`s = s_\\mathrm{min} = s_\\mathrm{max}` labels: 3 by n tensor listing the :math:`(l, m, s)` triplets. Alternative to specifying ``lmax`` and friends. compact_spin: If ``True``, the representation will be memory efficient. If ``False``, the representation is sucht the spin dimension can be separated from the other coefficients. The latter is generally more practical, especially for for smaller ``smax``. Ignored if ``labels`` are given. """ if isinstance(lmax, HarmonicSampling): return lmax # labels passed as first and only argument if ( lmin is None and mmax is None and mmin is None and smax is None and smin is None and spin is None and labels is None and len(getattr(lmax, "shape", ())) == 2 ): return harmonic_sampling_scheme(labels=lmax) if lmax is None and labels is None: raise ValueError("One of lmax or labels must be given") if lmax is not None and labels is not None: raise ValueError("Only one of lmax or labels can be given") if lmax is not None and len(tf.shape(lmax)) == 2 and tf.shape(lmax)[0] in (2, 3): lmax, labels = None, cast(Array, lmax) if lmax is not None: labels = _figure_labels(lmax, lmin, mmax, mmin, smax, smin, spin, compact_spin) assert labels is not None if len(labels.shape) != 2 or labels.shape[0] not in (2, 3): raise ValueError("labels outght to be a 2 or 3 by n matrix") if labels.shape[0] == 2: labels = tf.concat((labels, tf.zeros_like(labels[:1])), axis=0) assert labels is not None lmax = int(tf.reduce_max(labels[0])) + 1 lmin = int(tf.reduce_min(labels[0])) mmax = int(tf.reduce_max(labels[1])) mmin = int(tf.reduce_min(labels[1])) smax = int(tf.reduce_max(labels[2])) smin = int(tf.reduce_min(labels[2])) assert labels is not None assert lmin is not None and lmax is not None assert mmin is not None and mmax is not None assert smin is not None and smax is not None iss = is_separable_spin(smin, smax, labels) return HarmonicSampling(lmax, lmin, mmax, mmin, smin, smax, iss, labels)
def _figure_labels( lmax: int, lmin: Optional[int] = None, mmax: Optional[int] = None, mmin: Optional[int] = None, smax: Optional[int] = None, smin: Optional[int] = None, spin: Optional[int] = None, compact_spin: Optional[bool] = None, ) -> TFArray: if lmax <= 0: raise ValueError("lmax must be strictly positive.") if spin is not None and smin is not None and spin != smin: raise ValueError("Only one of spin or smin should be given") if spin is not None and smax is not None and spin != smax: raise ValueError("Only one of spin or smax should be given") if spin is not None and smin is None: smin = spin if spin is not None and smax is None: smax = spin if spin is not None and spin != smin: raise ValueError("When spin is given, smin should be None or equal to spin") if spin is not None and spin != smax: raise ValueError("When spin is given, smax should be None or equal to spin") if smin is None and smax is not None: smin = -np.abs(smax) if smax is not None and smax is None: smin = np.abs(smin) if compact_spin is None: compact_spin = (smin is None and smax is None) or (smin == smax) lmin = 0 if lmin is None else lmin mmax = (lmax - 1) if mmax is None else max(min(mmax, lmax - 1), 1 - lmax) mmin = (1 - lmax) if mmin is None else max(min(mmin, lmax - 1), 1 - lmax) smax = (lmax - 1) if smax is None else max(min(smax, lmax - 1), 1 - lmax) smin = 1 - lmax if smin is None else max(min(smin, lmax - 1), 1 - lmax) return spin_legendre_labels( lmax=lmax, lmin=lmin, mmax=mmax, mmin=mmin, smin=smin, smax=smax, compact_spin=compact_spin if compact_spin is not None else True, ) def is_separable_spin(smin: int, smax: int, labels: Array) -> bool: """True if the labels can be separated.""" if smin == smax: return True unicity = tf.unique_with_counts(labels[2]) nspins = len(unicity.y) if len(labels[2]) % nspins != 0: return False indices = tf.reshape(unicity.idx, (nspins, -1)) if tf.reduce_any(indices != tf.range(nspins)[:, None]): return False separated = tf.reshape(labels, (3, nspins, -1)) if tf.reduce_any(separated[0, 0] != separated[0, 1:]): return False if tf.reduce_any(separated[1, 0] != separated[1, 1:]): return False return True class Axis(str, Enum): """Names axes of interest. Example: >>> from tensossht.sampling import Axis >>> Axis("phi") <Axis.PHI: 'phi'> >>> Axis("phi") == "phi" True >>> Axis("theta") == "phi" False >>> Axis("theta") == Axis("theta") True """ PHI: str = "phi" # type: ignore THETA: str = "theta" # type: ignore SPIN: str = "spin" # type: ignore COEFF: str = "coeff" # type: ignore class ImageAxesBase(NamedTuple): """Avoids __classcell__ issues. NamedTuple does not pass ``__classcell__`` to ``type.__new__``, which means ``super`` cannot be used in python 3.8. We add a base class to avoid calling ``super`` directly. """ phi: int theta: int spin: Optional[int] = None class ImageAxes(ImageAxesBase): """Keeps track of the dimensions of an image tensor. Example: Images axes contain three components corresponding to the theta, phi and spin dimensions. Spin can be missing, in which case it is `None`. >>> from tensossht.sampling import ImageAxes, Axis >>> axes = ImageAxes(phi=-2, theta=-1, spin=1) >>> axes ImageAxes(phi=-2, theta=-1, spin=1) >>> axes[Axis.PHI] -2 >>> axes["theta"] -1 >>> axes[2] 1 >>> axes % 5 ImageAxes(phi=3, theta=4, spin=1) One or more axis can be shifted to the end (e.g. most rapidly incrementing dimension in tensorflow): >>> axes.shift(5, Axis.PHI) ImageAxes(phi=-1, theta=-2, spin=1) >>> axes.shift(5, "spin") ImageAxes(phi=-3, theta=-2, spin=-1) >>> axes.shift(5, "spin", Axis.THETA) ImageAxes(phi=-3, theta=-1, spin=-2) Spin-less instances can still "shift" the spin axis, although nothing is done in that case. This makes it easier to implement spin and spin-less algorithms: >>> ImageAxes(phi=-2, theta=-1, spin=None).shift(5, Axis.PHI, Axis.SPIN) ImageAxes(phi=-1, theta=-2, spin=None) And a tensor can then be tranposed accordingly: >>> tensor = tf.zeros((2, 3, 4, 5, 6)) >>> assert axes.transpose(tensor).shape == tensor.shape >>> assert axes.transpose(tensor, "spin").shape == (2, 4, 5, 6, 3) >>> assert axes.transpose(tensor, "spin", Axis.THETA).shape == (2, 4, 5, 3, 6) Optionally, the spin axis can be None, for single spin calculations: >>> axes = ImageAxes(phi=-2, theta=-1) >>> assert axes[Axis.SPIN] is None >>> assert axes.shift(3, Axis.PHI) == ImageAxes(theta=-2, phi=-1) In that case, shifting the spin axis is ignored: >>> assert axes.shift(3, Axis.SPIN) == axes >>> assert axes.shift(3, Axis.PHI, Axis.SPIN) == ImageAxes(theta=-2, phi=-1) Arbitrary permutations can be applied to the axes: >>> axes.permutate([1, 0, 3, 2, 4]) ImageAxes(phi=-3, theta=-1, spin=None) It is also possible to figure out the permutation from one `ImageAxes` instance to another. The permutation does not change the order of the unnamed axes. >>> atstart = ImageAxes(phi=0, theta=1) >>> atend = ImageAxes(phi=-1, theta=-2) >>> atstart.permutation(atend, ndims=5) [2, 3, 4, 1, 0] >>> atend.permutation(atstart, ndims=6) [5, 4, 0, 1, 2, 3] >>> src = ImageAxes(phi=1, theta=-1, spin=-2) >>> dst = ImageAxes(phi=-1, spin=-3, theta=-2) >>> src.permutation(dst, ndims=5) [0, 2, 3, 4, 1] The transpose can be applied to a tensor as follows: >>> tensor = tf.zeros((2, 3, 4, 5, 6)) >>> assert src.transpose(tensor, dst).shape == (2, 4, 5, 6, 3) """ def __getitem__(self, index: Union[int, slice, Axis, str]): return ImageAxesBase.__getitem__( cast(Any, self), _axes_normalize_index(self, index) ) def __mod__(self, ndims: int): return ImageAxes(*_axes_mod(self, ndims)) def shift(self, ndims: int, *shifted: Union[str, Axis]): """Shift given axes to end.""" return ImageAxes(**_axes_shift(self, ndims, *shifted)) def transpose( self, tensor: Array, *shifted: Union[str, Axis, "ImageAxes"] ) -> TFArray: """Transpose tensor by shifting given axes to end.""" return transpose(tensor, self, *shifted) def permutate(self, permutation: Sequence[int]): return ImageAxes(**_axes_permutate(self, permutation)) def permutation(self, other: "ImageAxes", ndims: int): """Permutation taking this set of axes to input axes.""" return _axes_permutation(ndims, self, other) class HarmonicAxesBase(NamedTuple): """Avoids __classcell__ issues. NamedTuple does not pass ``__classcell__`` to ``type.__new__``, which means ``super`` cannot be used in python 3.8. We add a base class to avoid calling ``super`` directly. """ coeff: int spin: Optional[int] = None
[docs]class HarmonicAxes(HarmonicAxesBase): """Keeps track of the dimensions of an spherical harmonic tensor. Example: Harmonics axes contain two components corresponding to the coefficient and spin dimensions. Spin can be missing, in which case it is `None`. >>> from tensossht.sampling import HarmonicAxes >>> axes = HarmonicAxes(coeff=-1, spin=-2) >>> axes HarmonicAxes(coeff=-1, spin=-2) >>> axes[Axis.COEFF] -1 >>> axes["spin"] -2 >>> axes[1] -2 >>> axes % 5 HarmonicAxes(coeff=4, spin=3) One or more axis can be shifted to the end (e.g. most rapidly incrementing dimension in tensorflow): >>> axes.shift(3, "coeff") HarmonicAxes(coeff=-1, spin=-2) >>> axes.shift(3, "spin") HarmonicAxes(coeff=-2, spin=-1) >>> axes.shift(3) HarmonicAxes(coeff=-1, spin=-2) >>> axes.shift(5, Axis.COEFF, "spin") HarmonicAxes(coeff=-2, spin=-1) And a tensor can then be tranposed accordingly: >>> tensor = tf.zeros((2, 3, 4, 5, 6)) >>> assert axes.transpose(tensor).shape == tensor.shape >>> assert axes.transpose(tensor, "spin").shape == (2, 3, 4, 6, 5) >>> assert axes.transpose(tensor, "spin", Axis.COEFF).shape == (2, 3, 4, 5, 6) >>> assert axes.transpose(tensor, Axis.COEFF, "spin").shape == (2, 3, 4, 6, 5) It is also possible to compute the permutation to go from one axis ordering to another: >>> HarmonicAxes(coeff=0).permutation(HarmonicAxes(coeff=-1), ndims=3) [1, 2, 0] >>> HarmonicAxes(coeff=1, spin=-1).permutation( ... HarmonicAxes(coeff=-1, spin=1), ndims=4 ... ) [0, 3, 2, 1] """ def __getitem__(self, index: Union[int, slice, Axis, str]): return HarmonicAxesBase.__getitem__( cast(Any, self), _axes_normalize_index(self, index) ) def __mod__(self, ndims: int): return HarmonicAxes(*_axes_mod(self, ndims)) def shift(self, ndims: int, *shifted: Union[str, Axis]): """Shift given axes to end.""" return HarmonicAxes(**_axes_shift(self, ndims, *shifted)) def transpose( self, tensor: Array, *shifted: Union[str, Axis, "HarmonicAxes"] ) -> TFArray: """Transpose tensor by shifting given axes to end.""" return transpose(tensor, self, *shifted) def permutate(self, permutation: Sequence[int]): return HarmonicAxes(**_axes_permutate(self, permutation)) def permutation(self, other: "HarmonicAxes", ndims: int): """Permutation taking this set of axes to input axes.""" return _axes_permutation(ndims, self, other)
def _axes_mod(axes: Iterable, ndims: int) -> Iterable: return ((item % ndims if item is not None else None) for item in axes) def _axes_normalize_index( axes: Union[HarmonicAxes, ImageAxes], index ) -> Union[int, slice]: if isinstance(index, str): index = Axis(index.lower()) if isinstance(index, Axis): for i, axis in enumerate(axes._fields): if index is Axis(axis): return i else: raise IndexError(f"Uknown index {index}") return index def _axes_as_list( axes: Union[HarmonicAxes, ImageAxes], ndims: int ) -> List[Union[None, Axis]]: result: List[Union[None, Axis]] = [None] * ndims for field in axes._fields: index = cast(int, axes[field]) if index is not None: assert result[index] is None result[index] = Axis(field) return result def _axes_shift( axes: Union[HarmonicAxes, ImageAxes], ndims: int, *shifted_axes: Union[str, Axis] ): dims = _axes_as_list(axes, ndims) shifted: List[Union[None, Axis]] = [ Axis(x) for x in shifted_axes if axes[x] is not None # type: ignore ] reorder = [ax for ax in dims if ax not in shifted] + shifted return { k: i - ndims if Axis(k) in shifted or axes[k] < 0 else i # type: ignore for i, k in enumerate(reorder) if k is not None } def _axes_permutate(a: Union[HarmonicAxes, ImageAxes], permutation: Sequence[int]): ndims = len(permutation) axes = _axes_as_list(a, ndims) permutated = [axes[i] for i in permutation] return { v: i if a[v] > 0 else i - ndims # type: ignore for i, v in enumerate(permutated) if v is not None }
[docs]def transpose( tensor: Array, axes: Union[ImageAxes, HarmonicAxes], *shifted_axes: Union[str, Axis, HarmonicAxes, ImageAxes], ) -> TFArray: """Transpose tensor so that given axes are pushed to the end.""" if len(shifted_axes) == 0: return tensor if len(shifted_axes) > 1 and any( (isinstance(u, (HarmonicAxes, ImageAxes)) for u in shifted_axes) ): msg = "Input ought to be a single ImageAxes tuple or one or more Axis instances" raise ValueError(msg) if len(shifted_axes) > 1 or not isinstance( shifted_axes[0], (HarmonicAxes, ImageAxes) ): return transpose( tensor, axes, axes.shift(len(tensor.shape), *shifted_axes) # type: ignore ) return tf.transpose( tensor, _axes_permutation(len(tensor.shape), axes, shifted_axes[0]) )
def _axes_permutation( ndims: int, src: Union[ImageAxes, HarmonicAxes], dst: Union[ImageAxes, HarmonicAxes] ) -> List[int]: from operator import itemgetter assert isinstance(src, ImageAxes) == isinstance(dst, ImageAxes) assert isinstance(src, HarmonicAxes) == isinstance(dst, HarmonicAxes) assert (src.spin is None) == (dst.spin is None) src = src % ndims dst = dst % ndims order = [ u[0] for u in sorted( ((k, v) for k, v in dst._asdict().items() if v is not None), key=itemgetter(1), ) ] permutation: List[int] = [i for i in range(ndims) if i not in src] for ax in order: permutation.insert(dst[ax], src[ax]) # type: ignore return permutation
[docs]def legendre_labels( lmax: int, lmin: Optional[int] = 0, mmax: Optional[int] = None, mmin: Optional[int] = 0, ) -> TFArray: r"""tensor with all (l, m) constrained by lmin, lmax and mmin, mmax. More specifically, we compute all values: .. math:: \left\{ (l, m); l \in [l_\text{min}, l_\text{max}[, m \in [m_\text{min}, m_\text{max}] \cap [-l, l] \right\} Generates a tensor with all ls and ms, as per requirements. ls are first and ms are second as illustrated below. Example: >>> from tensossht.sampling import legendre_labels >>> labels = legendre_labels(lmax=6, lmin=3, mmax=2, mmin=1) >>> labels <tf.Tensor: shape=(2, 6), dtype=int32, numpy= array([[3, 3, 4, 4, 5, 5], [1, 2, 1, 2, 1, 2]], dtype=int32)> >>> l, m = labels >>> l.numpy() array([3, 3, 4, 4, 5, 5], dtype=int32) """ if mmin is None: mmin = -lmax if mmax is None: mmax = lmax if lmin is None: lmin = 0 if lmin > lmax: raise ValueError(f"lmin > lmax ({lmin} > {lmax}") if mmin > mmax: raise ValueError(f"mmin > mmax ({mmin} > {mmax}") ls = tf.concat( [tf.fill((min(l, mmax) + 1 - max(-l, mmin),), l) for l in range(lmin, lmax)], 0 ) ms = tf.concat( [tf.range(max(-l, mmin), min(l, mmax) + 1) for l in range(lmin, lmax)], 0 ) return tf.concat([tf.expand_dims(ls, 0), tf.expand_dims(ms, 0)], axis=0)
[docs]def spin_legendre_labels( lmax: int, lmin: int = 0, mmax: Optional[int] = None, mmin: Optional[int] = 0, smax: Optional[int] = None, smin: Optional[int] = 0, dtype: Union[tf.DType, np.dtype, str] = tf.int32, compact_spin: bool = True, ) -> TFArray: r"""tensor with all (l, m, s) constrained by lmin, lmax and mmin, mmax, smin, smax. More specifically, we compute all values: .. math:: \left\{ (l, m, s); s \in [s_\text{min}, s_\text{max}] l \in [l_\text{min}, l_\text{max}[,\ l \geq |s|, m \in [m_\text{min}, m_\text{max}],\ -l \geq m \geq l \right\} Generates a tensor with all ls and ms, as per requirements. All pairs ``(l, m)`` for a given ``s`` are contiguous, e.g. ``s` is the outermost index. ``m`` is the innermost index, e.g. the most rapidly changing. This setup implies we could expand the labels into a 2-dimensional ragged tensor ``s`` vs ``(l, m)``. If `compact_spin` is ``True`` then the labels are arranged such that memory is optimized. If it false, then the tensor can be reshaped so that the spins are in a separate dimension. Example: The followin illustrates a compact label representation. >>> from tensossht import spin_legendre_labels >>> labels = spin_legendre_labels(lmax=3, smin=-1, smax=1) >>> labels <tf.Tensor: shape=(3, 16), dtype=int32, numpy= array([[ 1, 1, 2, 2, 2, 0, 1, 1, 2, 2, 2, 1, 1, 2, 2, 2], [ 0, 1, 0, 1, 2, 0, 0, 1, 0, 1, 2, 0, 1, 0, 1, 2], [-1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]], dtype=int32)> >>> l, m, s = labels >>> l.numpy() array([1, 1, 2, 2, 2, 0, 1, 1, 2, 2, 2, 1, 1, 2, 2, 2], dtype=int32) However, a less compact representation is available: >>> labels = spin_legendre_labels(lmax=3, smin=-1, smax=1, compact_spin=False) >>> tf.reshape(labels, (3, 3, -1)) <tf.Tensor: shape=(3, 3, 6), dtype=int32, numpy= array([[[ 0, 1, 1, 2, 2, 2], [ 0, 1, 1, 2, 2, 2], [ 0, 1, 1, 2, 2, 2]], <BLANKLINE> [[ 0, 0, 1, 0, 1, 2], [ 0, 0, 1, 0, 1, 2], [ 0, 0, 1, 0, 1, 2]], <BLANKLINE> [[-1, -1, -1, -1, -1, -1], [ 0, 0, 0, 0, 0, 0], [ 1, 1, 1, 1, 1, 1]]], dtype=int32)> """ if mmin is None: mmin = -lmax if mmax is None: mmax = lmax if smax is None: smax = lmax if smin is None: smin = -lmax if lmax < 0: raise ValueError(f"lmax ({lmax}) < 0") if lmin > lmax: raise ValueError(f"lmin > lmax ({lmin} > {lmax}") if mmin > mmax: raise ValueError(f"mmin > mmax ({mmin} > {mmax}") if smin > smax: raise ValueError(f"smin > smax ({smin} > {smax}") return tf.roll( tf.transpose( tf.constant( [ (s, order, m) for s in range(smin, smax + 1) for order in range( max(abs(s), lmin) if compact_spin else lmin, lmax ) for m in range(max(-order, mmin), min(order, mmax) + 1) ], dtype=dtype, ) ), shift=-1, axis=0, )
[docs]def wignerd_labels( lmax: int, lmin: int = 0, mmax: Optional[int] = None, mmin: Optional[int] = None, mpmin: Optional[int] = None, mpmax: Optional[int] = None, ) -> TFArray: r"""tensor with all (l, m, m') constrained by the input arguments More specifically, we compute all values: .. math:: \left\{ (l, m, m'); l \in [l_\text{min}, l_\text{max}[, m \in [m_\text{min}, m_\text{max}] \cap [-l, l] m' \in [m'_\text{min}, m'_\text{max}] \cap [-l, l] \right\} Generates a tensor with all :math:`l, m, m'` as per requirements. The first row corresponds to :math:`l`, the second to :math:`m`, and the third to :math:`m'`. Example: >>> from tensossht import wignerd_labels >>> labels = wignerd_labels(lmax=6, lmin=3, mmax=2, mmin=-1, mpmax=1, mpmin=-2) >>> labels <tf.Tensor: shape=(3, 48), dtype=int32, numpy= array([[ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5], [-1, -1, -1, -1, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, -1, -1, -1, -1, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, -1, -1, -1, -1, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2], [-2, -1, 0, 1, -2, -1, 0, 1, -2, -1, 0, 1, -2, -1, 0, 1, -2, -1, 0, 1, -2, -1, 0, 1, -2, -1, 0, 1, -2, -1, 0, 1, -2, -1, 0, 1, -2, -1, 0, 1, -2, -1, 0, 1, -2, -1, 0, 1]], dtype=int32)> """ if mmin is None: mmin = -lmax if mmax is None: mmax = lmax if lmin > lmax: raise ValueError(f"lmin > lmax ({lmin} > {lmax}") if mmin > mmax: raise ValueError(f"mmin > mmax ({mmin} > {mmax}") if mpmin is None: mpmin = mmin if mpmax is None: mpmax = mmax ls = tf.concat( [ tf.fill( ( (min(l, mmax) + 1 - max(-l, mmin)) * (min(l, mpmax) + 1 - max(-l, mpmin)), ), l, ) for l in range(lmin, lmax) ], 0, ) ms = tf.concat( [ tf.repeat( tf.range(max(-l, mmin), min(l, mmax) + 1), (min(l, mpmax) + 1 - max(-l, mpmin)), 0, ) for l in range(lmin, lmax) ], 0, ) mps = tf.concat( [ tf.tile( tf.range(max(-l, mpmin), min(l, mpmax) + 1), (min(l, mmax) + 1 - max(-l, mmin),), ) for l in range(lmin, lmax) ], 0, ) return tf.concat( [tf.expand_dims(ls, 0), tf.expand_dims(ms, 0), tf.expand_dims(mps, 0)], axis=0 )
[docs]def symmetric_labels( labels: Array, dtype: Union[str, np.dtype, tf.DType] = tf.int32 ) -> Tuple[TFArray, TFArray]: """Symmetrize wigner-d (l, m, m') labels so m >= m' >= 0. Returns a tuple with the symetrized labels and the sign factor. Example: First we create the labels: >>> from tensossht import wignerd_labels >>> from tensossht.sampling import symmetric_labels >>> labels = wignerd_labels(8) >>> symlabs, factors = symmetric_labels(labels) The we can verify the labels have been symmtrized: >>> assert tf.reduce_all(symlabs[1] >= symlabs[-1]) >>> assert tf.reduce_all(symlabs[-1] >= 0) We can check that the wigner-ds are equal to a factor, using the naive multi-precision implementation: >>> from pytest import approx >>> from tensossht.specialfunctions.naive import wignerd >>> for i in range(labels.shape[1]): ... expected = wignerd(*labels[:, i].numpy()) ... actual = factors[i].numpy() * wignerd(*symlabs[:, i].numpy()) ... assert actual == approx(expected) """ ms = tf.where(tf.abs(labels[1]) >= tf.abs(labels[-1]), labels[1:], -labels[2:0:-1]) factors = tf.where( tf.logical_and(ms[0] < 0, (labels[0] - ms[1]) % 2 == 1) != tf.logical_and(ms[1] < 0, (labels[0] - ms[0]) % 2 == 1), tf.constant(-1, dtype=dtype), tf.constant(1, dtype=dtype), ) labels = tf.concat((labels[0:1], tf.abs(ms)), axis=0) return labels, factors