Skip to content

Commit

Permalink
Merge branch 'sensitivity-evolution' of github.com:jeremy-baier/hasas…
Browse files Browse the repository at this point in the history
…ia into sensitivity-evolution
  • Loading branch information
jeremy-baier committed Dec 24, 2024
2 parents 9dafb9b + 6fafccc commit 158b4db
Show file tree
Hide file tree
Showing 3 changed files with 292 additions and 98 deletions.
39 changes: 26 additions & 13 deletions hasasia/sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import scipy.linalg as sl
import os, pickle
from astropy import units as u
import jax.numpy as jnp
import jax.scipy as jsc

import hasasia
from .utils import create_design_matrix
Expand Down Expand Up @@ -62,14 +64,14 @@ def R_matrix(designmatrix, N):
n,m = M.shape
L = np.linalg.cholesky(N)
Linv = np.linalg.inv(L)
U,s,_ = np.linalg.svd(np.matmul(Linv,M), full_matrices=True)
U,s,_ = np.linalg.svd(jnp.matmul(Linv,M), full_matrices=True)
Id = np.eye(M.shape[0])
S = np.zeros_like(M)
S[:m,:m] = np.diag(s)
inner = np.linalg.inv(np.matmul(S.T,S))
outer = np.matmul(S,np.matmul(inner,S.T))
inner = np.linalg.inv(jnp.matmul(S.T,S))
outer = jnp.matmul(S,jnp.matmul(inner,S.T))

return Id - np.matmul(L,np.matmul(np.matmul(U,outer),np.matmul(U.T,Linv)))
return Id - jnp.matmul(L,jnp.matmul(jnp.matmul(U,outer),jnp.matmul(U.T,Linv)))

def G_matrix(designmatrix):
"""
Expand All @@ -96,7 +98,7 @@ def get_Tf(designmatrix, toas, N=None, nf=200, fmin=None, fmax=2e-7,
freqs=None, exact_astro_freqs = False,
from_G=True, twofreqs=False, Gmatrix=None):
"""
Calculate the transmission function for a given pulsar design matrix, TOAs
the transmission function for a given pulsar design matrix, TOAs
and TOA errors.
Parameters
Expand Down Expand Up @@ -169,7 +171,7 @@ def get_Tf(designmatrix, toas, N=None, nf=200, fmin=None, fmax=2e-7,
m = G.shape[1]
Gtilde = np.zeros((ff.size,G.shape[1]),dtype='complex128')
Gtilde = np.dot(np.exp(1j*2*np.pi*ff[:,np.newaxis]*toas),G)
Tmat = np.matmul(np.conjugate(Gtilde),Gtilde.T)/N_TOA
Tmat = jnp.matmul(np.conjugate(Gtilde),Gtilde.T)/N_TOA
if twofreqs:
Tmat = np.real(Tmat)
else:
Expand Down Expand Up @@ -261,10 +263,14 @@ def get_NcalInv(psr, nf=200, fmin=None, fmax=2e-7, freqs=None,
Gtilde = np.dot(np.exp(1j*2*np.pi*ff[:,np.newaxis]*toas),G)
# N_freq x N_TOA-N_par

Ncal = np.matmul(G.T,np.matmul(psr.N,G)) #N_TOA-N_par x N_TOA-N_par
NcalInv = np.linalg.inv(Ncal) #N_TOA-N_par x N_TOA-N_par

TfN = np.matmul(np.conjugate(Gtilde),np.matmul(NcalInv,Gtilde.T)) / 2
L = jsc.linalg.cholesky(psr.N)
A = jnp.matmul(L,G)
del L
Ncal = jnp.matmul(A.T,A)
del A
NcalInv = jnp.linalg.inv(Ncal)

TfN = jnp.matmul(np.conjugate(Gtilde),jnp.matmul(NcalInv,Gtilde.T)) / 2
if return_Gtilde_Ncal:
return np.real(TfN), Gtilde, Ncal
elif full_matrix:
Expand Down Expand Up @@ -751,6 +757,13 @@ def to_pickle(self, filepath):
with open(filepath, "wb") as fout:
pickle.dump(self, fout)

def fidx(self,f):
"""Get the indices of a frequencies in the frequency array."""
if isinstance(f, int) or isinstance(f, float):
f = np.array([f])
f = np.asarray(f)
return np.array([np.argmin(abs(ff-self.freqs)) for ff in f])

@property
def S_eff(self):
"""Strain power sensitivity. """
Expand Down Expand Up @@ -983,7 +996,7 @@ def get_NcalInvIJ(psrs, A_GWB, freqs, full_matrix=False,
# C_h = sl.block_diag(*[corr_from_psd(freqs=freqs, psd=psd,
# toas=p.toas, fast=True) for p in psrs])
C = C_n + C_h
Ncal = np.matmul(G.T, np.matmul(C, G)) #N_TOA-N_par x N_TOA-N_par
Ncal = jnp.matmul(G.T, jnp.matmul(C, G)) #N_TOA-N_par x N_TOA-N_par
NcalInv = np.linalg.inv(Ncal) #N_TOA-N_par x N_TOA-N_par

TfN = NcalInv#np.matmul(G, np.matmul(NcalInv, G.T))
Expand Down Expand Up @@ -1243,7 +1256,7 @@ def corr_from_psd(freqs, psd, toas, fast=True):
df = np.diff(freqs)
df = np.append(df,df[-1])
tm = np.sqrt(psd*df)*np.exp(1j*2*np.pi*freqs*toas[:,np.newaxis])
integrand = np.matmul(tm, np.conjugate(tm.T))
integrand = jnp.matmul(tm, np.conjugate(tm.T))
return np.real(integrand)
else: #Makes much larger arrays, but uses np.trapz
t1, t2 = np.meshgrid(toas, toas, indexing='ij')
Expand Down Expand Up @@ -1284,7 +1297,7 @@ def corr_from_psdIJ(freqs, psd, toasI, toasJ, fast=True):
df = np.append(df,df[-1])
tmI = np.sqrt(psd*df)*np.exp(1j*2*np.pi*freqs*toasI[:,np.newaxis])
tmJ = np.sqrt(psd*df)*np.exp(1j*2*np.pi*freqs*toasJ[:,np.newaxis])
integrand = np.matmul(tmI, np.conjugate(tmJ.T))
integrand = jnp.matmul(tmI, np.conjugate(tmJ.T))
return np.real(integrand)
else: #Makes much larger arrays, but uses np.trapz
t1, t2 = np.meshgrid(toasI, toasJ, indexing='ij')
Expand Down
Loading

0 comments on commit 158b4db

Please sign in to comment.