Skip to content

Commit

Permalink
Merge pull request #324 from SNEWS2/Sheshuk/Flavor_Matrix_implementation
Browse files Browse the repository at this point in the history
Adding the implementation of the FlavorScheme and FlavorMatrix
  • Loading branch information
Sheshuk authored Jun 7, 2024
2 parents 262cf2a + 5b2fb84 commit 65bc4d5
Show file tree
Hide file tree
Showing 6 changed files with 408 additions and 50 deletions.
171 changes: 171 additions & 0 deletions python/snewpy/flavor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import enum
import numpy as np
import typing

class EnumMeta(enum.EnumMeta):
def __getitem__(cls, key):
#if this is an iterable: apply to each value, and construct a tuple
if isinstance(key, slice):
return slice(cls[key.start],cls[key.stop],key.step)
if isinstance(key, typing.Iterable) and not isinstance(key, str):
return tuple(map(cls.__getitem__, key))
#if this is from a flavor scheme: check that it's ours
if isinstance(key, FlavorScheme):
if not isinstance(key, cls):
raise TypeError(f'Value {repr(key)} is not from {cls.__name__} sheme!')
return key
#if this is a string find it by name
if isinstance(key, str):
try:
return super().__getitem__(key)
except KeyError as e:
raise KeyError(
f'Cannot find key "{key}" in {cls.__name__} sheme! Valid options are {list(cls)}'
)

if key is None:
return None
#if this is anything else - treat it as a slice
return np.array(list(cls.__members__.values()),dtype=object)[key]

class FlavorScheme(enum.IntEnum, metaclass=EnumMeta):
def to_tex(self):
"""LaTeX-compatible string representations of flavor."""
base = r'\nu'
if self.is_antineutrino:
base = r'\overline{\nu}'
lepton = self.lepton.lower()
if self.is_muon or self.is_tauon:
lepton = '\\'+lepton
return f"${base}_{{{lepton}}}$"

@property
def is_neutrino(self):
return not self.is_antineutrino

@property
def is_antineutrino(self):
return '_BAR' in self.name

@property
def is_electron(self):
return self.lepton=='E'

@property
def is_muon(self):
return self.lepton=='MU'

@property
def is_tauon(self):
return self.lepton=='TAU'

@property
def is_sterile(self):
return self.lepton=='S'

@property
def lepton(self):
return self.name.split('_')[1]

@classmethod
def from_lepton_names(cls, name:str, leptons:list):
enum_class = cls(name, start=0, names = [f'NU_{L}{BAR}' for L in leptons for BAR in ['','_BAR']])
return enum_class

@classmethod
def take(cls, index):
return cls[index]

TwoFlavor = FlavorScheme.from_lepton_names('TwoFlavor',['E','X'])
ThreeFlavor = FlavorScheme.from_lepton_names('ThreeFlavor',['E','MU','TAU'])
FourFlavor = FlavorScheme.from_lepton_names('FourFlavor',['E','MU','TAU','S'])

class FlavorMatrix:
def __init__(self,
array:np.ndarray,
flavor:FlavorScheme,
from_flavor:FlavorScheme = None
):
self.array = np.asarray(array)
self.flavor_out = flavor
self.flavor_in = from_flavor or flavor
expected_shape = (len(self.flavor_out), len(self.flavor_in))
if(self.array.shape != expected_shape):
raise ValueError(f"FlavorMatrix array shape {self.array.shape} mismatch expected {expected_shape}")

def _convert_index(self, index):
if isinstance(index, str) or (not isinstance(index,typing.Iterable)):
index = [index]
new_idx = [flavors[idx] for idx,flavors in zip(index, [self.flavor_out, self.flavor_in])]
return tuple(new_idx)

def __getitem__(self, index):
return self.array[self._convert_index(index)]

def __setitem__(self, index, value):
self.array[self._convert_index(index)] = value

def _repr_short(self):
return f'{self.__class__.__name__}:<{self.flavor_in.__name__}->{self.flavor_out.__name__}> shape={self.shape}'

def __repr__(self):
s=self._repr_short()+'\n'+repr(self.array)
return s
def __eq__(self,other):
return self.flavor_in==other.flavor_in and self.flavor_out==other.flavor_out and np.allclose(self.array,other.array)

def __matmul__(self, other):
if isinstance(other, FlavorMatrix):
try:
data = np.tensordot(self.array, other.array, axes=[1,0])
return FlavorMatrix(data, self.flavor_out, from_flavor = other.flavor_in)
except Exception as e:
raise ValueError(f"Cannot multiply {self._repr_short()} by {other._repr_short()}") from e
elif hasattr(other, '__rmatmul__'):
return other.__rmatmul__(self)
raise TypeError(f"Cannot multiply object of {self.__class__} by {other.__class__}")
#properties
@property
def shape(self):
return self.array.shape
@property
def flavor(self):
return self.flavor_out

@classmethod
def eye(cls, flavor:FlavorScheme, from_flavor:FlavorScheme = None):
from_flavor = from_flavor or flavor
shape = (len(from_flavor), len(flavor))
data = np.eye(*shape)
return cls(data, flavor, from_flavor)

@classmethod
def from_function(cls, flavor:FlavorScheme, from_flavor:FlavorScheme = None):
"""A decorator for creating the flavor matrix from the given function"""
from_flavor = from_flavor or flavor
def _decorator(function):
data = [[function(f1,f2)
for f2 in from_flavor]
for f1 in flavor]

return cls(np.array(data,dtype=float), flavor, from_flavor)
return _decorator
#flavor conversion utils

def conversion_matrix(from_flavor:FlavorScheme, to_flavor:FlavorScheme):
@FlavorMatrix.from_function(to_flavor, from_flavor)
def convert(f1, f2):
if f1.name == f2.name:
return 1.
if (f1.is_neutrino == f2.is_neutrino) and (f2.lepton == 'X' and f1.lepton in ['MU', 'TAU']):
# convert from TwoFlavor to more flavors
return 1.
if (f1.is_neutrino == f2.is_neutrino) and (f1.lepton == 'X' and f2.lepton in ['MU', 'TAU']):
# convert from more flavors to TwoFlavor
return 0.5
return 0.
return convert

FlavorScheme.conversion_matrix = classmethod(conversion_matrix)
EnumMeta.__rshift__ = conversion_matrix
EnumMeta.__lshift__ = lambda f1,f2:conversion_matrix(f2,f1)
86 changes: 66 additions & 20 deletions python/snewpy/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
"""
from typing import Union, Optional, Set, List
from snewpy.neutrino import Flavor
#from snewpy.neutrino import Flavor
from snewpy.flavor import FlavorScheme, FlavorMatrix
from astropy import units as u

import numpy as np
Expand Down Expand Up @@ -86,11 +87,12 @@ class _ContainerBase:
unit = None
def __init__(self,
data: u.Quantity,
flavor: List[Flavor],
flavor: List[FlavorScheme],
time: u.Quantity[u.s],
energy: u.Quantity[u.MeV],
*,
integrable_axes: Optional[Set[Axes]] = None
integrable_axes: Optional[Set[Axes]] = None,
flavor_scheme:Optional[FlavorScheme] = None
):
"""A container class storing the physical quantity (flux, fluence, rate...), which depends on flavor, time and energy.
Expand All @@ -112,7 +114,11 @@ def __init__(self,
integrable_axes: set of :class:`Axes` or None
List of axes which can be integrated.
If None (default) this set will be derived from the axes shapes
If None (default) this set will be derived from the axes shapes
flavor_scheme: a subclass of :class:`snewpy.flavor.FlavorSchemes` or None
A class which lists all the allowed flavors.
If None (default) this value will be retrieved from the ``flavor`` arguemnt.
"""
if self.unit is not None:
#try to convert to the unit
Expand All @@ -121,8 +127,20 @@ def __init__(self,
self.array = u.Quantity(data)
self.time = u.Quantity(time, ndmin=1)
self.energy = u.Quantity(energy, ndmin=1)
self.flavor = np.sort(np.array(flavor, ndmin=1))

self.flavor = np.array(flavor,ndmin=1, dtype=object)
self.flavor_scheme = flavor_scheme
if not flavor_scheme:
#guess the flavor scheme
if isinstance(flavor, type):#issubclass without isinstance(type) check raises TypeError
if issubclass(flavor, FlavorScheme):
self.flavor_scheme = flavor
else:
#get schemes from the data
flavor_schemes = set(f.__class__ for f in self.flavor)
if len(flavor_schemes)!=1:
raise ValueError(f"Flavors {flavor} must be from a single flavor scheme, but are from {flavor_schemes}")
else:
self.flavor_scheme = flavor_schemes.pop()
Nf,Nt,Ne = len(self.flavor), len(self.time), len(self.energy)
#list all valid shapes of the input array
expected_shapes=[(nf,nt,ne) for nf in (Nf,Nf-1) for nt in (Nt,Nt-1) for ne in (Ne,Ne-1)]
Expand Down Expand Up @@ -162,21 +180,28 @@ def shape(self):

def __getitem__(self, args)->'Container':
"""Slice the flux array and produce a new Flux object"""
try:
iter(args)
except TypeError:
if not isinstance(args,tuple):
args = [args]
args = [a if isinstance(a, slice) else slice(a, a + 1) for a in args]
#expand args to match axes
args+=[slice(None)]*(len(Axes)-len(args))
array = self.array.__getitem__(tuple(args))
newaxes = [ax.__getitem__(arg) for arg, ax in zip(args, self.axes)]
return self.__class__(array, *newaxes)
args = list(args)
arg_slices = [slice(None)]*len(Axes)
if isinstance(args[0],str) or isinstance(args[0],FlavorScheme):
args[0] = self.flavor_scheme[args[0]]
for n,arg in enumerate(args):
if not isinstance(arg, slice):
arg = slice(arg, arg + 1)
arg_slices[n] = arg

array = self.array.__getitem__(tuple(arg_slices))
newaxes = [ax.__getitem__(arg) for arg, ax in zip(arg_slices, self.axes)]
return self.__class__(array, *newaxes, flavor_scheme=self.flavor_scheme)

def __repr__(self) -> str:
"""print information about the container"""
s = [f"{len(values)} {label.name}({values.min()};{values.max()})"
for label, values in zip(Axes,self.axes)]
if label!=Axes.flavor
else f"{len(values)} {label.name}[{self.flavor_scheme}]({values.min()};{values.max()})"
for label, values in zip(Axes,self.axes)
]
return f"{self.__class__.__name__} {self.array.shape} [{self.array.unit}]: <{' x '.join(s)}>"

def sum(self, axis: Union[Axes,str])->'Container':
Expand Down Expand Up @@ -317,7 +342,7 @@ def __mul__(self, factor) -> 'Container':
array = self.array*factor
axes = list(self.axes)
return Container(array, *axes)

def save(self, fname:str)->None:
"""Save container data to a given file (using `numpy.savez`)"""
def _save_quantity(name):
Expand All @@ -329,8 +354,9 @@ def _save_quantity(name):
except:
return {name:values}
data_dict = {}
for name in ['array','time','energy','flavor']:
for name in ['array','time','energy']:
data_dict.update(_save_quantity(name))
data_dict['flavor'] = np.array(self.flavor, dtype=object)
np.savez(fname,
_class_name=self.__class__.__name__,
**data_dict,
Expand All @@ -340,7 +366,7 @@ def _save_quantity(name):
@classmethod
def load(cls, fname:str)->'Container':
"""Load container from a given file"""
with np.load(fname) as f:
with np.load(fname, allow_pickle=True) as f:
def _load_quantity(name):
array = f[name]
try:
Expand All @@ -361,9 +387,29 @@ def __eq__(self, other:'Container')->bool:
result = self.__class__==other.__class__ and \
self.unit == other.unit and \
np.allclose(self.array, other.array) and \
all([np.allclose(self.axes[ax], other.axes[ax]) for ax in Axes])
self.flavor_scheme==other.flavor_scheme and \
len(self.flavor)==len(other.flavor) and \
all(self.flavor==other.flavor) and \
all([np.allclose(self.axes[ax], other.axes[ax]) for ax in list(Axes)[1:]])
return result

def _is_full_flavor(self):
return all(self.flavor==list(self.flavor_scheme))

def convert_to_flavor(self, flavor:FlavorScheme):
if(self.flavor_scheme==flavor):
return self
return (self.flavor_scheme>>flavor)@self
def __rshift__(self, flavor:FlavorScheme):
return self.convert_to_flavor(flavor)

def __rmatmul__(self, matrix:FlavorMatrix):
if not self._is_full_flavor():
raise RuntimeError(f"Cannot multiply flavor matrix object {self}, expected {len(self.flavor_scheme)} flavors")
if matrix.flavor_in!=self.flavor_scheme:
raise ValueError(f"Cannot multiply flavor matrix {matrix} by {self} - flavor scheme mismatch!")
array = np.tensordot(matrix.array,self.array, axes=[1,0])
return Container(array, flavor=matrix.flavor_out, time=self.time, energy=self.energy)
class Container(_ContainerBase):
#a dictionary holding classes for each unit
_unit_classes = {}
Expand Down
4 changes: 2 additions & 2 deletions python/snewpy/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ def get_flux (self, t, E, distance, flavor_xform=NoTransformation()):
factor = 1/(4*np.pi*(distance.to('cm'))**2)
f = self.get_transformed_spectra(t, E, flavor_xform)

array = np.stack([f[flv] for flv in sorted(Flavor)])
return Flux(data=array*factor, flavor=np.sort(Flavor), time=t, energy=E)
array = np.stack([f[flv] for flv in Flavor])
return Flux(data=array*factor, flavor=Flavor, time=t, energy=E)



Expand Down
27 changes: 1 addition & 26 deletions python/snewpy/neutrino.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional
import numpy as np
from collections.abc import Mapping
from .flavor import TwoFlavor as Flavor

class MassHierarchy(IntEnum):
"""Neutrino mass ordering: ``NORMAL`` or ``INVERTED``."""
Expand All @@ -23,32 +24,6 @@ def derive_from_dm2(cls, dm12_2, dm32_2, dm31_2):
else:
return MassHierarchy.INVERTED

class Flavor(IntEnum):
"""Enumeration of CCSN Neutrino flavors."""
NU_E = 0
NU_X = 1
NU_E_BAR = 2
NU_X_BAR = 3

def to_tex(self):
"""LaTeX-compatible string representations of flavor."""
if '_BAR' in self.name:
return r'$\overline{{\nu}}_{0}$'.format(self.name[3].lower())
return r'$\{0}$'.format(self.name.lower())

@property
def is_electron(self):
"""Return ``True`` for ``Flavor.NU_E`` and ``Flavor.NU_E_BAR``."""
return self.value in (Flavor.NU_E.value, Flavor.NU_E_BAR.value)

@property
def is_neutrino(self):
"""Return ``True`` for ``Flavor.NU_E`` and ``Flavor.NU_X``."""
return self.value in (Flavor.NU_E.value, Flavor.NU_X.value)

@property
def is_antineutrino(self):
return self.value in (Flavor.NU_E_BAR.value, Flavor.NU_X_BAR.value)

@dataclass
class MixingParameters3Flavor(Mapping):
Expand Down
Loading

0 comments on commit 65bc4d5

Please sign in to comment.