diff --git a/projection/frameletdata.py b/projection/frameletdata.py index 63af871..cefb904 100644 --- a/projection/frameletdata.py +++ b/projection/frameletdata.py @@ -12,7 +12,7 @@ class FrameletData: - def __init__(self, metadata, imgfolder): + def __init__(self, metadata: dict, imgfolder: str): self.nframelets = int(metadata["LINES"] / FRAME_HEIGHT) # number of RGB frames self.nframes = int(self.nframelets / 3) @@ -64,71 +64,63 @@ def __init__(self, metadata, imgfolder): self.framelets.append(Framelet(self.start_et, frame_delay, n, c, framei)) - def get_backplane(self, num_procs): + def get_backplane(self, num_procs: int): tmid = np.mean([frame.et for frame in self.framelets]) - ''' - inpargs = [(frame, tmid) for frame in self.framelets] - - with multiprocessing.Pool( - processes=num_procs, initializer=initializer - ) as pool: - try: - pool.starmap(Framelet.project_to_midplane, tqdm.tqdm(inpargs)) - pool.close() - except KeyboardInterrupt: - pool.terminate() - pool.join() - sys.exit() - - pool.join() - ''' - for frame in tqdm.tqdm(self.framelets): - frame.project_to_midplane(tmid) - - def update_jitter(self, jitter): + if num_procs > 1: + inpargs = [(frame, tmid) for frame in self.framelets] + with multiprocessing.Pool( + processes=num_procs, initializer=initializer + ) as pool: + try: + with tqdm.tqdm(inpargs, desc='Projecting framelets') as pbar: + self.framelets = pool.starmap(Framelet.project_to_midplane, pbar) + pool.close() + except KeyboardInterrupt: + pool.terminate() + sys.exit() + finally: + pool.join() + else: + for frame in tqdm.tqdm(self.framelets): + frame.project_to_midplane(tmid) + + def update_jitter(self, jitter: float): + self.jitter = jitter for framelet in self.framelets: framelet.jitter = jitter - def update_image(self, new_image): - assert new_image.shape == self.image.shape, f"New image must be the same shape as the old one. Got {new_image.shape} instead of {self.image.shape}" - - # first back up the original - self.image_old = self.image - # then store the new image data - self.image[:] = new_image[:] - @property - def tmid(self): + def tmid(self) -> float: return np.mean([frame.et for frame in self.framelets]) @property - def image(self): + def image(self) -> np.ndarray: return np.stack([frame.image for frame in self.framelets], axis=0).reshape((self.nframes, 3, FRAME_HEIGHT, FRAME_WIDTH)) @property - def coords(self): + def coords(self) -> np.ndarray: return np.stack([frame.coords for frame in self.framelets], axis=0).reshape((self.nframes, 3, FRAME_HEIGHT, FRAME_WIDTH, 2)) @property - def emission(self): + def emission(self) -> np.ndarray: return np.stack([frame.emission for frame in self.framelets], axis=0).reshape((self.nframes, 3, FRAME_HEIGHT, FRAME_WIDTH)) @property - def incidence(self): + def incidence(self) -> np.ndarray: return np.stack([frame.incidence for frame in self.framelets], axis=0).reshape((self.nframes, 3, FRAME_HEIGHT, FRAME_WIDTH)) @property - def longitude(self): + def longitude(self) -> np.ndarray: return np.stack([frame.lon for frame in self.framelets], axis=0).reshape((self.nframes, 3, FRAME_HEIGHT, FRAME_WIDTH)) @property - def latitude(self): + def latitude(self) -> np.ndarray: return np.stack([frame.lat for frame in self.framelets], axis=0).reshape((self.nframes, 3, FRAME_HEIGHT, FRAME_WIDTH)) class Framelet: - def __init__(self, start_et, frame_delay, frame_no, color, img): + def __init__(self, start_et: float, frame_delay: float, frame_no: int, color: int, img: np.ndarray): self.start_et = start_et self.frame_delay = frame_delay self.frame_no = frame_no @@ -138,7 +130,7 @@ def __init__(self, start_et, frame_delay, frame_no, color, img): self.camera = CameraModel(color) @property - def et(self): + def et(self) -> float: jitter = 0 if self.jitter is None else self.jitter return ( self.start_et @@ -147,7 +139,7 @@ def et(self): + (self.frame_delay + self.camera.iframe_delay) * self.frame_no ) - def project_to_midplane(self, tmid): + def project_to_midplane(self, tmid: float) -> None: coords = np.nan * np.zeros((FRAME_HEIGHT, FRAME_WIDTH, 2)) lat = np.nan * np.zeros((FRAME_HEIGHT, FRAME_WIDTH)) lon = np.nan * np.zeros((FRAME_HEIGHT, FRAME_WIDTH)) @@ -169,6 +161,8 @@ def project_to_midplane(self, tmid): # geometry correction self.fluxcal = fluxcal + return self + # for decompanding -- taken from Kevin Gill's github page SQROOT = np.array( @@ -197,7 +191,7 @@ def project_to_midplane(self, tmid): ) -def decompand(image): +def decompand(image: np.ndarray) -> np.ndarray: """ Decompands the image from the 8-bit in the public release to the original 12-bit shot by JunoCam @@ -215,7 +209,7 @@ def decompand(image): data = np.array(255 * image, dtype=np.uint8) ny, nx = data.shape - def get_sqrt(x): + def get_sqrt(x: float) -> float: return SQROOT[x] v_get_sqrt = np.vectorize(get_sqrt) diff --git a/projection/mosaic.py b/projection/mosaic.py new file mode 100644 index 0000000..6e39520 --- /dev/null +++ b/projection/mosaic.py @@ -0,0 +1,53 @@ +import numpy as np +import healpy as hp +from skimage import color +import tqdm + + +class Mosaic: + def __init__(self, fnames, n_side=512): + self.fnames = fnames + self.nframes = len(fnames) + self.n_side = n_side + + def load_data(self): + self.maps = np.zeros((len(self.fnames), hp.nside2npix(self.n_side), 3)) + + for i, fname in enumerate(tqdm.tqdm(self.fnames)): + map = np.load(fname) + if hp.nside2npix(self.n_side) != map.shape[0]: + map = hp.ud_grade(map.T, self.n_side).T + + map_lab = color.rgb2lab(map) + + # clean up the fringes. these are negative a* and b* in the CIELAB space + map_lab[(map_lab[:, 1] < -0.001) | (map_lab[:, 2] < -0.001), 0] = 0 + + self.maps[i, :] = color.lab2rgb(map_lab) + self.maps[~np.isfinite(self.maps)] = 0. + + def stack(self, radius=2): + m_normed = self.maps.copy() + # first, norma + for j in range(self.nframes): + m_normed[j, :] = m_normed[j, :] / np.percentile(m_normed[j, :], 99) + + count = np.sum(np.min(m_normed, axis=-1) > 1.e-6, axis=0) + m_hsv = color.rgb2hsv(m_normed) + v_ave = np.mean(m_hsv[:, count > 0, 2]) + + pix_inds = np.where(count > 0)[0] + v_loc = np.zeros_like(m_normed[:, :, 0]) + vecs = np.asarray(hp.pix2vec(self.n_side, ipix=pix_inds)).T + + for i, vec in tqdm.tqdm(zip(pix_inds, vecs), total=len(pix_inds)): + neighbours = hp.query_disc(self.n_side, vec=vec, radius=np.radians(radius)) + m_hsv_neigh = m_hsv[:, neighbours, 2] + v_loc[:, i] = np.mean(m_hsv_neigh[:, count[neighbours] > 0], axis=1) + + m_new = np.zeros_like(m_normed[0, :]) + for i in range(3): + m_new[:, i] = np.sum(m_normed[:, :, i] * (v_ave / v_loc), axis=0) / count + m_new[~np.isfinite(m_new)] = 0 + + return m_new diff --git a/projection/project.c b/projection/project.c index 96210b9..daaad7c 100755 --- a/projection/project.c +++ b/projection/project.c @@ -241,7 +241,7 @@ void project_midplane(double eti, int cam, double tmid, double *lon, pixvec[2] / sqrtf(pixvec[0] * pixvec[0] + pixvec[1] * pixvec[1] + pixvec[2] * pixvec[2]); fluxcal[jj * FRAME_WIDTH + ii] = - (M_PI / 4.) * + (M_PI / 4.) * pow(disti, 2.) * pow((aperture / focal_length) * cosalpha * cosalpha, 2); } } diff --git a/projection/projector.py b/projection/projector.py index d4565ed..e6abbe5 100755 --- a/projection/projector.py +++ b/projection/projector.py @@ -6,11 +6,7 @@ import matplotlib.pyplot as plt import spiceypy as spice import tqdm -import multiprocessing -import time -import sys -from .globals import FRAME_HEIGHT, FRAME_WIDTH, initializer -from .cython_utils import furnish_c, project_midplane_c, get_pixel_from_coords_c +from .cython_utils import furnish_c, get_pixel_from_coords_c from .camera_funcs import CameraModel from .spice_utils import get_kernels from .frameletdata import FrameletData @@ -188,23 +184,35 @@ def get_limb(self, eti, cami): return limbs_jcam - def process(self, nside=512, num_procs=8, apply_LS=True, n_neighbor=5): + def process(self, nside=512, num_procs=8, apply_correction='minneart', n_neighbor=5, minneart_k=1.25): print(f"Projecting {self.fname} to a HEALPix grid with n_side={nside}") self.framedata.get_backplane(num_procs) - if apply_LS: - self.framedata.update_image( - apply_lommel_seeliger(self.framedata.image, self.framedata.incidence, self.framedata.emission) - ) - - coords_new = np.transpose(self.framedata.coords, (1, 0, 2, 3, 4)).reshape(3, -1, 2) - imgvals_new = np.transpose(self.framedata.image, (1, 0, 2, 3)).reshape(3, -1) + self.apply_correction(apply_correction, minneart_k) - map = self.project_to_healpix(nside, coords_new, imgvals_new, n_neighbor=n_neighbor) + map = self.project_to_healpix(nside, self.framecoords, self.imagevalues, n_neighbor=n_neighbor) return map + def apply_correction(self, correction_type, minneart_k=1.25): + if correction_type == 'ls': + print("Applying Lommel-Seeliger correction") + for frame in self.framedata.framelets: + frame.image = apply_lommel_seeliger(frame.image, frame.incidence, frame.emission) + elif correction_type == 'minneart': + print("Applying Minneart correction") + for frame in self.framedata.framelets: + frame.image = apply_minneart(frame.image, frame.incidence, frame.emission, k=minneart_k) + + @property + def framecoords(self): + return np.transpose(self.framedata.coords, (1, 0, 2, 3, 4)).reshape(3, -1, 2) + + @property + def imagevalues(self): + return np.transpose(self.framedata.image, (1, 0, 2, 3)).reshape(3, -1) + def project_to_healpix(self, nside, coords, imgvals, n_neighbor=4): # get the image extents in pixel coordinate space # clip half a pixel to avoid edge artifacts @@ -242,7 +250,6 @@ def apply_lommel_seeliger(imgvals, incidence, emission): ''' Apply the Lommel-Seeliger correction for incidence ''' - print("Applying Lommel-Seeliger correction") # apply Lommel-Seeliger correction mu0 = np.cos(incidence) mu = np.cos(emission) @@ -253,6 +260,18 @@ def apply_lommel_seeliger(imgvals, incidence, emission): return imgvals +def apply_minneart(imgvals, incidence, emission, k=1.25): + # apply Minneart correction + mu0 = np.cos(incidence) + mu = np.cos(emission) + corr = (mu ** k) * (mu0 ** (k - 1)) + # log(mu * mu0) < -4 is usually pretty noisy + corr[np.log(np.cos(incidence) * np.cos(emission)) < -4] = np.inf + imgvals = imgvals / corr + + return imgvals + + def create_image_from_grid(coords, imgvals, pix, inds, img_shape, n_neighbor=5, min_dist=25.): ''' Reproject an irregular spaced image onto a regular grid from a list of coordinate @@ -260,6 +279,39 @@ def create_image_from_grid(coords, imgvals, pix, inds, img_shape, n_neighbor=5, by `pix`, where pix gives the coordinates in the original image where the corresponding pixel coordinate on the new image should be. The coordinate on the new image is given by the inds variable. + ''' + nchannels, ncoords, _ = coords.shape + mask = np.ones(coords.shape[1], dtype=bool) + for n in range(nchannels): + maski = np.isfinite(coords[n, :, 0] * coords[n, :, 1]) + mask = mask & maski + + newvals = np.zeros((nchannels, pix.shape[0])) + print("Calculating image values at new locations") + for n in range(nchannels): + neighbors = NearestNeighbors().fit(coords[n][mask]) + dist, indi = neighbors.kneighbors(pix, n_neighbor) + weight = 1. / (dist + 1.e-16) + weight = weight / np.sum(weight, axis=1, keepdims=True) + weight[dist > min_dist] = 0. + + newvals[n, :] = np.sum(np.take(imgvals[n][mask], indi, axis=0) * weight, axis=1) + + IMG = np.zeros((*img_shape, nchannels)) + # loop through each point observed by JunoCam and assign the pixel value + for k, ind in enumerate(tqdm.tqdm(inds, desc='Building image')): + if len(img_shape) == 2: + j, i = np.unravel_index(ind, img_shape) + + # do the weighted average for each filter + for n in range(nchannels): + IMG[j, i, n] = newvals[n, k] + else: + for n in range(nchannels): + IMG[ind, n] = newvals[n, k] + + IMG[~np.isfinite(IMG)] = 0. + ''' # break up the image and coordinate data into the different filters Rcoords = coords[2, :] @@ -328,5 +380,6 @@ def create_image_from_grid(coords, imgvals, pix, inds, img_shape, n_neighbor=5, IMG[ind, 2] = B_vals[k] IMG[~np.isfinite(IMG)] = 0. + ''' return IMG diff --git a/projection/spice_utils.py b/projection/spice_utils.py index a2e870a..83436ec 100644 --- a/projection/spice_utils.py +++ b/projection/spice_utils.py @@ -37,7 +37,7 @@ def download_kernel(kernel, KERNEL_DATAFOLDER): f.write(response.content) else: total_length = int(total_length) - with tqdm.tqdm(total=total_length, ascii=True, bytes=True, desc=f'Downloading {kernel}') as pbar: + with tqdm.tqdm(total=total_length, unit='B', unit_scale=True, unit_divisor=1024, ascii=True, desc=f'Downloading {kernel}') as pbar: for data in tqdm.tqdm(response.iter_content(chunk_size=4096)): f.write(data) pbar.update(len(data))