Skip to content

Commit

Permalink
Datastore to raw file and job updates
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Nov 27, 2023
1 parent e1acc38 commit e530d61
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 9 deletions.
Binary file modified log/events.out.tfevents.1701113995.lee-htem-gpu0
Binary file not shown.
5 changes: 2 additions & 3 deletions src/autoseg/train/ACLSDTrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

def aclsd_train(
raw_file: str = "../../data/xpress-challenge.zarr",
raw_dataset: str = "volumes/training_raw",
out_file: str = "./raw_predictions.zarr",
iterations: int = 100000,
warmup: int = 200000,
Expand Down Expand Up @@ -95,7 +94,7 @@ def aclsd_train(
gp.ZarrSource(
raw_file,
{
raw: raw_dataset,
raw: "volumes/training_raw",
},
{
raw: gp.ArraySpec(interpolatable=True),
Expand All @@ -121,7 +120,7 @@ def aclsd_train(
gt_source = gp.ZarrSource(
raw_file,
{
raw: raw_dataset,
raw: "volumes/training_raw",
labels: f"volumes/training_gt_labels",
labels_mask: f"volumes/training_labels_mask",
unlabelled: f"volumes/training_unlabelled_mask",
Expand Down
4 changes: 2 additions & 2 deletions src/autoseg/train/MTLSDTrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
torch.backends.cudnn.benchmark = True


def mtlsd_train(iterations: int, data_store: str, voxel_size: int = 33):
def mtlsd_train(iterations: int, raw_file: str, voxel_size: int = 33):
raw = gp.ArrayKey("RAW")
labels = gp.ArrayKey("LABELS")
labels_mask = gp.ArrayKey("LABELS_MASK")
Expand Down Expand Up @@ -76,7 +76,7 @@ def mtlsd_train(iterations: int, data_store: str, voxel_size: int = 33):
request.add(pred_lsds, output_size)

source = gp.ZarrSource(
store=data_store,
store=raw_file,
datasets={
raw: f"volumes/training_raw",
labels: f"volumes/training_gt_labels",
Expand Down
10 changes: 6 additions & 4 deletions src/autoseg/train_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

def train_model(
model_type: str = "MTLSD",
iterations: int = 10000,
data_store: str = "path/to/zarr/or/n5",
iterations: int = 100000,
warmup: int = 100000,
raw_file: str = "path/to/zarr/or/n5",
voxel_size: int = 33,
) -> None:
match model_type.lower():
case "mtlsd":
mtlsd_train(iterations=iterations, data_store=data_store)
mtlsd_train(iterations=iterations, raw_file=raw_file, voxel_size=voxel_size)
case "aclsd":
raise NotImplementedError
aclsd_train(iterations=iterations, raw_file=raw_file, warmup=warmup)
case "stelarr":
raise NotImplementedError

0 comments on commit e530d61

Please sign in to comment.