Skip to content

Commit

Permalink
Merge pull request #16 from harvard-nrg/memfree
Browse files Browse the repository at this point in the history
Memfree
  • Loading branch information
danielasay authored Jan 3, 2025
2 parents 342f1b6 + 6cb202d commit 2833f94
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 59 deletions.
137 changes: 83 additions & 54 deletions scanbuddy/proc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
from pubsub import pub
from sortedcontainers import SortedDict
from scanbuddy.proc.snr import SNR

logger = logging.getLogger(__name__)

Expand All @@ -19,13 +20,36 @@ def __init__(self):
self.reset()
pub.subscribe(self.reset, 'reset')
pub.subscribe(self.listener, 'incoming')
self._fdata_array = np.array([])

def reset(self):
self._instances = SortedDict()
self._slice_means = SortedDict()
pub.sendMessage('plot_snr', snr_metric=str(0.0))
logger.debug('received message to reset')

def getsize(self, obj):
size_in_bytes = sys.getsizeof(obj)
return size_in_bytes

def get_size_slice_means(self):
total_size = 0
for key in self._slice_means:
slice_means = self._slice_means[key]['slice_means']
total_size += slice_means.nbytes
return total_size

def get_size_mask(self):
total_size = 0
for key in self._slice_means:
mask = self._slice_means[key]['mask']
if mask is not None:
mb = mask.nbytes / (1024**2)
shape = mask.shape
logger.info(f'mask for instance {key} is dtype={mask.dtype}, shape={shape}, size={mb} MB')
total_size += mask.nbytes
return total_size

def listener(self, ds, path):
key = int(ds.InstanceNumber)
self._instances[key] = {
Expand All @@ -42,22 +66,12 @@ def listener(self, ds, path):
logger.debug('current state of instances')
logger.debug(json.dumps(self._instances, default=list, indent=2))



tasks = self.check_volreg(key)
snr_tasks = self.check_snr(key)
logger.debug('publishing message to volreg topic with the following tasks')
logger.debug(json.dumps(tasks, indent=2))
pub.sendMessage('volreg', tasks=tasks)
logger.debug(f'publishing message to params topic')
pub.sendMessage('params', ds=ds)
logger.debug(f'publishing message to snr_fdata topic')
logger.debug(f'snr task sorted dict: {snr_tasks}')
pub.sendMessage('snr', nii_path=self._instances[key]['nii_path'], tasks=snr_tasks)
logger.debug('after snr calculation')

logger.debug(json.dumps(self._instances, indent=2))


logger.debug(f'after volreg')
logger.debug(json.dumps(self._instances, indent=2))
Expand All @@ -68,23 +82,43 @@ def listener(self, ds, path):
subtitle_string = f'{project}{session}{scandesc}{scannum}'
pub.sendMessage('plot', instances=self._instances, subtitle_string=subtitle_string)

snr_tasks = self.check_snr(key)
#logger.info(f'snr task sorted dict: {snr_tasks}')

snr = SNR()
nii_path = self._instances[key]['nii_path']
snr.do(nii_path, snr_tasks)
'''
size_of_snr_tasks = self.getsize(snr_tasks) / (1024**3)
size_of_slice_means = self.get_size_slice_means() / (1024**3)
size_of_fdata_array = self._fdata_array.nbytes / (1024**3)
logger.info('==============================================')
logger.info(f' SIZE OF snr_tasks IS {size_of_snr_tasks} GB')
logger.info(f' SIZE OF self._slice_means is {size_of_slice_means} GB')
logger.info(f' SIZE OF self._fdata_array is {size_of_fdata_array} GB')
logger.info('==============================================')
'''
logger.debug('after snr calculation')
logger.debug(json.dumps(self._instances, indent=2))

if key < 5:
self._num_vols = ds[(0x0020, 0x0105)].value
self._mask_threshold, self._decrement = self.get_mask_threshold(ds)
x, y, self._z, _ = self._slice_means[key]['slice_means'].shape

self._fdata_array = np.zeros((x, y, self._z, self._num_vols), dtype=np.float64)
self._slice_intensity_means = np.zeros((self._z, self._num_vols), dtype=np.float64)



logger.info(f'shape of zeros: {self._fdata_array.shape}')
logger.info(f'shape of first slice means: {self._slice_means[key]['slice_means'].shape}')

if key >= 5:
insert_position = key - 5
self._fdata_array[:, :, :, insert_position] = self._slice_means[key]['slice_means'].squeeze()
self._slice_means[key]['slice_means'] = np.array([])


if key > 53 and (key % 4 == 0) and key < self._num_vols:
logger.info('launching calculate and publish snr thread')

Expand All @@ -95,21 +129,13 @@ def listener(self, ds, path):
time.sleep(2)
data_path = os.path.dirname(self._instances[key]['path'])
logger.info(f'removing dicom dir: {data_path}')
shutil.rmtree(data_path)

#if key == self._num_vols:
# logger.info('RUNNING FINAL SNR CALCULATION')
# snr_metric = round(self.calc_snr(key), 2)
# logger.info(f'final snr metric: {snr_metric}')
# pub.sendMessage('plot_snr', snr_metric=snr_metric)


shutil.rmtree(data_path)


def calculate_and_publish_snr(self, key):
start = time.time()
snr_metric = round(self.calc_snr(key), 2)
elapsed = time.time() - start
#self._plot_dict[self._key] = elapsed
logger.info(f'snr calculation took {elapsed} seconds')
logger.info(f'running snr metric: {snr_metric}')
if np.isnan(snr_metric):
Expand Down Expand Up @@ -145,8 +171,15 @@ def check_volreg(self, key):
return tasks

def calc_snr(self, key):

slice_intensity_means, slice_voxel_counts, data = self.get_mean_slice_intensitites(key)
slice_intensity_means, slice_voxel_counts, data = self.get_mean_slice_intensities(key)
'''
size_slice_int_means = self.getsize(slice_intensity_means) / (1024**3)
size_data = self.getsize(data) / (1024**2)
logger.info('==============================================')
logger.info(f' SIZE OF slice_intensity_means IS {size_slice_int_means} MB')
logger.info(f' SIZE OF data IS {size_data} MB')
logger.info('==============================================')
'''

non_zero_columns = ~np.all(slice_intensity_means == 0, axis=0)

Expand All @@ -155,6 +188,7 @@ def calc_snr(self, key):
slice_count = slice_intensity_means_2.shape[0]
volume_count = slice_intensity_means_2.shape[1]


slice_weighted_mean_mean = 0
slice_weighted_stdev_mean = 0
slice_weighted_snr_mean = 0
Expand All @@ -177,27 +211,18 @@ def calc_snr(self, key):
total_voxel_count += slice_voxel_count

logger.debug(f"Slice {slice_idx}: Mean={slice_mean}, StdDev={slice_stdev}, SNR={slice_snr}")



return slice_weighted_snr_mean / total_voxel_count


def get_mean_slice_intensitites(self, key):

def get_mean_slice_intensities(self, key):

data = self.generate_mask(key)

mask = np.ma.getmask(data)
dim_x, dim_y, dim_z, _ = data.shape

dim_t = key - 4

'''
if key > 55:
start = time.time()
differing_slices = self.find_mask_differences(key)
logger.info(f'finding mask differences took {time.time() - start}')
'''


slice_voxel_counts = np.zeros( (dim_z), dtype='uint32' )
slice_size = dim_x * dim_y

Expand All @@ -223,11 +248,6 @@ def get_mean_slice_intensitites(self, key):
slice_vol_mean = slice_data.mean()
self._slice_intensity_means[slice_idx,volume_idx] = slice_vol_mean

#logger.info(f'recalculating slice means at the following slices: {differing_slices}')
#logger.info(f'total of {len(differing_slices)} new slices being computed')

#if differing_slices:

if key == self._num_vols:
start = time.time()
differing_slices = self.find_mask_differences(key)
Expand All @@ -240,7 +260,8 @@ def get_mean_slice_intensitites(self, key):
slice_vol_mean = slice_data.mean()
self._slice_intensity_means[slice_idx,volume_idx] = slice_vol_mean

elif key % 6 == 0:
elif key % 2 == 0:
#elif key % 6 == 0:
logger.info(f'inside the even calculation')
start = time.time()
differing_slices = self.find_mask_differences(key)
Expand All @@ -253,7 +274,8 @@ def get_mean_slice_intensitites(self, key):
slice_vol_mean = slice_data.mean()
self._slice_intensity_means[slice_idx,volume_idx] = slice_vol_mean

elif key % 5 == 0:
else:
#elif key % 5 == 0:
logger.info(f'inside the odd calculation')
start = time.time()
differing_slices = self.find_mask_differences(key)
Expand All @@ -265,17 +287,15 @@ def get_mean_slice_intensitites(self, key):
slice_data = data[:,:,slice_idx,volume_idx]
slice_vol_mean = slice_data.mean()
self._slice_intensity_means[slice_idx,volume_idx] = slice_vol_mean



return self._slice_intensity_means[:, :dim_t], slice_voxel_counts, data


def generate_mask(self, key):

mean_data = np.mean(self._fdata_array[...,:key-4], axis=3)

numpy_3d_mask = np.zeros(mean_data.shape, dtype=bool)

to_mask = (mean_data <= self._mask_threshold)

mask_lower_count = int(to_mask.sum())
Expand All @@ -284,27 +304,37 @@ def generate_mask(self, key):

numpy_4d_mask = np.zeros(self._fdata_array[..., :key-4].shape, dtype=bool)

numpy_4d_mask[numpy_3d_mask] = 1
numpy_4d_mask[numpy_3d_mask] = True

masked_data = np.ma.masked_array(self._fdata_array[..., :key-4], mask=numpy_4d_mask)

mask = np.ma.getmask(masked_data)

self._slice_means[key]['mask'] = mask

'''
size_mask = self.get_size_mask() / (1024**2)
logger.info(f'===============================')
logger.info(f'SHAPE OF MASK IS {mask.shape}')
logger.info(f'SIZE OF MASK IS {size_mask} MB')
logger.info(f'===============================')
'''

return masked_data

def find_mask_differences(self, key):
num_old_vols = key - 8
last_50 = num_old_vols - 50
prev_mask = self._slice_means[key-4]['mask']
logger.info(f'looking for mask differences between {key} and {key - 4}')
prev_mask = self._slice_means[key - 4]['mask']
current_mask = self._slice_means[key]['mask']
differences = prev_mask[:,:,:,-50:] != current_mask[:,:,:,last_50:num_old_vols]
diff_indices = np.where(differences)
differing_slices = []
for index in zip(*diff_indices):
if int(index[2]) not in differing_slices:
differing_slices.append(int(index[2]))
logger.info(f'reclaim memory for instance {key - 4 } mask')
self._slice_means[key - 4]['mask'] = np.array([])
return differing_slices


Expand Down Expand Up @@ -332,7 +362,6 @@ def find_coil(self, ds):

def check_snr(self, key):
tasks = list()
current = self._slice_means[key]

current_idx = self._slice_means.bisect_left(key)

Expand Down
9 changes: 7 additions & 2 deletions scanbuddy/proc/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ def __init__(self):
pub.subscribe(self.listener, 'snr')


def do(self, nii_path, tasks):
logger.info('received tasks for fdata extraction')
self.snr_tasks = tasks
self._nii_path = nii_path

self.run()

def listener(self, nii_path, tasks):
logger.info('received tasks for fdata extraction')
self.snr_tasks = tasks
Expand All @@ -39,7 +46,6 @@ def run(self):
instance_num = int(dcm.InstanceNumber)

logger.info(f'extracting fdata for volume {instance_num}')

data_array = self.get_nii_array()

self.insert_snr(data_array, self.snr_tasks[0], None)
Expand All @@ -50,7 +56,6 @@ def run(self):

logger.info(f'extracting fdata from volume {instance_num} took {elapsed} seconds')


def get_nii_array(self):
return nib.load(self._nii_path).get_fdata()

Expand Down
6 changes: 4 additions & 2 deletions scanbuddy/proc/volreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ def get_num_tasks(self):
def create_niis(self, task_idx):

dcm1 = self.tasks[task_idx][1]['path']
nii1 = self.run_dcm2niix(dcm1, 1)
#nii1 = self.run_dcm2niix(dcm1, 1)
self._dcm1_instance_num = int(pydicom.dcmread(dcm1, force=True, stop_before_pixels=True).InstanceNumber)
nii1 = self.run_dcm2niix(dcm1, self._dcm1_instance_num)
if self.tasks[task_idx][1]['nii_path'] is None:
self.tasks[task_idx][1]['nii_path'] = nii1

dcm2 = self.tasks[task_idx][0]['path']
nii2 = self.run_dcm2niix(dcm2, 2)
#nii2 = self.run_dcm2niix(dcm2, 2)
self._dcm2_instance_num = int(pydicom.dcmread(dcm2, force=True, stop_before_pixels=True).InstanceNumber)
nii2 = self.run_dcm2niix(dcm2, self._dcm2_instance_num)
if self.tasks[task_idx][0]['nii_path'] is None:
Expand All @@ -100,6 +100,8 @@ def run_dcm2niix(self, dicom, num):
'-o', self.out_dir,
dicom
]
cmdstr = json.dumps(dcm2niix_cmd, indent=2)
#logger.info(f'running {cmdstr}')

output = subprocess.check_output(dcm2niix_cmd, stderr=subprocess.STDOUT)

Expand Down
1 change: 0 additions & 1 deletion scripts/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def main():
config=config
)
volreg = VolReg(mock=args.mock)
snr = SNR()
view = View(
host=args.host,
port=args.port,
Expand Down

0 comments on commit 2833f94

Please sign in to comment.