diff --git a/python/snewpy/flavor.py b/python/snewpy/flavor.py new file mode 100644 index 000000000..9ed92eaaf --- /dev/null +++ b/python/snewpy/flavor.py @@ -0,0 +1,171 @@ +import enum +import numpy as np +import typing + +class EnumMeta(enum.EnumMeta): + def __getitem__(cls, key): + #if this is an iterable: apply to each value, and construct a tuple + if isinstance(key, slice): + return slice(cls[key.start],cls[key.stop],key.step) + if isinstance(key, typing.Iterable) and not isinstance(key, str): + return tuple(map(cls.__getitem__, key)) + #if this is from a flavor scheme: check that it's ours + if isinstance(key, FlavorScheme): + if not isinstance(key, cls): + raise TypeError(f'Value {repr(key)} is not from {cls.__name__} sheme!') + return key + #if this is a string find it by name + if isinstance(key, str): + try: + return super().__getitem__(key) + except KeyError as e: + raise KeyError( + f'Cannot find key "{key}" in {cls.__name__} sheme! Valid options are {list(cls)}' + ) + + if key is None: + return None + #if this is anything else - treat it as a slice + return np.array(list(cls.__members__.values()),dtype=object)[key] + +class FlavorScheme(enum.IntEnum, metaclass=EnumMeta): + def to_tex(self): + """LaTeX-compatible string representations of flavor.""" + base = r'\nu' + if self.is_antineutrino: + base = r'\overline{\nu}' + lepton = self.lepton.lower() + if self.is_muon or self.is_tauon: + lepton = '\\'+lepton + return f"${base}_{{{lepton}}}$" + + @property + def is_neutrino(self): + return not self.is_antineutrino + + @property + def is_antineutrino(self): + return '_BAR' in self.name + + @property + def is_electron(self): + return self.lepton=='E' + + @property + def is_muon(self): + return self.lepton=='MU' + + @property + def is_tauon(self): + return self.lepton=='TAU' + + @property + def is_sterile(self): + return self.lepton=='S' + + @property + def lepton(self): + return self.name.split('_')[1] + + @classmethod + def from_lepton_names(cls, name:str, leptons:list): + enum_class = cls(name, start=0, names = [f'NU_{L}{BAR}' for L in leptons for BAR in ['','_BAR']]) + return enum_class + + @classmethod + def take(cls, index): + return cls[index] + +TwoFlavor = FlavorScheme.from_lepton_names('TwoFlavor',['E','X']) +ThreeFlavor = FlavorScheme.from_lepton_names('ThreeFlavor',['E','MU','TAU']) +FourFlavor = FlavorScheme.from_lepton_names('FourFlavor',['E','MU','TAU','S']) + +class FlavorMatrix: + def __init__(self, + array:np.ndarray, + flavor:FlavorScheme, + from_flavor:FlavorScheme = None + ): + self.array = np.asarray(array) + self.flavor_out = flavor + self.flavor_in = from_flavor or flavor + expected_shape = (len(self.flavor_out), len(self.flavor_in)) + if(self.array.shape != expected_shape): + raise ValueError(f"FlavorMatrix array shape {self.array.shape} mismatch expected {expected_shape}") + + def _convert_index(self, index): + if isinstance(index, str) or (not isinstance(index,typing.Iterable)): + index = [index] + new_idx = [flavors[idx] for idx,flavors in zip(index, [self.flavor_out, self.flavor_in])] + return tuple(new_idx) + + def __getitem__(self, index): + return self.array[self._convert_index(index)] + + def __setitem__(self, index, value): + self.array[self._convert_index(index)] = value + + def _repr_short(self): + return f'{self.__class__.__name__}:<{self.flavor_in.__name__}->{self.flavor_out.__name__}> shape={self.shape}' + + def __repr__(self): + s=self._repr_short()+'\n'+repr(self.array) + return s + def __eq__(self,other): + return self.flavor_in==other.flavor_in and self.flavor_out==other.flavor_out and np.allclose(self.array,other.array) + + def __matmul__(self, other): + if isinstance(other, FlavorMatrix): + try: + data = np.tensordot(self.array, other.array, axes=[1,0]) + return FlavorMatrix(data, self.flavor_out, from_flavor = other.flavor_in) + except Exception as e: + raise ValueError(f"Cannot multiply {self._repr_short()} by {other._repr_short()}") from e + elif hasattr(other, '__rmatmul__'): + return other.__rmatmul__(self) + raise TypeError(f"Cannot multiply object of {self.__class__} by {other.__class__}") + #properties + @property + def shape(self): + return self.array.shape + @property + def flavor(self): + return self.flavor_out + + @classmethod + def eye(cls, flavor:FlavorScheme, from_flavor:FlavorScheme = None): + from_flavor = from_flavor or flavor + shape = (len(from_flavor), len(flavor)) + data = np.eye(*shape) + return cls(data, flavor, from_flavor) + + @classmethod + def from_function(cls, flavor:FlavorScheme, from_flavor:FlavorScheme = None): + """A decorator for creating the flavor matrix from the given function""" + from_flavor = from_flavor or flavor + def _decorator(function): + data = [[function(f1,f2) + for f2 in from_flavor] + for f1 in flavor] + + return cls(np.array(data,dtype=float), flavor, from_flavor) + return _decorator + #flavor conversion utils + +def conversion_matrix(from_flavor:FlavorScheme, to_flavor:FlavorScheme): + @FlavorMatrix.from_function(to_flavor, from_flavor) + def convert(f1, f2): + if f1.name == f2.name: + return 1. + if (f1.is_neutrino == f2.is_neutrino) and (f2.lepton == 'X' and f1.lepton in ['MU', 'TAU']): + # convert from TwoFlavor to more flavors + return 1. + if (f1.is_neutrino == f2.is_neutrino) and (f1.lepton == 'X' and f2.lepton in ['MU', 'TAU']): + # convert from more flavors to TwoFlavor + return 0.5 + return 0. + return convert + +FlavorScheme.conversion_matrix = classmethod(conversion_matrix) +EnumMeta.__rshift__ = conversion_matrix +EnumMeta.__lshift__ = lambda f1,f2:conversion_matrix(f2,f1) diff --git a/python/snewpy/flux.py b/python/snewpy/flux.py index 0fb330374..fc9469153 100644 --- a/python/snewpy/flux.py +++ b/python/snewpy/flux.py @@ -52,7 +52,8 @@ """ from typing import Union, Optional, Set, List -from snewpy.neutrino import Flavor +#from snewpy.neutrino import Flavor +from snewpy.flavor import FlavorScheme, FlavorMatrix from astropy import units as u import numpy as np @@ -86,11 +87,12 @@ class _ContainerBase: unit = None def __init__(self, data: u.Quantity, - flavor: List[Flavor], + flavor: List[FlavorScheme], time: u.Quantity[u.s], energy: u.Quantity[u.MeV], *, - integrable_axes: Optional[Set[Axes]] = None + integrable_axes: Optional[Set[Axes]] = None, + flavor_scheme:Optional[FlavorScheme] = None ): """A container class storing the physical quantity (flux, fluence, rate...), which depends on flavor, time and energy. @@ -112,7 +114,11 @@ def __init__(self, integrable_axes: set of :class:`Axes` or None List of axes which can be integrated. - If None (default) this set will be derived from the axes shapes + If None (default) this set will be derived from the axes shapes + + flavor_scheme: a subclass of :class:`snewpy.flavor.FlavorSchemes` or None + A class which lists all the allowed flavors. + If None (default) this value will be retrieved from the ``flavor`` arguemnt. """ if self.unit is not None: #try to convert to the unit @@ -121,8 +127,20 @@ def __init__(self, self.array = u.Quantity(data) self.time = u.Quantity(time, ndmin=1) self.energy = u.Quantity(energy, ndmin=1) - self.flavor = np.sort(np.array(flavor, ndmin=1)) - + self.flavor = np.array(flavor,ndmin=1, dtype=object) + self.flavor_scheme = flavor_scheme + if not flavor_scheme: + #guess the flavor scheme + if isinstance(flavor, type):#issubclass without isinstance(type) check raises TypeError + if issubclass(flavor, FlavorScheme): + self.flavor_scheme = flavor + else: + #get schemes from the data + flavor_schemes = set(f.__class__ for f in self.flavor) + if len(flavor_schemes)!=1: + raise ValueError(f"Flavors {flavor} must be from a single flavor scheme, but are from {flavor_schemes}") + else: + self.flavor_scheme = flavor_schemes.pop() Nf,Nt,Ne = len(self.flavor), len(self.time), len(self.energy) #list all valid shapes of the input array expected_shapes=[(nf,nt,ne) for nf in (Nf,Nf-1) for nt in (Nt,Nt-1) for ne in (Ne,Ne-1)] @@ -162,21 +180,28 @@ def shape(self): def __getitem__(self, args)->'Container': """Slice the flux array and produce a new Flux object""" - try: - iter(args) - except TypeError: + if not isinstance(args,tuple): args = [args] - args = [a if isinstance(a, slice) else slice(a, a + 1) for a in args] - #expand args to match axes - args+=[slice(None)]*(len(Axes)-len(args)) - array = self.array.__getitem__(tuple(args)) - newaxes = [ax.__getitem__(arg) for arg, ax in zip(args, self.axes)] - return self.__class__(array, *newaxes) + args = list(args) + arg_slices = [slice(None)]*len(Axes) + if isinstance(args[0],str) or isinstance(args[0],FlavorScheme): + args[0] = self.flavor_scheme[args[0]] + for n,arg in enumerate(args): + if not isinstance(arg, slice): + arg = slice(arg, arg + 1) + arg_slices[n] = arg + + array = self.array.__getitem__(tuple(arg_slices)) + newaxes = [ax.__getitem__(arg) for arg, ax in zip(arg_slices, self.axes)] + return self.__class__(array, *newaxes, flavor_scheme=self.flavor_scheme) def __repr__(self) -> str: """print information about the container""" s = [f"{len(values)} {label.name}({values.min()};{values.max()})" - for label, values in zip(Axes,self.axes)] + if label!=Axes.flavor + else f"{len(values)} {label.name}[{self.flavor_scheme}]({values.min()};{values.max()})" + for label, values in zip(Axes,self.axes) + ] return f"{self.__class__.__name__} {self.array.shape} [{self.array.unit}]: <{' x '.join(s)}>" def sum(self, axis: Union[Axes,str])->'Container': @@ -317,7 +342,7 @@ def __mul__(self, factor) -> 'Container': array = self.array*factor axes = list(self.axes) return Container(array, *axes) - + def save(self, fname:str)->None: """Save container data to a given file (using `numpy.savez`)""" def _save_quantity(name): @@ -329,8 +354,9 @@ def _save_quantity(name): except: return {name:values} data_dict = {} - for name in ['array','time','energy','flavor']: + for name in ['array','time','energy']: data_dict.update(_save_quantity(name)) + data_dict['flavor'] = np.array(self.flavor, dtype=object) np.savez(fname, _class_name=self.__class__.__name__, **data_dict, @@ -340,7 +366,7 @@ def _save_quantity(name): @classmethod def load(cls, fname:str)->'Container': """Load container from a given file""" - with np.load(fname) as f: + with np.load(fname, allow_pickle=True) as f: def _load_quantity(name): array = f[name] try: @@ -361,9 +387,29 @@ def __eq__(self, other:'Container')->bool: result = self.__class__==other.__class__ and \ self.unit == other.unit and \ np.allclose(self.array, other.array) and \ - all([np.allclose(self.axes[ax], other.axes[ax]) for ax in Axes]) + self.flavor_scheme==other.flavor_scheme and \ + len(self.flavor)==len(other.flavor) and \ + all(self.flavor==other.flavor) and \ + all([np.allclose(self.axes[ax], other.axes[ax]) for ax in list(Axes)[1:]]) return result + def _is_full_flavor(self): + return all(self.flavor==list(self.flavor_scheme)) + + def convert_to_flavor(self, flavor:FlavorScheme): + if(self.flavor_scheme==flavor): + return self + return (self.flavor_scheme>>flavor)@self + def __rshift__(self, flavor:FlavorScheme): + return self.convert_to_flavor(flavor) + + def __rmatmul__(self, matrix:FlavorMatrix): + if not self._is_full_flavor(): + raise RuntimeError(f"Cannot multiply flavor matrix object {self}, expected {len(self.flavor_scheme)} flavors") + if matrix.flavor_in!=self.flavor_scheme: + raise ValueError(f"Cannot multiply flavor matrix {matrix} by {self} - flavor scheme mismatch!") + array = np.tensordot(matrix.array,self.array, axes=[1,0]) + return Container(array, flavor=matrix.flavor_out, time=self.time, energy=self.energy) class Container(_ContainerBase): #a dictionary holding classes for each unit _unit_classes = {} diff --git a/python/snewpy/models/base.py b/python/snewpy/models/base.py index 0d14e30f4..0d81efc0c 100644 --- a/python/snewpy/models/base.py +++ b/python/snewpy/models/base.py @@ -190,8 +190,8 @@ def get_flux (self, t, E, distance, flavor_xform=NoTransformation()): factor = 1/(4*np.pi*(distance.to('cm'))**2) f = self.get_transformed_spectra(t, E, flavor_xform) - array = np.stack([f[flv] for flv in sorted(Flavor)]) - return Flux(data=array*factor, flavor=np.sort(Flavor), time=t, energy=E) + array = np.stack([f[flv] for flv in Flavor]) + return Flux(data=array*factor, flavor=Flavor, time=t, energy=E) diff --git a/python/snewpy/neutrino.py b/python/snewpy/neutrino.py index bcbb53e12..e1612360f 100644 --- a/python/snewpy/neutrino.py +++ b/python/snewpy/neutrino.py @@ -7,6 +7,7 @@ from typing import Optional import numpy as np from collections.abc import Mapping +from .flavor import TwoFlavor as Flavor class MassHierarchy(IntEnum): """Neutrino mass ordering: ``NORMAL`` or ``INVERTED``.""" @@ -23,32 +24,6 @@ def derive_from_dm2(cls, dm12_2, dm32_2, dm31_2): else: return MassHierarchy.INVERTED -class Flavor(IntEnum): - """Enumeration of CCSN Neutrino flavors.""" - NU_E = 0 - NU_X = 1 - NU_E_BAR = 2 - NU_X_BAR = 3 - - def to_tex(self): - """LaTeX-compatible string representations of flavor.""" - if '_BAR' in self.name: - return r'$\overline{{\nu}}_{0}$'.format(self.name[3].lower()) - return r'$\{0}$'.format(self.name.lower()) - - @property - def is_electron(self): - """Return ``True`` for ``Flavor.NU_E`` and ``Flavor.NU_E_BAR``.""" - return self.value in (Flavor.NU_E.value, Flavor.NU_E_BAR.value) - - @property - def is_neutrino(self): - """Return ``True`` for ``Flavor.NU_E`` and ``Flavor.NU_X``.""" - return self.value in (Flavor.NU_E.value, Flavor.NU_X.value) - - @property - def is_antineutrino(self): - return self.value in (Flavor.NU_E_BAR.value, Flavor.NU_X_BAR.value) @dataclass class MixingParameters3Flavor(Mapping): diff --git a/python/snewpy/test/test_flavors.py b/python/snewpy/test/test_flavors.py new file mode 100644 index 000000000..ead9e0dc6 --- /dev/null +++ b/python/snewpy/test/test_flavors.py @@ -0,0 +1,139 @@ +import pytest +import numpy as np +import snewpy.flavor +from snewpy.flavor import TwoFlavor,ThreeFlavor,FourFlavor, FlavorMatrix, FlavorScheme + +flavor_schemes = TwoFlavor,ThreeFlavor,FourFlavor + +class TestFlavorScheme: + @staticmethod + def test_flavor_scheme_lengths(): + assert len(TwoFlavor)==4 + assert len(ThreeFlavor)==6 + assert len(FourFlavor)==8 + + @staticmethod + + def test_getitem_string(): + assert TwoFlavor['NU_E'] == TwoFlavor.NU_E + assert TwoFlavor['NU_X'] == TwoFlavor.NU_X + with pytest.raises(KeyError): + TwoFlavor['NU_MU'] + + @staticmethod + def test_getitem_enum(): + assert TwoFlavor[TwoFlavor.NU_E] == TwoFlavor.NU_E + assert TwoFlavor[TwoFlavor.NU_X] == TwoFlavor.NU_X + with pytest.raises(TypeError): + TwoFlavor[ThreeFlavor.NU_E] + + @staticmethod + def test_values_from_different_enums(): + assert TwoFlavor.NU_E==ThreeFlavor.NU_E + assert TwoFlavor.NU_E_BAR==ThreeFlavor.NU_E_BAR + + @staticmethod + def test_makeFlavorScheme(): + TestFlavor = FlavorScheme.from_lepton_names('TestFlavor',leptons=['A','B','C']) + assert len(TestFlavor)==6 + assert [f.name for f in TestFlavor]==['NU_A','NU_A_BAR','NU_B','NU_B_BAR','NU_C','NU_C_BAR'] + + @staticmethod + def test_flavor_properties(): + f = ThreeFlavor.NU_E + assert f.is_neutrino + assert f.is_electron + assert not f.is_muon + assert not f.is_tauon + assert f.lepton=='E' + + f = ThreeFlavor.NU_MU + assert f.is_neutrino + assert not f.is_electron + assert f.is_muon + assert not f.is_tauon + assert f.lepton=='MU' + + f = ThreeFlavor.NU_E_BAR + assert not f.is_neutrino + assert f.is_electron + assert not f.is_muon + assert not f.is_tauon + assert f.lepton=='E' + + f = ThreeFlavor.NU_MU_BAR + assert not f.is_neutrino + assert not f.is_electron + assert f.is_muon + assert not f.is_tauon + assert f.lepton=='MU' + + f = ThreeFlavor.NU_TAU + assert f.is_neutrino + assert not f.is_electron + assert not f.is_muon + assert f.is_tauon + assert f.lepton=='TAU' + + f = ThreeFlavor.NU_TAU_BAR + assert not f.is_neutrino + assert not f.is_electron + assert not f.is_muon + assert f.is_tauon + assert f.lepton=='TAU' + +class TestFlavorMatrix: + @staticmethod + def test_init_square_matrix(): + m = FlavorMatrix(array=np.ones(shape=(4,4)), flavor=TwoFlavor) + assert m.shape == (4,4) + assert m.flavor_in == TwoFlavor + assert m.flavor_out == TwoFlavor + + @staticmethod + def test_getitem(): + m = FlavorMatrix.eye(TwoFlavor,TwoFlavor) + assert m[TwoFlavor.NU_E, TwoFlavor.NU_E]==1 + assert m['NU_E','NU_E']==1 + assert m['NU_E','NU_X']==0 + assert np.allclose(m['NU_E'], [1,0,0,0]) + assert np.allclose(m['NU_E'], m['NU_E',:]) + assert np.allclose(m[:,:], m.array) + + @staticmethod + def test_setitem(): + m = FlavorMatrix.eye(TwoFlavor,TwoFlavor) + m['NU_E']=[2,3,4,5] + assert m['NU_E','NU_E']==2 + assert m['NU_E','NU_X']==4 + #check that nothing changed in other parts + assert m['NU_X','NU_E']==0 + assert m['NU_X','NU_X']==1 + m['NU_E','NU_E_BAR']=123 + assert m['NU_E','NU_E_BAR']==123 + + @staticmethod + def test_init_square_matrix_with_wrong_shape_raises_ValueError(): + with pytest.raises(ValueError): + m = FlavorMatrix(array=np.ones(shape=(4,5)), flavor=TwoFlavor) + with pytest.raises(ValueError): + m = FlavorMatrix(array=np.ones(shape=(5,5)), flavor=TwoFlavor) + with pytest.raises(ValueError): + m = FlavorMatrix(array=np.ones(shape=(5,4)), flavor=TwoFlavor) + + @staticmethod + def test_conversion_matrices_for_same_flavor_are_unity(): + for flavor in [TwoFlavor,ThreeFlavor,FourFlavor]: + matrix = flavor>>flavor + assert isinstance(matrix, FlavorMatrix) + assert np.allclose(matrix.array, np.eye(len(flavor))) + + @staticmethod + @pytest.mark.parametrize('flavor_in',flavor_schemes) + @pytest.mark.parametrize('flavor_out',flavor_schemes) + def test_conversion_matrices(flavor_in, flavor_out): + M = flavor_in>>flavor_out + assert M==flavor_out<