Skip to content

Commit

Permalink
Refactor WkwDatasetInfo to include all wkw-header-related properties (#…
Browse files Browse the repository at this point in the history
…144)

* extend the WkwDatasetInfo to include all important properties

* convert dtype to a VALID_VOXEL_TYPE

* use voxel_type of header

* use parameter 'dtype' on lower-level methods as np.dtype consistently

* reformating of code

* fix test_element_class_convertion test

* Info for KnossosDatasets do not contain a header

* resolve circular dependency

* black
  • Loading branch information
rschwanhold authored Oct 23, 2019
1 parent e330475 commit 66dd9f5
Show file tree
Hide file tree
Showing 10 changed files with 111 additions and 71 deletions.
46 changes: 26 additions & 20 deletions tests/test_downsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
WKW_CUBE_SIZE = 1024
CUBE_EDGE_LEN = 256

source_info = WkwDatasetInfo("testdata/WT1_wkw", "color", "uint8", 1)
target_info = WkwDatasetInfo("testoutput/WT1_wkw", "color", "uint8", 2)
source_info = WkwDatasetInfo("testdata/WT1_wkw", "color", 1, wkw.Header(np.uint8))
target_info = WkwDatasetInfo("testoutput/WT1_wkw", "color", 2, wkw.Header(np.uint8))


def read_wkw(wkw_info, offset, size, **kwargs):
with open_wkw(wkw_info, **kwargs) as wkw_dataset:
def read_wkw(wkw_info, offset, size):
with open_wkw(wkw_info) as wkw_dataset:
return wkw_dataset.read(offset, size)


Expand Down Expand Up @@ -110,11 +110,10 @@ def downsample_test_helper(use_compress):
block_type = (
wkw.Header.BLOCK_TYPE_LZ4HC if use_compress else wkw.Header.BLOCK_TYPE_RAW
)
target_info.header.block_type = block_type

target_buffer = read_wkw(
target_info,
tuple(a * WKW_CUBE_SIZE for a in offset),
(CUBE_EDGE_LEN,) * 3,
block_type=block_type,
target_info, tuple(a * WKW_CUBE_SIZE for a in offset), (CUBE_EDGE_LEN,) * 3
)[0]
assert np.any(target_buffer != 0)

Expand All @@ -133,14 +132,6 @@ def test_compressed_downsample_cube_job():


def test_downsample_multi_channel():
source_info = WkwDatasetInfo("testoutput/multi-channel-test", "color", "uint8", 1)
target_info = WkwDatasetInfo("testoutput/multi-channel-test", "color", "uint8", 2)
try:
shutil.rmtree(source_info.dataset_path)
shutil.rmtree(target_info.dataset_path)
except:
pass

offset = (0, 0, 0)
num_channels = 3
size = (32, 32, 10)
Expand All @@ -149,9 +140,25 @@ def test_downsample_multi_channel():
).astype("uint8")
file_len = 32

with open_wkw(
source_info, num_channels=num_channels, file_len=file_len
) as wkw_dataset:
source_info = WkwDatasetInfo(
"testoutput/multi-channel-test",
"color",
1,
wkw.Header(np.uint8, num_channels, file_len=file_len),
)
target_info = WkwDatasetInfo(
"testoutput/multi-channel-test",
"color",
2,
wkw.Header(np.uint8, file_len=file_len),
)
try:
shutil.rmtree(source_info.dataset_path)
shutil.rmtree(target_info.dataset_path)
except:
pass

with open_wkw(source_info) as wkw_dataset:
print("writing source_data shape", source_data.shape)
wkw_dataset.write(offset, source_data)
assert np.any(source_data != 0)
Expand Down Expand Up @@ -180,7 +187,6 @@ def test_downsample_multi_channel():
target_info,
tuple(a * WKW_CUBE_SIZE for a in offset),
list(map(lambda x: x // 2, size)),
file_len=file_len,
)
assert np.any(target_buffer != 0)

Expand Down
13 changes: 8 additions & 5 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import os
import wkw

from wkcuber.cubing import ensure_wkw
from wkcuber.utils import WkwDatasetInfo, open_wkw
Expand All @@ -16,9 +17,9 @@ def test_element_class_convertion():
test_wkw_path = os.path.join("testoutput", "test_metadata")
prediction_layer_name = "prediction"
prediction_wkw_info = WkwDatasetInfo(
test_wkw_path, prediction_layer_name, np.float32, 1
test_wkw_path, prediction_layer_name, 1, wkw.Header(np.float32, num_channels=3)
)
ensure_wkw(prediction_wkw_info, num_channels=3)
ensure_wkw(prediction_wkw_info)

write_custom_layer(test_wkw_path, "prediction", np.float32, num_channels=3)
write_webknossos_metadata(
Expand Down Expand Up @@ -61,9 +62,11 @@ def write_custom_layer(target_path, layer_name, dtype, num_channels):
.reshape((num_channels, 4, 4, 4))
.astype(dtype)
)
prediction_wkw_info = WkwDatasetInfo(target_path, layer_name, dtype, 1)
ensure_wkw(prediction_wkw_info, num_channels=num_channels)
with open_wkw(prediction_wkw_info, num_channels=num_channels) as dataset:
prediction_wkw_info = WkwDatasetInfo(
target_path, layer_name, 1, wkw.Header(dtype, num_channels)
)
ensure_wkw(prediction_wkw_info)
with open_wkw(prediction_wkw_info) as dataset:
dataset.write(off=(0, 0, 0), data=data)


Expand Down
7 changes: 4 additions & 3 deletions wkcuber/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
setup_logging,
)
from .metadata import detect_resolutions
from .metadata import convert_element_class_to_dtype
from typing import List


Expand Down Expand Up @@ -80,10 +81,10 @@ def compress_mag(source_path, layer_name, target_path, mag: Mag, args=None):
exit(1)

if args is not None and hasattr(args, "dtype"):
dtype = args.dtype
header = wkw.Header(convert_element_class_to_dtype(args.dtype))
else:
dtype = None
source_wkw_info = WkwDatasetInfo(source_path, layer_name, dtype, mag)
header = None
source_wkw_info = WkwDatasetInfo(source_path, layer_name, mag, header)
target_mag_path = path.join(target_path, layer_name, str(mag))
logging.info("Compressing mag {0} in '{1}'".format(str(mag), target_mag_path))

Expand Down
5 changes: 4 additions & 1 deletion wkcuber/convert_knossos.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
setup_logging,
)
from .knossos import KnossosDataset, CUBE_EDGE_LEN
from .metadata import convert_element_class_to_dtype


def create_parser():
Expand Down Expand Up @@ -76,7 +77,9 @@ def convert_knossos(
source_path, target_path, layer_name, dtype, mag=1, jobs=1, args=None
):
source_knossos_info = KnossosDatasetInfo(source_path, dtype)
target_wkw_info = WkwDatasetInfo(target_path, layer_name, dtype, mag)
target_wkw_info = WkwDatasetInfo(
target_path, layer_name, mag, wkw.Header(convert_element_class_to_dtype(dtype))
)

ensure_wkw(target_wkw_info)

Expand Down
37 changes: 23 additions & 14 deletions wkcuber/cubing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import time
import logging
import numpy as np
import wkw
from argparse import ArgumentParser
from os import path
from natsort import natsorted
Expand All @@ -22,6 +23,7 @@
setup_logging,
)
from .image_readers import image_reader
from .metadata import convert_element_class_to_dtype

BLOCK_LEN = 32

Expand Down Expand Up @@ -121,15 +123,14 @@ def cubing_job(args):
source_file_batches,
batch_size,
image_size,
num_channels,
pad,
) = args
if len(z_batches) == 0:
return

downsampling_needed = target_mag != Mag(1)

with open_wkw(target_wkw_info, num_channels=num_channels) as target_wkw:
with open_wkw(target_wkw_info) as target_wkw:
# Iterate over batches of continuous z sections
# The batches have a maximum size of `batch_size`
# Batched iterations allows to utilize IO more efficiently
Expand All @@ -144,7 +145,9 @@ def cubing_job(args):
# Iterate over each z section in the batch
for z, file_name in zip(z_batch, source_file_batch):
# Image shape will be (x, y, channel_count, z=1)
image = read_image_file(file_name, target_wkw_info.dtype)
image = read_image_file(
file_name, target_wkw_info.header.voxel_type
)
if not pad:
assert (
image.shape[0:2] == image_size
Expand All @@ -171,7 +174,9 @@ def cubing_job(args):
for _slice in slices
]

buffer = prepare_slices_for_wkw(slices, num_channels)
buffer = prepare_slices_for_wkw(
slices, target_wkw_info.header.num_channels
)
if downsampling_needed:
buffer = downsample_unpadded_data(
buffer, target_mag, interpolation_mode
Expand All @@ -195,8 +200,20 @@ def cubing_job(args):

def cubing(source_path, target_path, layer_name, dtype, batch_size, args=None) -> dict:

source_files = find_source_filenames(source_path)

# All images are assumed to have equal dimensions
num_x, num_y = image_reader.read_dimensions(source_files[0])
num_channels = image_reader.read_channel_count(source_files[0])
num_z = len(source_files)

target_mag = Mag(args.target_mag)
target_wkw_info = WkwDatasetInfo(target_path, layer_name, dtype, target_mag)
target_wkw_info = WkwDatasetInfo(
target_path,
layer_name,
target_mag,
wkw.Header(convert_element_class_to_dtype(dtype), num_channels),
)
interpolation_mode = parse_interpolation_mode(
args.interpolation_mode, target_wkw_info.layer_name
)
Expand All @@ -205,16 +222,9 @@ def cubing(source_path, target_path, layer_name, dtype, batch_size, args=None) -
f"Downsampling the cubed image to {target_mag} in memory with interpolation mode {interpolation_mode}."
)

source_files = find_source_filenames(source_path)

# All images are assumed to have equal dimensions
num_x, num_y = image_reader.read_dimensions(source_files[0])
num_channels = image_reader.read_channel_count(source_files[0])
num_z = len(source_files)

logging.info("Found source files: count={} size={}x{}".format(num_z, num_x, num_y))

ensure_wkw(target_wkw_info, num_channels=num_channels)
ensure_wkw(target_wkw_info)

with get_executor_for_args(args) as executor:
job_args = []
Expand All @@ -233,7 +243,6 @@ def cubing(source_path, target_path, layer_name, dtype, batch_size, args=None) -
source_files[z:max_z],
batch_size,
(num_x, num_y),
num_channels,
args.pad,
)
)
Expand Down
29 changes: 22 additions & 7 deletions wkcuber/downsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def determine_buffer_edge_len(dataset):
return min(DEFAULT_EDGE_LEN, dataset.header.file_len * dataset.header.block_len)


def extend_wkw_dataset_info_header(wkw_info, **kwargs):
for key, value in kwargs.items():
setattr(wkw_info.header, key, value)


def calculate_virtual_scale_for_target_mag(target_mag):
"This scale is not the actual scale of the dataset"
"The virtual scale is used for downsample_mags_anisotropic."
Expand Down Expand Up @@ -182,13 +187,16 @@ def downsample(
header_block_type = (
wkw.Header.BLOCK_TYPE_LZ4HC if compress else wkw.Header.BLOCK_TYPE_RAW
)
ensure_wkw(

extend_wkw_dataset_info_header(
target_wkw_info,
block_type=header_block_type,
num_channels=num_channels,
file_len=source_wkw.header.file_len,
block_type=header_block_type,
)

ensure_wkw(target_wkw_info)

with get_executor_for_args(args) as executor:
job_args = []
for target_cube_xyz in target_cube_addresses:
Expand Down Expand Up @@ -230,12 +238,16 @@ def downsample_cube_job(args):
with open_wkw(source_wkw_info) as source_wkw:
num_channels = source_wkw.header.num_channels
source_dtype = source_wkw.header.voxel_type
with open_wkw(

extend_wkw_dataset_info_header(
target_wkw_info,
block_type=header_block_type,
voxel_type=source_dtype,
num_channels=num_channels,
file_len=source_wkw.header.file_len,
) as target_wkw:
block_type=header_block_type,
)

with open_wkw(target_wkw_info) as target_wkw:
wkw_cubelength = (
source_wkw.header.file_len * source_wkw.header.block_len
)
Expand Down Expand Up @@ -470,10 +482,13 @@ def downsample_mag(
):
interpolation_mode = parse_interpolation_mode(interpolation_mode, layer_name)

source_wkw_info = WkwDatasetInfo(path, layer_name, None, source_mag.to_layer_name())
source_wkw_info = WkwDatasetInfo(path, layer_name, source_mag.to_layer_name(), None)
with open_wkw(source_wkw_info) as source:
target_wkw_info = WkwDatasetInfo(
path, layer_name, source.header.voxel_type, target_mag.to_layer_name()
path,
layer_name,
target_mag.to_layer_name(),
wkw.Header(source.header.voxel_type),
)

downsample(
Expand Down
4 changes: 2 additions & 2 deletions wkcuber/image_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class PillowImageReader:
def read_array(self, file_name, dtype):
this_layer = np.array(Image.open(file_name), np.dtype(dtype))
this_layer = np.array(Image.open(file_name), dtype)
this_layer = this_layer.swapaxes(0, 1)
this_layer = this_layer.reshape(this_layer.shape + (1,))
return this_layer
Expand All @@ -34,7 +34,7 @@ def read_channel_count(self, file_name):
def to_target_datatype(data: np.ndarray, target_dtype) -> np.ndarray:

factor = (1 + np.iinfo(data.dtype).max) / (1 + np.iinfo(target_dtype).max)
return (data / factor).astype(np.dtype(target_dtype))
return (data / factor).astype(target_dtype)


class Dm3ImageReader:
Expand Down
9 changes: 4 additions & 5 deletions wkcuber/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def read_metadata_for_layer(wkw_path, layer_name):
return layer_info, dtype, bounding_box, origin


def convert_dype_to_element_class(dtype):
def convert_dtype_to_element_class(dtype):
element_class_to_dtype_map = {
"float": np.float32,
"double": np.float64,
Expand All @@ -173,20 +173,19 @@ def convert_dype_to_element_class(dtype):
"uint64": np.uint64,
}
conversion_map = {v: k for k, v in element_class_to_dtype_map.items()}
return conversion_map.get(dtype.type, str(dtype))
return conversion_map.get(dtype, str(dtype))


def detect_dtype(dataset_path, layer, mag: Mag = Mag(1)):
layer_path = path.join(dataset_path, layer, str(mag))
if path.exists(layer_path):
with wkw.Dataset.open(layer_path) as dataset:
voxel_type = dataset.header.voxel_type
voxel_size = dataset.header.voxel_type
num_channels = dataset.header.num_channels
voxel_size = np.dtype(voxel_type)
if voxel_size == np.uint8 and num_channels > 1:
return "uint" + str(8 * num_channels)
else:
return convert_dype_to_element_class(voxel_size)
return convert_dtype_to_element_class(voxel_size)


def detect_cubeLength(dataset_path, layer, mag: Mag = Mag(1)):
Expand Down
Loading

0 comments on commit 66dd9f5

Please sign in to comment.