Source code for tensossht.specialfunctions.kostelec

"""
Kostelec Recurrence
===================

Implementation of the recurrence from [Kostelec, Rockmore (2008)].

API
---

.. autofunction:: tensossht.specialfunctions.kostelec.legendre
.. autofunction:: tensossht.specialfunctions.kostelec.wignerd
.. autofunction:: tensossht.specialfunctions.kostelec.logfact
"""
from textwrap import dedent
from typing import Optional, Tuple, Union

import numpy as np
import tensorflow as tf

from tensossht.references import REFERENCES


[docs]@tf.function def logfact( n: Union[int, tf.Tensor], dtype: Union[str, tf.DType] = tf.float32 ) -> tf.Tensor: """Log of factorial 0 <= i <= n. Example: >>> from tensossht.specialfunctions.kostelec import logfact >>> tf.exp(logfact(5)).numpy().astype(float).round(2) array([ 1., 1., 2., 6., 24., 120.]) """ initial = tf.range(1, n + 1, dtype=tf.dtypes.as_dtype(dtype).real_dtype) logfacts = tf.scan(lambda a, x: a + x, tf.math.log(initial)) return tf.concat([tf.zeros(1, dtype=logfacts.dtype), logfacts], 0)
[docs]def legendre( beta: Union[float, tf.Tensor], lmax: Optional[int] = None, lmin: int = 0, mmin: int = 0, mmax: Optional[int] = None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None, scaled: bool = True, ) -> tf.Tensor: from tensossht.sampling import legendre_labels if labels is None and lmax is None: raise ValueError("At least one of lmax or labels is required on input") elif labels is not None and lmax is not None and lmax != tf.reduce_max(labels): raise ValueError("Only one lmax or labels should be given on input") elif lmax is not None: labels = legendre_labels(lmax=lmax, lmin=lmin, mmin=mmin, mmax=mmax) else: labels = tf.convert_to_tensor(labels) lmax = tf.reduce_max(labels) assert lmax is not None _beta: tf.Tensor = tf.convert_to_tensor(beta, name="beta") logfacts = logfact(2 * lmax + 2, dtype=_beta.dtype) if scaled: return _scaled_legendre_impl(_beta, labels, logfacts) else: return _legendre_impl(_beta, labels, logfacts)
legendre.__doc__ = ( dedent( r"""Legendre polynomials via the Kostelec recurrence. See [Kostelec, Rockmore (2008)], equations 4.6, 4.10, and 4.11. Specifically, They are implemented here with :math:`m'=0`, reproduced and simplified below: .. math:: d^j_{j0}(\beta) = \sqrt{\frac{(2j)!}{j!^2}} 2^{-j}\sin^j\beta \tilde{d}^j_{m0} = \sqrt{\frac{2}{3j + 1}} d^j_{m0} \tilde{d}^{j+1}_{m0} = \sqrt{\frac{(2j + 3)(2j+1)}{(j + 1)^2 - m^2}}\cos\beta \tilde{d}^j_{m0} - \sqrt{\frac{(2j + 3)(j^2 - m^2)}{(2j - 1)(j + 1)^2 - m^2}} \tilde{d}^{j-1}_{m0} %s Args: beta: Scalar or multi-dimenionsonal vector of values. ``beta`` corresponds to ``x=tf.cos(beta)`` in other formulations of the legendre polynomial. lmax: If given, all values ``lmax >= l >= lmin`` and ``mmax >= m >= mmin`` will be computed. Otherwise, ``labels`` should be given as input. lmin: Minimum ``l`` value. mmax: Maximum ``m``, defaults to ``lmax``. mmin: Minimum ``m`` value. labels: a 2 by n vector of ``(l, m)`` values. If given, then ``lmax`` should not. scaled: If ``True``, then does not take into account the scaling :math:`\sqrt{\frac{(l - m)!}{(l + m)!}}`. The scaling makes the function numerically well-behaved. Example: There are two means of calling this function. Either the labels are given explicitly as as 2 by n vector, or they are specified via ``lmax``, ``lmin``, ``mmax``, and ``mmin``. In the latter case, the labels are obtained via :py:func:`~tensossht.specialfunctions.legendre_labels`. The input ``beta`` can be a ``float`` or anything convertible to a tensorflow tensor: >>> from pytest import approx >>> from tensossht import legendre_labels >>> from tensossht.specialfunctions.kostelec import legendre >>> lmax = 4 >>> legpol = legendre(beta=0.4, lmax=lmax) >>> legpol.numpy().round(4) array([ 1. , 0.9211, -0.2754, 0.7725, -0.4393, 0.0929, 0.5719, -0.5466, 0.1913, -0.033 ], dtype=float32) We can verify the two methods are equivalent: >>> labels = legendre_labels(lmax=lmax) >>> labels <tf.Tensor: shape=(2, 10), dtype=int32, numpy= array([[0, 1, 1, 2, 2, 2, 3, 3, 3, 3], [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]], dtype=int32)> >>> assert legpol.numpy() == approx(legendre(beta=0.4, labels=labels).numpy()) We can check that the legendre polynomials are correct against a naive, brute-force implementation: >>> from tensossht.specialfunctions import naive >>> beta = tf.random.uniform([3], dtype=tf.float64) * np.pi >>> legpol = legendre(beta, labels=labels, scaled=False) >>> for actual, (l, m) in zip(legpol.numpy().T, labels.numpy().T): ... expected = [float(naive.legendre(l, m, np.cos(b))) for b in beta] ... assert actual == approx(expected) Note that the input to the different legendre functions in the code are not all equivalent. They will differ by whether they take ``beta`` or ``tf.cos(beta)`` as input, and potentially by a keyword-triggered scaling. """ ) % REFERENCES["KR"] ) @tf.function def _legendre_impl( beta: tf.Tensor, labels: tf.Tensor, logfacts: tf.Tensor ) -> tf.Tensor: return tf.exp( 0.5 * ( tf.math.lgamma(tf.cast(labels[0] + labels[1] + 1, dtype=beta.dtype)) - tf.math.lgamma(tf.cast(labels[0] - labels[1] + 1, dtype=beta.dtype)) ) ) * _scaled_legendre_impl(beta, labels, logfacts) @tf.function def _scaled_legendre_impl( beta: tf.Tensor, labels: tf.Tensor, logfacts: tf.Tensor ) -> tf.Tensor: r"""Implementation of the Kostelec recurrence limited to m'=0.""" dtype = beta.dtype valid = labels[0] >= tf.abs(labels[1]) l = labels[0] m = tf.where(valid, tf.abs(labels[1]), 0) sinb = tf.sin(beta) cosb = tf.cos(beta) # kostelec 4.10 + 4.6 sinterm = tf.cast((m // 2), dtype=dtype) * ( tf.math.log(tf.expand_dims(sinb * sinb, -1)) - tf.math.log(tf.cast(4, dtype=dtype)) ) midterm = tf.exp( 0.5 * ( tf.gather(logfacts, 2 * m + 1, axis=0) - tf.math.log(tf.constant(2, dtype=dtype)) ) - tf.gather(logfacts, m, axis=0) + tf.where(tf.math.is_nan(sinterm), tf.zeros_like(sinterm), sinterm) ) * tf.where( m % 2 == 1, -0.5 * tf.expand_dims(sinb, -1), tf.expand_dims(tf.ones_like(beta), -1), ) # kostelec 4.10 + 4.11, assuming d^{j-1}{m0} is zero... highterm = tf.where( l == m, midterm, tf.sqrt(tf.cast(2 * m + 3, dtype=dtype)) * tf.expand_dims(cosb, -1) * midterm, ) j = m + 2 for _ in tf.range(tf.reduce_max(l - j) + 1): lowterm, midterm = midterm, highterm coeffm = ( tf.sqrt(tf.cast(4 * j * j - 1, dtype=dtype)) / tf.sqrt(tf.cast(j * j - m * m, dtype)) * tf.expand_dims(cosb, -1) ) coeffl = tf.sqrt( tf.cast((2 * j + 1) * ((j - 1) * (j - 1) - m * m), dtype=dtype) ) / tf.sqrt(tf.cast((2 * j - 3) * (j * j - m * m), dtype)) highterm = tf.where(j > l, highterm, coeffm * midterm - coeffl * lowterm) j = tf.where(j <= l, j + 1, j) result = highterm * tf.sqrt( tf.cast(2, dtype=dtype) / tf.cast(2 * l + 1, dtype=dtype) ) return tf.where( valid, tf.where(tf.logical_or(labels[1] >= 0, labels[1] % 2 == 0), result, -result), 0, )
[docs]def wignerd( beta: Union[float, tf.Tensor], *, labels: Optional[tf.Tensor] = None, lmax: Optional[int] = None, lmin: int = 0, mmin: Optional[int] = None, mmax: Optional[int] = None, mpmin: Optional[int] = None, mpmax: Optional[int] = None, scaled: bool = True, ) -> tf.Tensor: f"""Wigner-d via the Kostelec recurrence. See [Kostelec, Rockmore (2008)], equations 4.6, 4.10, and 4.11. {REFERENCES['KR']} """ from tensossht.sampling import wignerd_labels if not scaled: raise NotImplementedError if labels is None and lmax is None: raise ValueError("Missing on of labels or (lmax, lmin...)") if labels is None: assert isinstance(lmax, int) labels = wignerd_labels( lmax=lmax, lmin=lmin, mmin=mmin, mmax=mmax, mpmin=mpmin, mpmax=mpmax ) else: lmax = tf.reduce_max(labels[0]) assert labels is not None _beta = tf.convert_to_tensor(beta, name="beta") logfacts = logfact(2 * lmax + 2, dtype=_beta.dtype) # type: ignore epsilon = tf.constant({tf.float16: 1e-7}.get(_beta.dtype, 1e-45), dtype=_beta.dtype) condition = tf.logical_and( labels[0] >= tf.abs(labels[1]), labels[0] >= tf.abs(labels[2]) ) labels = tf.where(condition, labels, 0) result = _wignerd_impl(_beta, labels, logfacts, epsilon) return tf.where(condition, result, 0)
@tf.function def _wignerd_impl( beta: tf.Tensor, labels: tf.Tensor, logfacts: tf.Tensor, epsilon: Union[float, tf.Tensor] = 1e-45, ) -> tf.Tensor: r"""Implementation of the Kostelec recurrence.""" dtype = beta.dtype symlabs, signs = _symmetries(labels) l = symlabs[0] m = symlabs[1] mp = symlabs[2] sinb = tf.sin(beta / 2) cosb = tf.cos(beta / 2) factors = tf.sqrt(2 / tf.cast(2 * l + 1, dtype=dtype)) factors = tf.where(signs >= 0, factors, -factors) def logzero(exponent, logterm): """Deals with zeros in the log-space terms.""" return tf.cast(exponent, dtype=beta.dtype) * tf.math.log( tf.where(logterm <= epsilon, epsilon, logterm) ) # kostelec 4.10 + 4.6 midterm = ( tf.exp( 0.5 * ( tf.gather(logfacts, 2 * m + 1, axis=0) - tf.gather(logfacts, m + mp, axis=0) - tf.gather(logfacts, m - mp, axis=0) - tf.math.log(tf.constant(2, dtype=dtype)) ) + logzero((m + mp) // 2, tf.expand_dims(cosb * cosb, -1)) + logzero((m - mp) // 2, tf.expand_dims(sinb * sinb, -1)) ) * tf.where( (m + mp) % 2 == 1, tf.expand_dims(cosb, -1), tf.expand_dims(tf.ones_like(beta), -1), ) * tf.where( (m - mp) % 2 == 1, tf.expand_dims(-sinb, -1), tf.expand_dims(tf.ones_like(beta), -1), ) ) # kostelec 4.10 + 4.11, assuming d^{j-1}{m0} is zero... highterm = tf.where( l == m, midterm, tf.sqrt( tf.cast(2 * m + 3, dtype=dtype) / tf.cast((m + 1) * (m + 1) - mp * mp, dtype=dtype) ) * ( tf.cast(m + 1, dtype=dtype) * tf.expand_dims(tf.cos(beta), -1) - tf.cast(mp, dtype=dtype) ) * midterm, ) j = m + 2 for _ in tf.range(tf.reduce_max(l - j) + 1): lowterm, midterm = midterm, highterm coeffm = tf.sqrt( tf.cast((2 * j + 1) * (2 * j - 1), dtype=dtype) / tf.cast((j * j - m * m) * (j * j - mp * mp), dtype=dtype) ) * ( tf.expand_dims(tf.cos(beta), -1) * tf.cast(j, dtype=dtype) - tf.cast(m * mp, dtype=dtype) / tf.cast(j - 1, dtype=dtype) ) coeffl = ( tf.sqrt( tf.cast( (2 * j + 1) * ((j - 1) * (j - 1) - m * m) * ((j - 1) * (j - 1) - mp * mp), dtype=dtype, ) / tf.cast( (2 * j - 3) * (j * j - m * m) * (j * j - mp * mp), dtype=dtype ) ) * tf.cast(j, dtype=dtype) / tf.cast(j - 1, dtype=dtype) ) highterm = tf.where(j > l, highterm, coeffm * midterm - coeffl * lowterm) j = tf.where(j <= l, j + 1, j) return highterm * factors def _symmetries(labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: """Computes factor while making sure m >= mp >= 0.""" l = labels[0] ms = labels[1:] size_condition = tf.abs(ms[0]) < tf.abs(ms[1]) ms = tf.where(size_condition, ms[::-1, :], ms) sign_condition = ms[0] < 0 ms = tf.where(sign_condition, -ms, ms) factors = tf.where( tf.logical_and(size_condition, (ms[0] - ms[1]) % 2 == 1) != tf.logical_and(sign_condition, (ms[0] - ms[1]) % 2 == 1), -1, 1, ) return ( tf.concat( [tf.expand_dims(l, 0), tf.expand_dims(ms[0], 0), tf.expand_dims(ms[1], 0)], axis=0, ), factors, )