Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the implementation of the FlavorScheme and FlavorMatrix #324

Merged
merged 31 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
693d5eb
Adding the implementation of the FlavorScheme and FlavorMatrix
Sheshuk Apr 30, 2024
698b077
Fix
Sheshuk Apr 30, 2024
ae917a4
Adding 'FlavorMatrix.from_function' and '__matmul__' methods
Sheshuk Apr 30, 2024
024ec35
Using TwoFlavor as neutrino.Flavor
Sheshuk Apr 30, 2024
94c7b98
Fixing the to_tex
Sheshuk Apr 30, 2024
38a6879
fix typo
Sheshuk Apr 30, 2024
47ef94f
remove type annotation breaking python 3.8
Sheshuk Apr 30, 2024
e80bc23
Update type annotations
Sheshuk May 3, 2024
cd6e7fe
Using custom metaclass with __getitem__ method
Sheshuk May 3, 2024
b2e18d7
Adding tests for FlavorMatrix
Sheshuk May 4, 2024
f12c747
Add convenience properties for the matrix
Sheshuk May 4, 2024
c3130f9
Make conversion matrix dict
Sheshuk May 4, 2024
105d466
Cleanup the interface: change method and argument names etc
Sheshuk May 10, 2024
0efaf6f
Update tests
Sheshuk May 10, 2024
5f200ee
Implement slicing for the flavors
Sheshuk May 10, 2024
3a7e2f5
Using FlavorScheme in the flux.Container
Sheshuk May 10, 2024
e44422c
Remove trailing character
Sheshuk May 10, 2024
017faeb
Update the flavor checks
Sheshuk May 10, 2024
6f01cef
Update the flavor checks
Sheshuk May 10, 2024
d8cf667
Using >> and << operators for generating conversion matrices
Sheshuk May 10, 2024
15a1093
Deriving the flavor_scheme
Sheshuk May 10, 2024
168b22c
Fixing the chechs in getitem
Sheshuk May 10, 2024
ffaef43
Adding tests for access by value and by string
Sheshuk May 10, 2024
32992b2
Matrix mutiplication for the flux
Sheshuk May 10, 2024
f9a2ff8
Remove debug prints
Sheshuk May 10, 2024
b52f8ee
initialize flux with the flavor enum, not its sorted version
Sheshuk May 10, 2024
7b30054
Correcting the Three->Two flavor conversion
Sheshuk May 10, 2024
00aa2ea
Update python/snewpy/flavor.py
Sheshuk May 14, 2024
76f150a
Add tests for the FlavorMatrix.__getitem__
Sheshuk May 14, 2024
07a090d
Fixing the _convert_index and __getitem__
Sheshuk May 14, 2024
5b2fb84
Adding comment for the isinstance and issubclass checks
Sheshuk May 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions python/snewpy/flavor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import enum
import numpy as np
import typing

class EnumMeta(enum.EnumMeta):
def __getitem__(cls, key):
#if this is an iterable: apply to each value, and construct a tuple
if isinstance(key, typing.Iterable) and not isinstance(key, str):
return tuple(map(cls.__getitem__, key))
#if this is from a flavor scheme: check that it's ours
if isinstance(key, FlavorScheme):
if not isinstance(key, cls):
raise TypeError(f'Value {repr(key)} is not from {cls.__name__} sheme!')
return key
#if this is a string find it by name
if isinstance(key, str):
try:
return super().__getitem__(key)
except KeyError as e:
raise KeyError(
f'Cannot find key "{key}" in {cls.__name__} sheme! Valid options are {list(cls)}'
)

#if this is anything else - treat it as a slice
return np.array(list(cls.__members__.values()),dtype=object)[key]

class FlavorScheme(enum.IntEnum, metaclass=EnumMeta):
JostMigenda marked this conversation as resolved.
Show resolved Hide resolved
def to_tex(self):
"""LaTeX-compatible string representations of flavor."""
base = r'\nu'
if self.is_antineutrino:
base = r'\overline{\nu}'
lepton = self.lepton.lower()
if self.is_muon or self.is_tauon:
lepton = '\\'+lepton
return f"${base}_{{{lepton}}}$"

@property
def is_neutrino(self):
return not self.is_antineutrino

@property
def is_antineutrino(self):
return '_BAR' in self.name

@property
def is_electron(self):
return self.lepton=='E'

@property
def is_muon(self):
return self.lepton=='MU'

@property
def is_tauon(self):
return self.lepton=='TAU'

@property
def is_sterile(self):
return self.lepton=='S'

@property
def lepton(self):
return self.name.split('_')[1]

@classmethod
def from_lepton_names(cls, name:str, leptons:list):
enum_class = cls(name, start=0, names = [f'NU_{L}{BAR}' for L in leptons for BAR in ['','_BAR']])
return enum_class

@classmethod
def take(cls, index):
return cls[index]

TwoFlavor = FlavorScheme.from_lepton_names('TwoFlavor',['E','X'])
ThreeFlavor = FlavorScheme.from_lepton_names('ThreeFlavor',['E','MU','TAU'])
FourFlavor = FlavorScheme.from_lepton_names('FourFlavor',['E','MU','TAU','S'])

class FlavorMatrix:
def __init__(self,
array:np.ndarray,
flavor:FlavorScheme,
from_flavor:FlavorScheme = None
):
self.array = np.asarray(array)
self.flavor_out = flavor
self.flavor_in = from_flavor or flavor
expected_shape = (len(self.flavor_out), len(self.flavor_in))
if(self.array.shape != expected_shape):
raise ValueError(f"FlavorMatrix array shape {self.array.shape} mismatch expected {expected_shape}")

def _convert_index(self, index):
if isinstance(index, str) or (not isinstance(index,typing.Iterable)):
index = [index]
new_idx = [flavors[idx] for idx,flavors in zip(index, self.flavors)]
Sheshuk marked this conversation as resolved.
Show resolved Hide resolved
return tuple(new_idx)

def __getitem__(self, index):
return self.array[self._convert_index(index)]

def __setitem__(self, index, value):
self.array[self._convert_index(index)] = value

def _repr_short(self):
return f'{self.__class__.__name__}:<{self.flavor_in.__name__}->{self.flavor_out.__name__}> shape={self.shape}'

def __repr__(self):
s=self._repr_short()+'\n'+repr(self.array)
return s
def __eq__(self,other):
return self.flavor_in==other.flavor_in and self.flavor_out==other.flavor_out and np.allclose(self.array,other.array)

def __matmul__(self, other):
if isinstance(other, FlavorMatrix):
try:
data = np.tensordot(self.array, other.array, axes=[1,0])
return FlavorMatrix(data, self.flavor_out, from_flavor = other.flavor_in)
except Exception as e:
raise ValueError(f"Cannot multiply {self._repr_short()} by {other._repr_short()}") from e
elif hasattr(other, '__rmatmul__'):
return other.__rmatmul__(self)
raise TypeError(f"Cannot multiply object of {self.__class__} by {other.__class__}")
#properties
@property
def shape(self):
return self.array.shape
@property
def flavor(self):
return self.flavor_out

@classmethod
def zeros(cls, flavor:FlavorScheme, from_flavor:FlavorScheme = None):
from_flavor = from_flavor or flavor
shape = (len(from_flavor), len(flavor))
data = np.zeros(shape)
return cls(data, flavor, from_flavor)
Sheshuk marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def eye(cls, flavor:FlavorScheme, from_flavor:FlavorScheme = None):
from_flavor = from_flavor or flavor
shape = (len(from_flavor), len(flavor))
data = np.eye(*shape)
return cls(data, flavor, from_flavor)

@classmethod
def from_function(cls, flavor:FlavorScheme, from_flavor:FlavorScheme = None):
"""A decorator for creating the flavor matrix from the given function"""
from_flavor = from_flavor or flavor
def _decorator(function):
data = [[function(f1,f2)
for f2 in from_flavor]
for f1 in flavor]

return cls(np.array(data,dtype=float), flavor, from_flavor)
return _decorator
#flavor conversion utils

def conversion_matrix(from_flavor:FlavorScheme, to_flavor:FlavorScheme):
if(from_flavor==TwoFlavor):
#define special cases
@FlavorMatrix.from_function(to_flavor, from_flavor)
def convert_2toN(f1,f2):
if (f1.name==f2.name):
return 1.
if (f1.is_neutrino==f2.is_neutrino)and(f2.lepton=='X' and f1.lepton in ['MU','TAU']):
return 1.
return 0
return convert_2toN
else:
@FlavorMatrix.from_function(to_flavor, from_flavor)
def convert_Nto2(f1,f2):
if (f1.name==f2.name):
return 1.
if (f1.is_neutrino==f2.is_neutrino)and(f1.lepton=='X' and f2.lepton in ['MU','TAU']):
return 0.5
return 0.
return convert_Nto2
Sheshuk marked this conversation as resolved.
Show resolved Hide resolved

FlavorScheme.conversion_matrix = classmethod(conversion_matrix)
EnumMeta.__rshift__ = conversion_matrix
EnumMeta.__lshift__ = lambda f1,f2:conversion_matrix(f2,f1)
85 changes: 65 additions & 20 deletions python/snewpy/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@

"""
from typing import Union, Optional, Set, List
from snewpy.neutrino import Flavor
#from snewpy.neutrino import Flavor
from snewpy.flavor import FlavorScheme, FlavorMatrix
from astropy import units as u

import numpy as np
Expand Down Expand Up @@ -86,11 +87,12 @@ class _ContainerBase:
unit = None
def __init__(self,
data: u.Quantity,
flavor: List[Flavor],
flavor: List[FlavorScheme],
time: u.Quantity[u.s],
energy: u.Quantity[u.MeV],
*,
integrable_axes: Optional[Set[Axes]] = None
integrable_axes: Optional[Set[Axes]] = None,
flavor_scheme:Optional[FlavorScheme] = None
):
"""A container class storing the physical quantity (flux, fluence, rate...), which depends on flavor, time and energy.

Expand All @@ -112,7 +114,11 @@ def __init__(self,

integrable_axes: set of :class:`Axes` or None
List of axes which can be integrated.
If None (default) this set will be derived from the axes shapes
If None (default) this set will be derived from the axes shapes

flavor_scheme: a subclass of :class:`snewpy.flavor.FlavorSchemes` or None
A class which lists all the allowed flavors.
If None (default) this value will be retrieved from the ``flavor`` arguemnt.
"""
if self.unit is not None:
#try to convert to the unit
Expand All @@ -121,8 +127,19 @@ def __init__(self,
self.array = u.Quantity(data)
self.time = u.Quantity(time, ndmin=1)
self.energy = u.Quantity(energy, ndmin=1)
self.flavor = np.sort(np.array(flavor, ndmin=1))

self.flavor = np.array(flavor,ndmin=1, dtype=object)
self.flavor_scheme = flavor_scheme
if not flavor_scheme:
#guess the flavor scheme
if isinstance(flavor, type) and issubclass(flavor, FlavorScheme):
Sheshuk marked this conversation as resolved.
Show resolved Hide resolved
self.flavor_scheme = flavor
else:
#get schemes from the data
flavor_schemes = set(f.__class__ for f in self.flavor)
if len(flavor_schemes)!=1:
raise ValueError(f"Flavors {flavor} must be from a single flavor scheme, but are from {flavor_schemes}")
else:
self.flavor_scheme = flavor_schemes.pop()
Nf,Nt,Ne = len(self.flavor), len(self.time), len(self.energy)
#list all valid shapes of the input array
expected_shapes=[(nf,nt,ne) for nf in (Nf,Nf-1) for nt in (Nt,Nt-1) for ne in (Ne,Ne-1)]
Expand Down Expand Up @@ -162,21 +179,28 @@ def shape(self):

def __getitem__(self, args)->'Container':
"""Slice the flux array and produce a new Flux object"""
try:
iter(args)
except TypeError:
if not isinstance(args,tuple):
args = [args]
args = [a if isinstance(a, slice) else slice(a, a + 1) for a in args]
#expand args to match axes
args+=[slice(None)]*(len(Axes)-len(args))
array = self.array.__getitem__(tuple(args))
newaxes = [ax.__getitem__(arg) for arg, ax in zip(args, self.axes)]
return self.__class__(array, *newaxes)
args = list(args)
arg_slices = [slice(None)]*len(Axes)
if isinstance(args[0],str) or isinstance(args[0],FlavorScheme):
args[0] = self.flavor_scheme[args[0]]
for n,arg in enumerate(args):
if not isinstance(arg, slice):
arg = slice(arg, arg + 1)
arg_slices[n] = arg

array = self.array.__getitem__(tuple(arg_slices))
newaxes = [ax.__getitem__(arg) for arg, ax in zip(arg_slices, self.axes)]
return self.__class__(array, *newaxes, flavor_scheme=self.flavor_scheme)

def __repr__(self) -> str:
"""print information about the container"""
s = [f"{len(values)} {label.name}({values.min()};{values.max()})"
for label, values in zip(Axes,self.axes)]
if label!=Axes.flavor
else f"{len(values)} {label.name}[{self.flavor_scheme}]({values.min()};{values.max()})"
for label, values in zip(Axes,self.axes)
]
return f"{self.__class__.__name__} {self.array.shape} [{self.array.unit}]: <{' x '.join(s)}>"

def sum(self, axis: Union[Axes,str])->'Container':
Expand Down Expand Up @@ -317,7 +341,7 @@ def __mul__(self, factor) -> 'Container':
array = self.array*factor
axes = list(self.axes)
return Container(array, *axes)

def save(self, fname:str)->None:
"""Save container data to a given file (using `numpy.savez`)"""
def _save_quantity(name):
Expand All @@ -329,8 +353,9 @@ def _save_quantity(name):
except:
return {name:values}
data_dict = {}
for name in ['array','time','energy','flavor']:
for name in ['array','time','energy']:
data_dict.update(_save_quantity(name))
data_dict['flavor'] = np.array(self.flavor, dtype=object)
np.savez(fname,
_class_name=self.__class__.__name__,
**data_dict,
Expand All @@ -340,7 +365,7 @@ def _save_quantity(name):
@classmethod
def load(cls, fname:str)->'Container':
"""Load container from a given file"""
with np.load(fname) as f:
with np.load(fname, allow_pickle=True) as f:
def _load_quantity(name):
array = f[name]
try:
Expand All @@ -361,9 +386,29 @@ def __eq__(self, other:'Container')->bool:
result = self.__class__==other.__class__ and \
self.unit == other.unit and \
np.allclose(self.array, other.array) and \
all([np.allclose(self.axes[ax], other.axes[ax]) for ax in Axes])
self.flavor_scheme==other.flavor_scheme and \
len(self.flavor)==len(other.flavor) and \
all(self.flavor==other.flavor) and \
all([np.allclose(self.axes[ax], other.axes[ax]) for ax in list(Axes)[1:]])
return result

def _is_full_flavor(self):
return all(self.flavor==list(self.flavor_scheme))

def convert_to_flavor(self, flavor:FlavorScheme):
if(self.flavor_scheme==flavor):
return self
return (self.flavor_scheme>>flavor)@self
def __rshift__(self, flavor:FlavorScheme):
return self.convert_to_flavor(flavor)
Sheshuk marked this conversation as resolved.
Show resolved Hide resolved

def __rmatmul__(self, matrix:FlavorMatrix):
if not self._is_full_flavor():
raise RuntimeError(f"Cannot multiply flavor matrix object {self}, expected {len(self.flavor_scheme)} flavors")
if matrix.flavor_in!=self.flavor_scheme:
raise ValueError(f"Cannot multiply flavor matrix {matrix} by {self} - flavor scheme mismatch!")
array = np.tensordot(matrix.array,self.array, axes=[1,0])
return Container(array, flavor=matrix.flavor_out, time=self.time, energy=self.energy)
class Container(_ContainerBase):
#a dictionary holding classes for each unit
_unit_classes = {}
Expand Down
4 changes: 2 additions & 2 deletions python/snewpy/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ def get_flux (self, t, E, distance, flavor_xform=NoTransformation()):
factor = 1/(4*np.pi*(distance.to('cm'))**2)
f = self.get_transformed_spectra(t, E, flavor_xform)

array = np.stack([f[flv] for flv in sorted(Flavor)])
return Flux(data=array*factor, flavor=np.sort(Flavor), time=t, energy=E)
array = np.stack([f[flv] for flv in Flavor])
return Flux(data=array*factor, flavor=Flavor, time=t, energy=E)



Expand Down
27 changes: 1 addition & 26 deletions python/snewpy/neutrino.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional
import numpy as np
from collections.abc import Mapping
from .flavor import TwoFlavor as Flavor

class MassHierarchy(IntEnum):
"""Neutrino mass ordering: ``NORMAL`` or ``INVERTED``."""
Expand All @@ -23,32 +24,6 @@ def derive_from_dm2(cls, dm12_2, dm32_2, dm31_2):
else:
return MassHierarchy.INVERTED

class Flavor(IntEnum):
"""Enumeration of CCSN Neutrino flavors."""
NU_E = 0
NU_X = 1
NU_E_BAR = 2
NU_X_BAR = 3

def to_tex(self):
"""LaTeX-compatible string representations of flavor."""
if '_BAR' in self.name:
return r'$\overline{{\nu}}_{0}$'.format(self.name[3].lower())
return r'$\{0}$'.format(self.name.lower())

@property
def is_electron(self):
"""Return ``True`` for ``Flavor.NU_E`` and ``Flavor.NU_E_BAR``."""
return self.value in (Flavor.NU_E.value, Flavor.NU_E_BAR.value)

@property
def is_neutrino(self):
"""Return ``True`` for ``Flavor.NU_E`` and ``Flavor.NU_X``."""
return self.value in (Flavor.NU_E.value, Flavor.NU_X.value)

@property
def is_antineutrino(self):
return self.value in (Flavor.NU_E_BAR.value, Flavor.NU_X_BAR.value)

@dataclass
class MixingParameters3Flavor(Mapping):
Expand Down
Loading