Skip to content

Commit

Permalink
refactoring correction and image generation code
Browse files Browse the repository at this point in the history
  • Loading branch information
ramanakumars committed Oct 24, 2023
1 parent cd1917a commit 8bec8b8
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 59 deletions.
78 changes: 36 additions & 42 deletions projection/frameletdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


class FrameletData:
def __init__(self, metadata, imgfolder):
def __init__(self, metadata: dict, imgfolder: str):
self.nframelets = int(metadata["LINES"] / FRAME_HEIGHT)
# number of RGB frames
self.nframes = int(self.nframelets / 3)
Expand Down Expand Up @@ -64,71 +64,63 @@ def __init__(self, metadata, imgfolder):

self.framelets.append(Framelet(self.start_et, frame_delay, n, c, framei))

def get_backplane(self, num_procs):
def get_backplane(self, num_procs: int):
tmid = np.mean([frame.et for frame in self.framelets])

'''
inpargs = [(frame, tmid) for frame in self.framelets]
with multiprocessing.Pool(
processes=num_procs, initializer=initializer
) as pool:
try:
pool.starmap(Framelet.project_to_midplane, tqdm.tqdm(inpargs))
pool.close()
except KeyboardInterrupt:
pool.terminate()
pool.join()
sys.exit()
pool.join()
'''
for frame in tqdm.tqdm(self.framelets):
frame.project_to_midplane(tmid)

def update_jitter(self, jitter):
if num_procs > 1:
inpargs = [(frame, tmid) for frame in self.framelets]
with multiprocessing.Pool(
processes=num_procs, initializer=initializer
) as pool:
try:
with tqdm.tqdm(inpargs, desc='Projecting framelets') as pbar:
self.framelets = pool.starmap(Framelet.project_to_midplane, pbar)
pool.close()
except KeyboardInterrupt:
pool.terminate()
sys.exit()
finally:
pool.join()
else:
for frame in tqdm.tqdm(self.framelets):
frame.project_to_midplane(tmid)

def update_jitter(self, jitter: float):
self.jitter = jitter
for framelet in self.framelets:
framelet.jitter = jitter

def update_image(self, new_image):
assert new_image.shape == self.image.shape, f"New image must be the same shape as the old one. Got {new_image.shape} instead of {self.image.shape}"

# first back up the original
self.image_old = self.image
# then store the new image data
self.image[:] = new_image[:]

@property
def tmid(self):
def tmid(self) -> float:
return np.mean([frame.et for frame in self.framelets])

@property
def image(self):
def image(self) -> np.ndarray:
return np.stack([frame.image for frame in self.framelets], axis=0).reshape((self.nframes, 3, FRAME_HEIGHT, FRAME_WIDTH))

@property
def coords(self):
def coords(self) -> np.ndarray:
return np.stack([frame.coords for frame in self.framelets], axis=0).reshape((self.nframes, 3, FRAME_HEIGHT, FRAME_WIDTH, 2))

@property
def emission(self):
def emission(self) -> np.ndarray:
return np.stack([frame.emission for frame in self.framelets], axis=0).reshape((self.nframes, 3, FRAME_HEIGHT, FRAME_WIDTH))

@property
def incidence(self):
def incidence(self) -> np.ndarray:
return np.stack([frame.incidence for frame in self.framelets], axis=0).reshape((self.nframes, 3, FRAME_HEIGHT, FRAME_WIDTH))

@property
def longitude(self):
def longitude(self) -> np.ndarray:
return np.stack([frame.lon for frame in self.framelets], axis=0).reshape((self.nframes, 3, FRAME_HEIGHT, FRAME_WIDTH))

@property
def latitude(self):
def latitude(self) -> np.ndarray:
return np.stack([frame.lat for frame in self.framelets], axis=0).reshape((self.nframes, 3, FRAME_HEIGHT, FRAME_WIDTH))


class Framelet:
def __init__(self, start_et, frame_delay, frame_no, color, img):
def __init__(self, start_et: float, frame_delay: float, frame_no: int, color: int, img: np.ndarray):
self.start_et = start_et
self.frame_delay = frame_delay
self.frame_no = frame_no
Expand All @@ -138,7 +130,7 @@ def __init__(self, start_et, frame_delay, frame_no, color, img):
self.camera = CameraModel(color)

@property
def et(self):
def et(self) -> float:
jitter = 0 if self.jitter is None else self.jitter
return (
self.start_et
Expand All @@ -147,7 +139,7 @@ def et(self):
+ (self.frame_delay + self.camera.iframe_delay) * self.frame_no
)

def project_to_midplane(self, tmid):
def project_to_midplane(self, tmid: float) -> None:
coords = np.nan * np.zeros((FRAME_HEIGHT, FRAME_WIDTH, 2))
lat = np.nan * np.zeros((FRAME_HEIGHT, FRAME_WIDTH))
lon = np.nan * np.zeros((FRAME_HEIGHT, FRAME_WIDTH))
Expand All @@ -169,6 +161,8 @@ def project_to_midplane(self, tmid):
# geometry correction
self.fluxcal = fluxcal

return self


# for decompanding -- taken from Kevin Gill's github page
SQROOT = np.array(
Expand Down Expand Up @@ -197,7 +191,7 @@ def project_to_midplane(self, tmid):
)


def decompand(image):
def decompand(image: np.ndarray) -> np.ndarray:
"""
Decompands the image from the 8-bit in the public release
to the original 12-bit shot by JunoCam
Expand All @@ -215,7 +209,7 @@ def decompand(image):
data = np.array(255 * image, dtype=np.uint8)
ny, nx = data.shape

def get_sqrt(x):
def get_sqrt(x: float) -> float:
return SQROOT[x]
v_get_sqrt = np.vectorize(get_sqrt)

Expand Down
53 changes: 53 additions & 0 deletions projection/mosaic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import numpy as np
import healpy as hp
from skimage import color
import tqdm


class Mosaic:
def __init__(self, fnames, n_side=512):
self.fnames = fnames
self.nframes = len(fnames)
self.n_side = n_side

def load_data(self):
self.maps = np.zeros((len(self.fnames), hp.nside2npix(self.n_side), 3))

for i, fname in enumerate(tqdm.tqdm(self.fnames)):
map = np.load(fname)
if hp.nside2npix(self.n_side) != map.shape[0]:
map = hp.ud_grade(map.T, self.n_side).T

map_lab = color.rgb2lab(map)

# clean up the fringes. these are negative a* and b* in the CIELAB space
map_lab[(map_lab[:, 1] < -0.001) | (map_lab[:, 2] < -0.001), 0] = 0

self.maps[i, :] = color.lab2rgb(map_lab)
self.maps[~np.isfinite(self.maps)] = 0.

def stack(self, radius=2):
m_normed = self.maps.copy()
# first, norma
for j in range(self.nframes):
m_normed[j, :] = m_normed[j, :] / np.percentile(m_normed[j, :], 99)

count = np.sum(np.min(m_normed, axis=-1) > 1.e-6, axis=0)
m_hsv = color.rgb2hsv(m_normed)
v_ave = np.mean(m_hsv[:, count > 0, 2])

pix_inds = np.where(count > 0)[0]
v_loc = np.zeros_like(m_normed[:, :, 0])
vecs = np.asarray(hp.pix2vec(self.n_side, ipix=pix_inds)).T

for i, vec in tqdm.tqdm(zip(pix_inds, vecs), total=len(pix_inds)):
neighbours = hp.query_disc(self.n_side, vec=vec, radius=np.radians(radius))
m_hsv_neigh = m_hsv[:, neighbours, 2]
v_loc[:, i] = np.mean(m_hsv_neigh[:, count[neighbours] > 0], axis=1)

m_new = np.zeros_like(m_normed[0, :])
for i in range(3):
m_new[:, i] = np.sum(m_normed[:, :, i] * (v_ave / v_loc), axis=0) / count
m_new[~np.isfinite(m_new)] = 0

return m_new
2 changes: 1 addition & 1 deletion projection/project.c
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ void project_midplane(double eti, int cam, double tmid, double *lon,
pixvec[2] / sqrtf(pixvec[0] * pixvec[0] + pixvec[1] * pixvec[1] +
pixvec[2] * pixvec[2]);
fluxcal[jj * FRAME_WIDTH + ii] =
(M_PI / 4.) *
(M_PI / 4.) * pow(disti, 2.) *
pow((aperture / focal_length) * cosalpha * cosalpha, 2);
}
}
Expand Down
83 changes: 68 additions & 15 deletions projection/projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@
import matplotlib.pyplot as plt
import spiceypy as spice
import tqdm
import multiprocessing
import time
import sys
from .globals import FRAME_HEIGHT, FRAME_WIDTH, initializer
from .cython_utils import furnish_c, project_midplane_c, get_pixel_from_coords_c
from .cython_utils import furnish_c, get_pixel_from_coords_c
from .camera_funcs import CameraModel
from .spice_utils import get_kernels
from .frameletdata import FrameletData
Expand Down Expand Up @@ -188,23 +184,35 @@ def get_limb(self, eti, cami):

return limbs_jcam

def process(self, nside=512, num_procs=8, apply_LS=True, n_neighbor=5):
def process(self, nside=512, num_procs=8, apply_correction='minneart', n_neighbor=5, minneart_k=1.25):
print(f"Projecting {self.fname} to a HEALPix grid with n_side={nside}")

self.framedata.get_backplane(num_procs)

if apply_LS:
self.framedata.update_image(
apply_lommel_seeliger(self.framedata.image, self.framedata.incidence, self.framedata.emission)
)

coords_new = np.transpose(self.framedata.coords, (1, 0, 2, 3, 4)).reshape(3, -1, 2)
imgvals_new = np.transpose(self.framedata.image, (1, 0, 2, 3)).reshape(3, -1)
self.apply_correction(apply_correction, minneart_k)

map = self.project_to_healpix(nside, coords_new, imgvals_new, n_neighbor=n_neighbor)
map = self.project_to_healpix(nside, self.framecoords, self.imagevalues, n_neighbor=n_neighbor)

return map

def apply_correction(self, correction_type, minneart_k=1.25):
if correction_type == 'ls':
print("Applying Lommel-Seeliger correction")
for frame in self.framedata.framelets:
frame.image = apply_lommel_seeliger(frame.image, frame.incidence, frame.emission)
elif correction_type == 'minneart':
print("Applying Minneart correction")
for frame in self.framedata.framelets:
frame.image = apply_minneart(frame.image, frame.incidence, frame.emission, k=minneart_k)

@property
def framecoords(self):
return np.transpose(self.framedata.coords, (1, 0, 2, 3, 4)).reshape(3, -1, 2)

@property
def imagevalues(self):
return np.transpose(self.framedata.image, (1, 0, 2, 3)).reshape(3, -1)

def project_to_healpix(self, nside, coords, imgvals, n_neighbor=4):
# get the image extents in pixel coordinate space
# clip half a pixel to avoid edge artifacts
Expand Down Expand Up @@ -242,7 +250,6 @@ def apply_lommel_seeliger(imgvals, incidence, emission):
'''
Apply the Lommel-Seeliger correction for incidence
'''
print("Applying Lommel-Seeliger correction")
# apply Lommel-Seeliger correction
mu0 = np.cos(incidence)
mu = np.cos(emission)
Expand All @@ -253,13 +260,58 @@ def apply_lommel_seeliger(imgvals, incidence, emission):
return imgvals


def apply_minneart(imgvals, incidence, emission, k=1.25):
# apply Minneart correction
mu0 = np.cos(incidence)
mu = np.cos(emission)
corr = (mu ** k) * (mu0 ** (k - 1))
# log(mu * mu0) < -4 is usually pretty noisy
corr[np.log(np.cos(incidence) * np.cos(emission)) < -4] = np.inf
imgvals = imgvals / corr

return imgvals


def create_image_from_grid(coords, imgvals, pix, inds, img_shape, n_neighbor=5, min_dist=25.):
'''
Reproject an irregular spaced image onto a regular grid from a list of coordinate
locations and corresponding image values. This uses an inverse lookup-table defined
by `pix`, where pix gives the coordinates in the original image where the corresponding
pixel coordinate on the new image should be. The coordinate on the new image is given by
the inds variable.
'''
nchannels, ncoords, _ = coords.shape
mask = np.ones(coords.shape[1], dtype=bool)
for n in range(nchannels):
maski = np.isfinite(coords[n, :, 0] * coords[n, :, 1])
mask = mask & maski

newvals = np.zeros((nchannels, pix.shape[0]))
print("Calculating image values at new locations")
for n in range(nchannels):
neighbors = NearestNeighbors().fit(coords[n][mask])
dist, indi = neighbors.kneighbors(pix, n_neighbor)
weight = 1. / (dist + 1.e-16)
weight = weight / np.sum(weight, axis=1, keepdims=True)
weight[dist > min_dist] = 0.

newvals[n, :] = np.sum(np.take(imgvals[n][mask], indi, axis=0) * weight, axis=1)

IMG = np.zeros((*img_shape, nchannels))
# loop through each point observed by JunoCam and assign the pixel value
for k, ind in enumerate(tqdm.tqdm(inds, desc='Building image')):
if len(img_shape) == 2:
j, i = np.unravel_index(ind, img_shape)

# do the weighted average for each filter
for n in range(nchannels):
IMG[j, i, n] = newvals[n, k]
else:
for n in range(nchannels):
IMG[ind, n] = newvals[n, k]

IMG[~np.isfinite(IMG)] = 0.

'''
# break up the image and coordinate data into the different filters
Rcoords = coords[2, :]
Expand Down Expand Up @@ -328,5 +380,6 @@ def create_image_from_grid(coords, imgvals, pix, inds, img_shape, n_neighbor=5,
IMG[ind, 2] = B_vals[k]
IMG[~np.isfinite(IMG)] = 0.
'''

return IMG
2 changes: 1 addition & 1 deletion projection/spice_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def download_kernel(kernel, KERNEL_DATAFOLDER):
f.write(response.content)
else:
total_length = int(total_length)
with tqdm.tqdm(total=total_length, ascii=True, bytes=True, desc=f'Downloading {kernel}') as pbar:
with tqdm.tqdm(total=total_length, unit='B', unit_scale=True, unit_divisor=1024, ascii=True, desc=f'Downloading {kernel}') as pbar:
for data in tqdm.tqdm(response.iter_content(chunk_size=4096)):
f.write(data)
pbar.update(len(data))
Expand Down

0 comments on commit 8bec8b8

Please sign in to comment.