Skip to content

Commit

Permalink
Merge branch 'develop' of github.com:Mouse-Imaging-Centre/minc-stuffs…
Browse files Browse the repository at this point in the history
… into develop
  • Loading branch information
bcdarwin committed Jun 20, 2016
2 parents 5c5fa0e + 67debdc commit 6b94f44
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 60 deletions.
271 changes: 213 additions & 58 deletions python/rotational_minctracc.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,72 @@ def get_centre_of_gravity(file):
cog = array(line.strip().split(" ")).astype("float32")
return cog

def get_coordinates_from_tag_file(tag_file):
#
# Tag file content looks as follows:
#
# MNI Tag Point File
# Volumes = 2;
# %Mon Apr 25 16:58:13 2016>>> find_peaks -pos_only -min_distance 2 //dev/shm//rot_23705/rot_5.mnc //dev/shm//rot_23705/rot_0..tag
#
# Points =
# -28.6989 -49.7709 -28.283 -28.6989 -49.7709 -28.283 "12"
# -28.5549 -49.7649 -25.445 -28.5549 -49.7649 -25.445 "11"
# -28.1424 -49.6824 -13.28 -28.1424 -49.6824 -13.28 "11"
# -27.7989 -49.2069 -22.667 -27.7989 -49.2069 -22.667 "10";
#
all_coordinates = []
with open(tag_file) as f:
content = f.readlines()
found_coordinates = False
for line in content:
if not found_coordinates:
if "Points" in line:
found_coordinates = True
else:
# each line with coordinates starts with a space
# so the very first element when splitting this
# line is the empty string. We don't need that one
current_coor = array([line.split(' ')[1],
line.split(' ')[2],
line.split(' ')[3]]).astype("float32")
all_coordinates.append(current_coor)
return all_coordinates


def get_distance_transform_peaks(input_file, peak_distance):
#
# for solid input files (brains, embryos) we can
# calculate the distance transform for the input file
# and use peaks from that distance transform to
# widen our search space a bit.
#
# the .decode() in the end removes the b from the front
# of the returned string (Python3 default)
bimodalt_value = subprocess.check_output(["mincstats",
"-quiet",
"-biModalT",
input_file]).rstrip().decode()
max_value = subprocess.check_output(["mincstats",
"-quiet",
"-max",
input_file]).rstrip().decode()
distance_transform = get_tempfile('.mnc')
subprocess.check_call(("mincmorph -successive B[%s:%s]F %s %s" %
(bimodalt_value,max_value,input_file,distance_transform)).split())
peak_tags = get_tempfile('.tag')
subprocess.check_call(("find_peaks -pos_only -min_distance %s %s %s" % (peak_distance, distance_transform, peak_tags)).split())
all_coors = get_coordinates_from_tag_file(peak_tags)
return all_coors

def get_blur_peaks(input_file, blur_kernel, peak_distance):
blurred_input = get_tempfile('_blur.mnc')
subprocess.check_call(("mincblur -no_apo -fwhm %s %s %s" % (blur_kernel, input_file, blurred_input.split('_blur.mnc')[0])).split())
peak_tags = get_tempfile('.tag')
subprocess.check_call(("find_peaks -pos_only -min_distance %s %s %s" % (peak_distance, blurred_input, peak_tags)).split())
all_coors = get_coordinates_from_tag_file(peak_tags)
return all_coors

def compute_xcorr(sourcefile, targetvol, maskvol):
try:
sourcevol = volumeFromFile(sourcefile)
Expand Down Expand Up @@ -81,74 +147,143 @@ def resample_volume(source, target, transform):
def minctracc(source, target, mask, stepsize, wtranslations, simplex):
wtrans_decomp = array(wtranslations.split(',')).astype("float")
tmp_transform = get_tempfile('.xfm')
if mask is not None:
cmd = ("minctracc -identity -lsq6 -xcorr -simplex %s -step %s %s %s %s %s %s -source_mask %s -model_mask %s -w_translations %s %s %s"
% (simplex, stepsize, stepsize, stepsize, source, target, tmp_transform, mask, mask,
wtrans_decomp[0], wtrans_decomp[1], wtrans_decomp[2]))
print(cmd)
subprocess.check_call(cmd.split())
else:
cmd = ("minctracc -identity -lsq6 -xcorr -simplex %s -step %s %s %s %s %s %s -w_translations %s %s %s"
% (simplex, stepsize, stepsize, stepsize, source, target, tmp_transform,
wtrans_decomp[0], wtrans_decomp[1], wtrans_decomp[2]))
print(cmd)
subprocess.check_call(cmd.split())
cmd = ("minctracc -identity -lsq6 -xcorr -simplex %s -step %s %s %s %s %s %s -w_translations %s %s %s "
% (simplex, stepsize, stepsize, stepsize, source, target, tmp_transform,
wtrans_decomp[0], wtrans_decomp[1], wtrans_decomp[2]))
if mask:
cmd += ("-source_mask %s -model_mask %s " % (mask, mask))
print(cmd)
subprocess.check_call(cmd.split())

return tmp_transform

def concat_transforms(t1, t2):
tmp_transform = get_tempfile('.xfm')
subprocess.check_call(("xfmconcat %s %s %s" % (t1, t2, tmp_transform)).split())
return tmp_transform

def loop_rotations(stepsize, source, target, mask, simplex, start=50, interval=10, wtranslations="0.2,0.2,0.2"):

# get the centre of gravity for both volumes
cog1 = get_centre_of_gravity(source)
cog2 = get_centre_of_gravity(target)
cogdiff = cog2 - cog1
print("\n\nCOG diff: %s\n\n" % cogdiff.tolist())

def get_cross_correlation_from_coordinate_pair(source_img, target_img, target_vol, mask, coordinate_pair):
# Generate a transformation based on the coordinate_pair provided. Apply this
# transformation to the source_img and calculate the cross correlation between
# the source_img and target_img after this initial alignment
#
#
# source coordinates : coordinate_pair[0]
# target_coordinates : coordinate_pair[1]
transform_from_coordinates = create_transform(coordinate_pair[1] - coordinate_pair[0], 0, 0, 0, coordinate_pair[0])
resampled_source = resample_volume(source_img, target_img, transform_from_coordinates)
xcorr = compute_xcorr(resampled_source, target_vol, mask)
os.remove(resampled_source)
return float(xcorr)

def loop_rotations(stepsize, source, target, mask, simplex, start=50, interval=10,
wtranslations="0.2,0.2,0.2", use_multiple_seeds=True, max_number_seeds=5):
# load the target and mask volumes
targetvol = volumeFromFile(target)
maskvol = volumeFromFile(mask) if mask is not None else None

# 1) The default way of aligning files is by using the centre
# of gravity of the input files
cog_source = get_centre_of_gravity(source)
cog_target = get_centre_of_gravity(target)
list_of_coordinate_pairs = [[cog_source, cog_target]]

# 2) If we are using multiple seeds, calculate possible
# seeds for both source and target images. The distance
# between peaks is based on the stepsize used for the
# registrations
if use_multiple_seeds:
list_source_peaks = get_distance_transform_peaks(input_file=source, peak_distance=stepsize)
print("\n\nPeaks found in the source image (Distance Transform):")
for coor_src in list_source_peaks:
print(coor_src)
# also add peaks from the blurred version of the input file
blurred_peaks_source = get_blur_peaks(input_file=source, blur_kernel=stepsize, peak_distance=stepsize)
print("\n\nPeaks found in the source image (blurrred image):")
for coor_src in blurred_peaks_source:
print(coor_src)
list_source_peaks.append(coor_src)
# also add the center of gravity of the source image
list_source_peaks.append(cog_source)
list_target_peaks = get_distance_transform_peaks(input_file=target, peak_distance=stepsize)
print("\n\nPeaks found in the target image (Distance Transform):")
for coor_trgt in list_target_peaks:
print(coor_trgt)
blurred_peaks_target = get_blur_peaks(input_file=target, blur_kernel=stepsize, peak_distance=stepsize)
print("\n\nPeaks found in the target image (blurrred image):")
for coor_target in blurred_peaks_target:
print(coor_target)
list_target_peaks.append(coor_target)
# same for the target; add the center of gravity:
list_target_peaks.append(cog_target)
for source_coor in list_source_peaks:
for target_coor in list_target_peaks:
list_of_coordinate_pairs.append([source_coor, target_coor])

# 3) If we have more coordinates pairs than we'll be using, we'll have
# to determine the initial cross correlation of each pair and sort
# the pairs based on that
pairs_with_xcorr = []
if len(list_of_coordinate_pairs) > max_number_seeds:
for coor_pair in list_of_coordinate_pairs:
xcorr_coor_pair = get_cross_correlation_from_coordinate_pair(source, target, targetvol, maskvol, coor_pair)
pairs_with_xcorr.append({'xcorr': xcorr_coor_pair,
'coorpair': coor_pair})
print("Xcorr: " + str(xcorr_coor_pair))
sort_results(pairs_with_xcorr, reverse_order=True)
print("\n\n\nCoordinate pairs and their cross correlation: \n\n")
print(pairs_with_xcorr)
# empty the list and fill it up with the best matches:
list_of_coordinate_pairs = []
for i in range(max_number_seeds):
list_of_coordinate_pairs.append(pairs_with_xcorr[i]['coorpair'])
print("\n\nNew list of coordinates:")
print(list_of_coordinate_pairs)

results = []
best_xcorr = 0
for x in range(-start, start+1, interval):
for y in range(-start, start+1, interval):
for z in range(-start, start+1, interval):
# we need to include the centre of the volume as rotation centre = cog1
init_transform = create_transform(cogdiff, x, y, z, cog1)
#init_resampled = resample_volume(source,target, init_transform)
init_resampled = resample_volume(source ,target, init_transform)
transform = minctracc(init_resampled, target, mask, stepsize=stepsize,
wtranslations=wtranslations, simplex=simplex)
resampled = resample_volume(init_resampled, target, transform)
xcorr = compute_xcorr(resampled, targetvol, maskvol)
if isnan(xcorr):
xcorr = 0
conc_transform = concat_transforms(init_transform, transform)
results.append({'xcorr': xcorr, 'transform': conc_transform, \
'resampled': resampled, 'x': x, \
'y': y, 'z': z})
if xcorr > best_xcorr:
best_xcorr = xcorr
else:
for coordinates_src_target in list_of_coordinate_pairs:
coor_src = coordinates_src_target[0]
coor_trgt = coordinates_src_target[1]
for x in range(-start, start+1, interval):
for y in range(-start, start+1, interval):
for z in range(-start, start+1, interval):
# we need to include the centre of the volume as rotation centre = cog1
init_transform = create_transform(coor_trgt - coor_src, x, y, z, coor_src)
init_resampled = resample_volume(source, target, init_transform)
transform = minctracc(init_resampled, target, mask, stepsize=stepsize,
wtranslations=wtranslations, simplex=simplex)
resampled = resample_volume(init_resampled, target, transform)
conc_transform = concat_transforms(init_transform, transform)
xcorr = compute_xcorr(resampled, targetvol, maskvol)
if isnan(xcorr):
xcorr = 0
results.append({'xcorr': xcorr, 'transform': conc_transform, \
'resampled': resampled, 'x': x, \
'y': y, 'z': z})
if xcorr > best_xcorr:
best_xcorr = xcorr
# had some issues with the resampled file being gone...
# we'll just resample the final file only at the end
os.remove(resampled)
os.remove(init_resampled)
print("FINISHED: %s %s %s :: %s" % (x,y,z, xcorr))
os.remove(init_resampled)
print("FINISHED: %s %s %s :: %s" % (x,y,z, xcorr))

sort_results(results)
# resample the best result:
final_resampled = resample_volume(source, target, results[-1]["transform"])
results[-1]["resampled"] = final_resampled
targetvol.closeVolume()
if mask is not None:
maskvol.closeVolume()
sort_results(results)
return results

def dict_extract(adict, akey):
return adict[akey]

def sort_results(results):
def sort_results(results, reverse_order=False):
sort_key_func = functools.partial(dict_extract, akey="xcorr")
results.sort(key=sort_key_func)
results.sort(key=sort_key_func, reverse=reverse_order)

def downsample(infile, stepsize):
output = get_tempfile(".mnc")
Expand All @@ -166,28 +301,43 @@ def termtrapper(signum, frame):
signal.signal(signal.SIGTERM, termtrapper)

parser = ArgumentParser()

parser.add_argument("-m", "--mask", dest="mask",
help="mask to use for computing xcorr", type=str)
parser.add_argument("-s", "--stepsizeresample", dest="resamplestepsize",
help="resampled volumes to this stepsize",
help="resampled volumes to this stepsize [default = %(default)s]",
type=float, default=0.2)
parser.add_argument("-g", "--stepsizeregistration", dest="registrationstepsize",
help="use this stepsize in the minctracc registration",
help="use this stepsize in the minctracc registration [default = %(default)s]",
type=float, default=0.6)
parser.add_argument("-t", "--tempdir", dest="tmpdir",
help="temporary directory to use",
type=str)
parser.add_argument("-r", "--range", dest="range",
help="range of rotations to search across",
help="range of rotations to search across [default = %(default)s]",
type=int, default=50)
parser.add_argument("-i", "--interval", dest="interval",
help="interval (in degrees) to search across range",
help="interval (in degrees) to search across range [default = %(default)s]",
type=int, default=10)
parser.add_argument("-w", "--wtranslations", dest="wtranslations",
help="Comma separated list of optimization weights of translations in x, y, z for minctracc",
help="Comma separated list of optimization weights of translations in "
"x, y, z for minctracc [default = %(default)s]",
type=str, default="0.2,0.2,0.2")
parser.add_argument("--simplex", dest="simplex", help="Radius of minctracc simplex volume", type=float, default=1)
parser.add_argument("--simplex", dest="simplex", type=float, default=1,
help="Radius of minctracc simplex volume [default = %(default)s]")
parser.set_defaults(use_multiple_seeds=True)
parser.add_argument("--use-multiple-seeds", dest="use_multiple_seeds", action="store_true",
help="Find multiple possible starting points in the source and target for "
"the initial alignment in addition to using only the centre of gravity "
"(of the intensities) of the input files. [default = %(default)s]")
parser.add_argument("--no-use-multiple-seeds", dest="use_multiple_seeds", action="store_false",
help="Opposite of --use-multiple-seeds")
parser.set_defaults(max_number_seeds=3)
parser.add_argument("--max-number-seeds", dest="max_number_seeds", type=int,
help="Specify the maximum number of seed-pair starting points "
"to use for the rotational part of the code. The seed "
"pairs are ordered based on the cross correlation gotten "
"from the alignment based on only the translation from the "
"seed point. [default = %(default)s]")
parser.add_argument("source", help="", type=str, metavar="source.mnc")
parser.add_argument("target", help="", type=str, metavar="target.mnc")
parser.add_argument("output_xfm", help="", type=str, metavar="output.xfm")
Expand All @@ -213,12 +363,17 @@ def termtrapper(signum, frame):
# downsample the mask only if it is specified
if options.mask:
options.mask = downsample(options.mask, options.resamplestepsize)
if options.mask:
results = loop_rotations(options.registrationstepsize, source, target, options.mask, start=options.range,
interval=options.interval, wtranslations=options.wtranslations, simplex=options.simplex)
else:
results = loop_rotations(options.registrationstepsize, source, target, None, start=options.range,
interval=options.interval, wtranslations=options.wtranslations, simplex=options.simplex)

results = loop_rotations(stepsize=options.registrationstepsize,
source=source,
target=target,
mask=options.mask,
start=options.range,
interval=options.interval,
wtranslations=options.wtranslations,
simplex=options.simplex,
use_multiple_seeds=options.use_multiple_seeds,
max_number_seeds=options.max_number_seeds)

print(results)
subprocess.check_call(("cp %s %s" % (results[-1]["transform"], output_xfm)).split())
Expand Down
4 changes: 2 additions & 2 deletions python/voxel_vote
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ if __name__ == "__main__":
(1,volhandles[0].sizes[1],
volhandles[0].sizes[2]))
t.shape = (volhandles[0].sizes[1], volhandles[0].sizes[2])
sliceArray[j::] = t
sliceArray[j,:,:] = t

outfile.data[i::] = mode(sliceArray)[0]
outfile.data[i,:,:] = mode(sliceArray)[0]

outfile.writeFile()
outfile.closeVolume()
Expand Down

0 comments on commit 6b94f44

Please sign in to comment.