-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSection.py
3907 lines (3175 loc) · 149 KB
/
Section.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import collections
import glob
import os
import pickle
import time
import jax
import skimage.io
import sys
from scipy.ndimage import gaussian_filter
from scipy.interpolate import griddata
sof_path1 = r"\\tungsten-nas.fmi.ch\tungsten\scratch\gmicro_sem\gfriedri\tgan\packages\sofima-forked-dev"
sof_path2 = r"/tungstenfs/scratch/gmicro_sem/gfriedri/tgan/packages/sofima-forked-dev"
sys.path.append(sof_path1)
sys.path.append(sof_path2)
from sofima import stitch_rigid, stitch_elastic, flow_utils, mesh, warp
from Tile import Tile
import inspection_utils as utils
from inspection_utils import Num
import mask_utils as mutils
import cv2
import csv
from contextlib import suppress
import functools as ft
import jax.numpy as jnp
import json
import logging
import matplotlib.pyplot as plt
import numpy as np
from numpy import ndarray, dtype
from pathlib import Path
import re
import skimage
from skimage.metrics import structural_similarity as ssim
from skimage import filters
from typing import Tuple, List, Set, Union, Optional, Dict, Mapping, Iterable, Any
from ruyaml import YAML
from ruyaml.scalarfloat import ScalarFloat
yaml = YAML(typ="rt")
import experiment_configs as cfg
sof_path1 = r"\\tungsten-nas.fmi.ch\tungsten\scratch\gmicro_sem\gfriedri\tgan\scripts\pythonProject"
sof_path2 = r"/tungstenfs/scratch/gmicro_sem/gfriedri/tgan/scripts/pythonProject"
sys.path.append(sof_path1)
sys.path.append(sof_path2)
from SOFIMA import sofima_files as sutils
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["XLA_FLAGS"] = "--xla_gpu_strict_conv_algorithm_picker=false"
# Set up logging
# logging.basicConfig(level=logging.DEBUG)
# logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.WARNING)
Vector = Union[Tuple[int, int], Tuple[int, int, int], Union[Tuple[int], Tuple[Any, ...]]] # [z]yx order
# Vector = Union[tuple[int, int], tuple[int, int, int]]
UniPath = Union[str, Path]
TileXY = Tuple[int, int]
ShapeXYZ = Tuple[int, int, int]
TileFlow = Dict[TileXY, np.ndarray]
TileOffset = Dict[TileXY, Vector]
TileMap = Dict[TileXY, np.ndarray]
FineFlow = Union[Tuple[TileFlow, TileOffset], None]
FineFlows = Tuple[Optional[FineFlow], Optional[FineFlow]]
MaskMap = Dict[TileXY, Optional[np.ndarray]]
TileFlowData = Tuple[np.ndarray, TileFlow, TileOffset]
MarginOverrides = Dict[TileXY, Tuple[int, int, int, int]]
GridXY = Tuple[Any, Any, Any]
ExpConfig = collections.namedtuple('ExpConfig', ['path', 'grid_num', 'first_sec', 'last_sec', 'grid_shape'])
class Section:
def __init__(self, path: Union[Path, str]):
path = Path(utils.cross_platform_path(str(path)))
if not path.is_dir():
m = f"The path '{path}' is not a directory or does not exist."
raise NotADirectoryError(m)
self.path: Path = path
self.path_stitched: Path = self.resolve_dir_stitched()
self.path_stitched_custom = self.path.parent.parent / 'stitched-sections-derotated' / (str(self.path.name) + ".zarr")
self.path_section_yaml = str(self.path / 'section.yaml')
self.path_stats_yaml = str(self.path / 'stats.yaml')
self.path_margin_masks = str(self.path / 'margin_masks.npz')
self.path_cmesh = str(self.path / 'coarse_mesh.pkl')
self.path_fmesh = str(self.path / 'meshes.npz')
self.path_fflows = str(self.path / 'fflows.pkl')
self.path_shift_to_prev = str(self.path_stitched / 'shift_to_previous.json')
self.section_num = int(str(self.path.name).split("_g")[0][1:])
self._grid_num = int(self.path.name[-1])
self.tile_shape = utils.get_tile_shape(self.path_section_yaml)
self.tile_id_map: Optional[np.ndarray[int]] = None
self.tile_dicts: Optional[Dict[int, str]] = None
self.cxy: Optional[np.ndarray[float]] = None
self.coarse_mesh: Optional[np.ndarray[float]] = None
self.mesh_offsets: Optional[np.ndarray[float]] = None
self.fflows: Optional[FineFlows] = None
self.fflows_clean: Optional[FineFlows] = None
self.fflows_recon: Optional[FineFlows] = None
self.fmesh: Optional[Dict[TileXY, np.ndarray]] = None
self.tile_map: Optional[TileMap] = None
self.mask_map: MaskMap = {} # TODO: make None as default (?)
self.roi_mask_map: MaskMap = {}
self.smr_mask_map: MaskMap = {}
self.margin_overrides: Optional[Dict[TileXY, Tuple[int, int, int, int]]] = None
self.margin_masks: Optional[Dict[TileXY, np.ndarray]] = None
self.path_thumb = self.resolve_path_thumb()
self.thumb: Optional[np.ndarray] = None
self.roi_inspect = None
self.roi_conv = None
self.filtered_image = None
self.image: Optional[np.ndarray] = None
self.height: Optional[int] = None # tile_shape already in attributes, refactor.
self.width: Optional[int] = None
self._stitched: bool = False
@property
def stitched(self) -> bool:
MIN_ZARR_SIZE = 1000 # in kb
# Verify that stitched folder has valid size
zarr_size = utils.get_folder_size(str(self.path_stitched))
if zarr_size is None:
self._stitched = False
return False
zarr_size_valid = True if zarr_size > MIN_ZARR_SIZE else False
if not zarr_size_valid:
self._stitched = False
return False
zarr_file_valid = utils.check_zarr_integrity(self.path_stitched)
if not zarr_file_valid:
self._stitched = False
return False
self._stitched = True
return True
@property
def grid_num(self) -> int:
return self._grid_num
# @property
# def stitched(self) -> bool:
# return self.path_stitched.exists()
@property
def cross_aligned(self) -> bool:
return Path(self.path_shift_to_prev).exists()
def clear_section(self):
del self.fflows
del self.fflows_recon
del self.fflows_clean
del self.tile_map
del self.mask_map
del self.roi_mask_map
del self.smr_mask_map
del self.margin_masks
return
def load_fmesh(self) -> None:
self.fmesh = np.load(self.path_fmesh)
return
def fix_coarse_offset(self,
tid_a: int,
tid_b: int,
est_vec_orig: bool,
skip_xcorr: bool,
refine: bool,
xcorr_kwargs: dict,
refine_kwargs: dict,
plot_ov=False,
) -> None:
""" Recomputes specific coarse offset vector by SOFIMA or refine pyramid
Recompute either by both methods or refinement only. Refinement can be done
with respect to original coarse shift vector or specific vector. Optionally,
plot resulting overlap image to section folder. Specify SOFIMA registration
and/or refinement parameters and pass it ot the function.
:param tid_a:
:param tid_b:
:param est_vec_orig:
:param skip_xcorr:
:param refine:
:param xcorr_kwargs:
:param refine_kwargs:
:param plot_ov:
:return:
"""
def opt_coo(section: Section) -> None:
if not skip_xcorr:
shift_vec = section.compute_coarse_offset(**xcorr_kwargs)
logging.info(f'computed coarse shift: {shift_vec}')
print(f'computed coarse shift: {shift_vec}')
if refine:
best_vec = self.refine_pyramid(**refine_kwargs)
print(f'refined vector: {best_vec}')
return
self.feed_section_data()
# Get original coarse shift vector if requested,
# otherwise average will be estimated
is_vert = utils.pair_is_vertical(self.tile_id_map, tid_a, tid_b)
axis = 1 if is_vert else 0
if est_vec_orig:
refine_kwargs['est_vec'] = self.get_coarse_offset(tid_a, axis)
# Perform optimization
opt_coo(self)
# Plot OV
if plot_ov:
self.plot_ov(
tid_a, tid_b, self.get_coarse_offset(tid_a, axis),
dir_out=self.path, clahe=True, blur=1.3, show_plot=False,
rotate_vert=True, store_to_root=True
)
return
def verify_tile_id_map(self, print_ids: bool = False) -> bool:
# Get tile IDs from section yaml file
yaml_tile_ids: Set[int] = set(utils.get_tile_ids(self.path))
if not yaml_tile_ids:
logging.warning(f'Verify tile_id_map: No tile IDs found in section .yaml file.')
return False
# Get tile IDs form section tile_id_map.json
if not self.tile_id_map:
self.read_tile_id_map()
if not isinstance(self.tile_id_map, np.ndarray):
logging.warning(f'Verify tile_id_map: no tile IDs found in section tile_id_map.json')
return False
tile_id_map_ids = set(self.tile_id_map.flatten())
if -1 in tile_id_map_ids:
tile_id_map_ids.remove(-1)
# Perform comparison
eq = yaml_tile_ids == tile_id_map_ids
# Print info if tile IDs are not the same in both sets
if not eq and print_ids:
sec_num = self.section_num
ids = yaml_tile_ids.symmetric_difference(tile_id_map_ids)
logging.warning(f'section s{sec_num} yaml tile IDs: {sorted(list(yaml_tile_ids))}')
logging.warning(f'section s{sec_num} tile_id_map IDs: {sorted(list(tile_id_map_ids))}')
logging.warning(f'missing s{sec_num} tile ids: {sorted(list(ids))}')
return eq
def rotate_and_store_stitched(self, rot_angle: float, dst_dir: UniPath) -> None:
stitched_img = self.load_image()
assert isinstance(stitched_img, np.ndarray), "rotate_stitched failed to load stitched .zarr file"
img_rot = utils.rotate_image(stitched_img, rot_angle)
utils.store_section_zarr(img_rot, self.path.name + '.zarr', dst_dir)
def warp_section(self,
stride: int,
margin: int = 0,
use_clahe: bool = False,
clahe_kwargs: ... = None,
zarr_store=True,
rescale_fct: Optional[float] = None,
parallelism: int = 1,
rot_angle: float = 0
) -> None:
# Load mesh
if self.fmesh is None:
self.fmesh = utils.load_mapped_npz(self.path_fmesh)
if self.fmesh is None:
logging.warning(
f'Warping s{self.section_num} failed: {Path(self.path_fmesh).name} could not be loaded.')
return
# Load tile-data
if self.tile_dicts is None:
self.feed_section_data()
if self.tile_map is None:
self.load_tile_map(clahe=use_clahe)
if self.tile_map is None:
logging.warning(f'Warping s{self.section_num} failed: tile-map could not be loaded.')
return
# Load margin masks
if self.margin_masks is None:
self.margin_masks = utils.load_mapped_npz(self.path_margin_masks)
# self.margin_masks = None
# Warp the tiles into a single image
stitched, mask = warp.render_tiles(
tiles=self.tile_map,
coord_maps=self.fmesh,
stride=(stride, stride),
margin=margin,
use_clahe=use_clahe,
clahe_kwargs=clahe_kwargs,
tile_masks=self.margin_masks,
parallelism=parallelism
)
if rot_angle != 0:
stitched = utils.rotate_image(stitched, rot_angle)
path_stitched = self.path_stitched.parent
if zarr_store:
utils.store_section_zarr(stitched, self.path.name + '.zarr', path_stitched)
# # Downscale warped image and save image data to disk
# if rescale_fct is not None:
# thumb_img = utils.downscale_image(stitched, rescale_fct)
# ext = f'_thumb_{rescale_fct}.png'
# name_end = str(self.path.name).split("_")[1]
# zfilled = str(self.section_num).zfill(5)
# new_name = "s" + zfilled + name_end
# thumb_fn = str(self.path / (new_name + ext))
# cv2.imwrite(thumb_fn, cv2.convertScaleAbs(thumb_img))
print(f'Section {self.section_num} stitched and warped.')
return
def compute_fine_mesh(self, config, stride: int, store=True) -> None:
if self.tile_dicts is None:
self.feed_section_data()
if self.tile_map is None:
self.load_tile_map()
if self.coarse_mesh is None:
self.load_coarse_mesh()
if self.coarse_mesh is None:
logging.warning(f'compute_fine_mesh s{self.section_num} failed (coarse mesh not loaded.)')
return
# Prepare data for mesh computation
if self.fflows is None:
self.load_fflows()
self.clean_fflows()
self.reconcile_fflows()
cx, cy = np.squeeze(self.cxy)
ffx, ffxo = self.fflows_recon[0]
ffy, ffyo = self.fflows_recon[1]
data_x: Tuple[np.ndarray, TileFlow, TileOffset] = (cx, ffx, ffxo)
data_y: Tuple[np.ndarray, TileFlow, TileOffset] = (cy, ffy, ffyo)
fx, fy, nds, nbors, key_to_idx = stitch_elastic.aggregate_arrays(
data_x, data_y, list(self.tile_map.keys()),
self.coarse_mesh[:, 0, ...], stride=(stride, stride),
tile_shape=next(iter(self.tile_map.values())).shape)
@jax.jit
def prev_fn(nds):
target_fn = ft.partial(stitch_elastic.compute_target_mesh, x=nds, fx=fx,
fy=fy, stride=(stride, stride))
nds = jax.vmap(target_fn)(nbors)
return jnp.transpose(nds, [1, 0, 2, 3])
if config is None:
config = mesh.IntegrationConfig(
dt=0.001, gamma=0., k0=0.01, k=0.1, stride=stride,
num_iters=1000, max_iters=20000, stop_v_max=0.001,
dt_max=100, prefer_orig_order=True,
start_cap=0.1, final_cap=10., remove_drift=True
)
# Compute fine mesh
res, _, _ = mesh.relax_mesh(nds, None, config, prev_fn=prev_fn)
# Unpack meshes into a dictionary.
idx_to_key = {v: k for k, v in key_to_idx.items()}
self.fmesh = {idx_to_key[i]: np.array(res[:, i:i + 1:, :]) for i in range(res.shape[1])}
# Save mesh for later processing
if store:
meshes_to_save = {str(k): v for k, v in self.fmesh.items()}
np.savez(self.path_fmesh, **meshes_to_save)
return
def clean_fflows(self,
min_pr: float = 1.4,
min_ps: float = 1.4,
max_mag: float = 0.,
max_dev: float = 5., ) -> None:
if self.fflows is None:
logging.warning(f's{self.section_num} clean fflows failed: fine flows not available.')
return
fine_x, offsets_x = self.fflows[0]
fine_y, offsets_y = self.fflows[1]
kwargs = {"min_peak_ratio": min_pr, "min_peak_sharpness": min_ps, "max_deviation": max_dev,
"max_magnitude": max_mag}
fine_x = {k: flow_utils.clean_flow(v[:, np.newaxis, ...], **kwargs)[:, 0, :, :] for k, v in fine_x.items()}
fine_y = {k: flow_utils.clean_flow(v[:, np.newaxis, ...], **kwargs)[:, 0, :, :] for k, v in fine_y.items()}
ffx = fine_x, offsets_x
ffy = fine_y, offsets_y
self.fflows_clean = (ffx, ffy)
return
def reconcile_fflows(self,
max_gradient: float = -1.,
max_deviation: float = -1.,
min_patch_size: int = 10) -> None:
fine_x, offsets_x = self.fflows_clean[0]
fine_y, offsets_y = self.fflows_clean[1]
kwargs = {"min_patch_size": min_patch_size, "max_gradient": max_gradient, "max_deviation": max_deviation}
fine_x = {k: flow_utils.reconcile_flows([v[:, np.newaxis, ...]], **kwargs)[:, 0, :, :] for k, v in
fine_x.items()}
fine_y = {k: flow_utils.reconcile_flows([v[:, np.newaxis, ...]], **kwargs)[:, 0, :, :] for k, v in
fine_y.items()}
ffx = fine_x, offsets_x
ffy = fine_y, offsets_y
self.fflows_recon = (ffx, ffy)
return
def compute_fine_flows(self,
patch_size: int, stride: int, masking=False,
store=True, overwrite: bool = False,
ext: Optional[str] = None) -> None:
def load_infra() -> bool:
if self.cxy is None:
self.feed_section_data()
if self.cxy is None:
logging.warning(f'compute_fine_flows section s{self.section_num}: coarse offset array not loaded!')
return False
if self.tile_map is None:
self.load_tile_map(clahe=False)
if self.tile_map is None:
return False
if not self.mask_map and masking:
self.load_masks()
return True
def iter_compute_flows(patch_size=patch_size, ff_iter=0, max_iter=5,
min_patch_size=10, step=5) -> None:
"""Compute fine flows with iterative decrease if patch size in case of SOFIMA negative dimension error"""
def compute_flows(ps: int, axis: int) -> Tuple[TileFlow, TileOffset]:
return stitch_elastic.compute_flow_map(self.tile_map, self.cxy[axis], axis,
patch_size=(ps, ps),
stride=(stride, stride), batch_size=256,
tile_masks=self.mask_map)
logging.info(f'Computing fine flows for section s{self.section_num}')
while self.fflows is None or ff_iter == max_iter:
try:
logging.info(f's{self.section_num} fflows params: iter={ff_iter}, patch_size={patch_size}')
self.fflows = (compute_flows(patch_size, axis=0),
compute_flows(patch_size, axis=1))
except ValueError as _:
logging.warning(
f'fine flows s{self.section_num} decreasing patch size ({patch_size} -> {patch_size - step})')
patch_size -= step
ff_iter += 1
if patch_size < min_patch_size:
ff_iter = max_iter
logging.info(
f's{self.section_num} fine flows computed after {ff_iter + 1} iterations. Final patch size: {patch_size}')
return
def store_fflows(fine_flows: FineFlows, ext: Optional[str]):
ext = '' if ext is None else ext
fname = f'fflows{ext}.pkl'
with open(self.path / fname, 'wb') as f:
pickle.dump(fine_flows, f)
return
# Load necessary data first
infra_loaded = load_infra()
if not infra_loaded:
return
if not overwrite:
self.load_fflows() # Load if present
if self.fflows is not None:
logging.info(f'compute_fine_flows skipping s{self.section_num} (fflows already exists and overwrite is disabled).')
return
# Compute fine flows and fine offsets
iter_compute_flows()
# Store results
if store and self.fflows is not None:
store_fflows(self.fflows, ext)
return
def load_fflows(self, ext: Optional[str] = None) -> None:
ext = '' if ext is None else ext
fp_fflows = self.path / f'fflows{ext}.pkl'
logging.info(f's{self.section_num}: loading fine flows from {fp_fflows}.')
try:
with open(fp_fflows, 'rb') as f:
self.fflows = pickle.load(f)
if (not isinstance(self.fflows, tuple) or len(self.fflows) != 2
or not all(isinstance(item, (dict, type(None))) for item in self.fflows)):
logging.info(
f's{self.section_num} loaded fine flows from {fp_fflows}, but data appears to be corrupted or invalid.')
# print(len(self.fflows))
# print(self.fflows)
# self.fflows = None # Reset to default value
else:
logging.info(f's{self.section_num}: successfully loaded fine flows from {fp_fflows}.')
except FileNotFoundError:
logging.warning(f's{self.section_num}: fine flows file {fp_fflows} not found!')
except EOFError:
logging.warning(f's{self.section_num}: EOFError - Ran out of input while reading {fp_fflows}.')
except pickle.UnpicklingError as e:
logging.error(f's{self.section_num}: Error while unpickling {fp_fflows}: {e}')
except Exception as e:
logging.error(f"An error occurred while reading '{fp_fflows}': {e}")
def check_and_load_coarse_mesh(self) -> bool:
try:
if Path(self.path_cmesh).exists():
self.load_coarse_mesh()
return self.coarse_mesh is not None
else:
print(f"File '{self.path_cmesh}' does not exist.")
return False
except Exception as e:
print(f"An error occurred while checking and loading '{self.path_cmesh}': {e}")
return False
def compute_coarse_mesh(self, cfg: Optional[mesh.IntegrationConfig] = None, store=True, overwrite=False) -> None:
if cfg is None:
cfg = mesh.IntegrationConfig(
dt=0.001,
gamma=0.0,
k0=0.0, # unused
k=0.1,
stride=(1, 1), # unused
num_iters=1000,
max_iters=100000,
stop_v_max=0.001,
dt_max=100,
)
def store_cmesh(data):
with open(self.path_cmesh, 'wb') as f:
pickle.dump(data, f)
return
logging.info('Computing coarse mesh ...')
if self.check_and_load_coarse_mesh() and not overwrite:
return
try:
cx, cy = self.get_coarse_mat()
except TypeError as _:
logging.warning(f's{self.section_num} coarse mesh not computed')
return
if cx.ndim != 4:
cx = cx[:, np.newaxis, ...]
cy = cy[:, np.newaxis, ...]
self.coarse_mesh = stitch_rigid.optimize_coarse_mesh(cx, cy, cfg)
if self.coarse_mesh is None:
logging.warning(f'Section s{self.section_num} coarse mesh not computed.')
elif store:
logging.info(f'Storing coarse mesh.')
store_cmesh(self.coarse_mesh)
return
def load_coarse_mesh(self) -> None:
try:
with open(self.path_cmesh, 'rb') as f:
self.coarse_mesh = pickle.load(f)
logging.info(f"s{self.section_num} coarse mesh loaded")
except EOFError:
print("EOFError: Ran out of input while reading the pickled data.")
except FileNotFoundError:
print(f"File '{self.path_cmesh}' not found.")
except Exception as e:
print(f"An error occurred while reading cmesh {self.path_cmesh}: {e}")
def create_masks(self,
roi_thresh: int,
max_vert_ext: int,
edge_only: bool,
n_lines: int,
store: bool = False,
filter_size: int = 20,
range_limit: int = 0
) -> Optional[MaskMap]:
"""
Create ROI and smearing masks for tiles and store them if specified.
Args:
roi_thresh: parameter influencing detection sensitivity of silver particles in resin
max_vert_ext: Mask only specified number of lines from top of the image
edge_only: If True, mask only top N lines of tile-data and do not compute smearing mask
n_lines: Number of lines from top of the image top be fully masked
store (bool, optional): Whether to store the masks. Defaults to False.
Returns:
Optional[MaskMap]: A map of masks.
"""
def get_masks(tid: int, tile_coords: Tuple[int, int]) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
m_roi = None
m_smr = None
if tid == -1:
return m_roi, m_smr
tile_path = self.tile_dicts.get(tid, None)
if tile_path is not None and Path(tile_path).exists():
tile = Tile(tile_path)
if self.tile_map is not None:
tile.img_data = self.tile_map[tile_coords]
m_smr = tile.create_smearing_mask(max_vert_ext, n_lines, edge_only)
if tid not in inner_ids:
m_roi = tile.create_roi_mask(roi_thresh, filter_size, range_limit)
else:
logging.debug(f'tile_path does not exist')
try:
logging.debug(f'shapes: {m_roi.shape, m_smr.shape}')
except AttributeError as e:
logging.debug(f'WARNING: {tid} attribute error!')
return m_roi, m_smr
# Ensure necessary data is loaded
if self.tile_id_map is None:
self.read_tile_id_map()
if self.tile_id_map is None:
return None
if self.tile_dicts is None:
self.feed_section_data()
# Initialize returned values
mask_map = {}
roi_mask_map = {}
smr_mask_map = {}
# Get tile_ids that are at the edge of ROI and thus may contain resin in image-data
inner_ids: List[int] = mutils.tile_ids_with_all_neighbors(self.tile_id_map)
h, w = self.tile_id_map.shape
for y in range(h):
for x in range(w):
tile_id = int(self.tile_id_map[y, x])
roi_mask, smr_mask = get_masks(tile_id, (x, y))
roi_mask_map[(x, y)] = roi_mask
smr_mask_map[(x, y)] = smr_mask
comb_mask = None
if roi_mask is not None and smr_mask is not None:
h1, w1 = np.shape(roi_mask)
h2, w2 = np.shape(smr_mask)
pad_smr = np.pad(smr_mask, ((0, h1 - h2), (0, w1 - w2)), 'constant')
comb_mask = np.logical_or(roi_mask, pad_smr)
# comb_mask = roi_mask.copy()
# comb_mask[:roi_mask.shape[0]] |= smr_mask
if roi_mask is not None and smr_mask is None:
comb_mask = roi_mask
if smr_mask is not None and roi_mask is None:
img = skimage.io.imread(self.tile_dicts[tile_id])
tile_shape = np.shape(img)
comb_mask = np.full(tile_shape, fill_value=False)
comb_mask[:np.shape(smr_mask)[0], :] = smr_mask
mask_map[(x, y)] = comb_mask
self.smr_mask_map = smr_mask_map
self.roi_mask_map = roi_mask_map
self.mask_map = mask_map
# Store masks if requested
if store:
fns = ('roi_masks.npz', 'smr_masks.npz', 'tile_masks.npz')
maps = (roi_mask_map, smr_mask_map, mask_map)
for fn, i_map in zip(fns, maps):
data = {str(k): v for k, v in i_map.items()}
fp = self.path / fn
logging.debug(f'saving to: {fp}')
np.savez_compressed(fp, **data)
return None
def load_masks(self):
"""Loads binary masks associated to each tile within section"""
fns = ('roi_masks.npz', 'smr_masks.npz', 'tile_masks.npz')
maps = (self.roi_mask_map, self.smr_mask_map, self.mask_map)
for fn, i_map in zip(fns, maps):
path_mask = self.path / fn
if not path_mask.exists():
logging.warning(f's{self.section_num} {fn} does not exist!')
continue
try:
data = np.load(path_mask, allow_pickle=True)
for key, item in data.items():
i_map[eval(key)] = item if item.size != 1 else None
except FileNotFoundError:
logging.info(f"s{self.section_num}: {path_mask} not loaded.")
except Exception as e:
print(f"An error occurred: s{self.section_num} {fn}.", e)
# self.margin_masks = utils.load_mapped_npz(self.path_margin_masks)
return
def build_mesh_offsets(self, mesh_config: Optional[mesh.IntegrationConfig] = None,
overwrite: Optional[bool] = True
) -> None:
"""Creates coarse offset matrix from coarse mesh values
Create coarse mesh offset array in form of individual tile offsets
in same notation as coarse offsets. These values will be used to create
margin overrides for section warping.
"""
def diff_mat(mat: np.ndarray, row_mode=False) -> np.ndarray:
if row_mode:
result = [[np.round(mat[i][j] - mat[i - 1][j]) for j in range(len(mat[i]))] for i in range(1, len(mat))]
else:
result = [[np.round(mat[i][j] - mat[i][j - 1]) for j in range(1, len(mat[i]))] for i in range(len(mat))]
return np.array(result)
def make_mesh_offsets() -> Optional[np.ndarray]:
if self.cxy is None:
_ = self.get_coarse_mat()
if Path(self.path_cmesh).exists():
self.load_coarse_mesh()
else:
self.compute_coarse_mesh(mesh_config, overwrite=overwrite)
if self.coarse_mesh is None:
return
cxx = diff_mat(self.coarse_mesh[0, 0, ...], row_mode=False)
cxy = diff_mat(self.coarse_mesh[1, 0, ...], row_mode=False)
cyx = diff_mat(self.coarse_mesh[0, 0, ...], row_mode=True)
cyy = diff_mat(self.coarse_mesh[1, 0, ...], row_mode=True)
mo = np.full_like(self.cxy, fill_value=np.nan)
nr, nc = cxx.shape
mo[0, 0, 0:nr, 0:nc] = cxx
mo[0, 1, 0:nr, 0:nc] = cxy
nr, nc = cyx.shape
mo[1, 0, 0:nr, 0:nc] = cyx
mo[1, 1, 0:nr, 0:nc] = cyy
mask = np.isnan(self.cxy)
mo[mask] = np.nan
return mo
self.mesh_offsets = make_mesh_offsets()
return
def feed_section_data(self):
self.tile_dicts = utils.get_tile_dicts(self.path)
self.read_tile_id_map()
_ = self.get_coarse_mat()
return
def build_margin_masks(self,
grid_shape: List[int],
margin: int = 20,
rim_size: int = 60,
overwrite: bool = False,
mesh_config: Optional[mesh.IntegrationConfig] = None
) -> None:
""" Creates masks for section rendering.
Margin masks allow to render overlap regions with better quality. Charging
and deformations are most often related to multiple-exposed regions. Margin
masks build on this fact and assign higher rendering priority to image-data
acquired on fresh sample surface. Procedure requires coarse offsets to be
computed and saved in advance.
grid_shape: Total number of rows and columns in the SBEMimage grid
margin: masks deformed borders of tiles due to elastic transformation
rim_size: Size of safety margin added to the coarse offset
to avoid holes in warped image. Defaults to 60 pixels.
"""
def create_sbem_grid() -> np.ndarray:
# Create virtual SBEMimage grid of active tiles
grid = np.full(shape=grid_shape, fill_value=-1)
rows, cols = grid_shape
for row_pos in range(rows):
for col_pos in range(cols):
tile_index = row_pos * cols + col_pos
if tile_index in self.tile_id_map:
grid[row_pos, col_pos] = tile_index
return grid
def tile_mask_junction(
tile_id: int,
row_is_odd: bool,
grid: np.ndarray,
rim: int,
min_rim: int = 5,
n_smr_lines: int = 20
) -> Optional[np.ndarray[bool]]:
"""Create tile mask for rendering
:param n_smr_lines:
:param rim: Size of safety margin added to the coarse offset
to avoid holes in warped image.
:param min_rim: minimal masked extent from each edge of a tile
"""
def eval_(xo):
dx_new = xo + rim + margin
dx_new = -min_rim if dx_new >= 0 else dx_new
return xo if abs(dx_new) > abs(xo) else dx_new
mask = np.full(self.tile_shape, fill_value=True)
y, x = np.where(tile_id == grid)
y, x = int(y[0]), int(x[0])
# Determine the next tile ID based on row parity
try:
tid = grid[y, x + (1 if row_is_odd else -1)]
except IndexError:
tid = -1
if row_is_odd:
# if margin != 0:
# mask[:, :margin] = False
if tid != -1:
offset = self.get_coarse_mesh_offset(tile_id)
dx, dy = eval_(offset[0]), offset[1]
if dy >= 0:
dyr = dy + rim if dy + rim > min_rim else min_rim
# print(f'tile_id: {tile_id} dyr+: {dyr} dx: {dx}')
mask[dyr:, dx:] = False
else:
dyr = dy - rim if abs(dy - rim) > min_rim else -min_rim
# print(f'tile_id: {tile_id} dyr-: {dyr} dx: {dx}')
mask[:dyr, dx:] = False
# Even rows
else:
# if margin != 0:
# mask[:, -margin:] = False
if tid != -1:
offset = self.get_coarse_mesh_offset(tile_id - 1)
dx, dy = eval_(offset[0]), offset[1]
if dy >= 0:
dyy = int(-dy - rim)
dyr = dyy if abs(dyy) > min_rim else -min_rim
mask[:dyr, :abs(dx)] = False
else:
dyy = int(abs(dy) + rim)
dyr = dyy if dyy > min_rim else min_rim
mask[dyr:, :abs(dx)] = False
# Set the top tile-edges mask
# # Mask elastic deformation on bottom edge (WHY?)
# if margin != 0:
# try:
# tid_nn_y = int(grid[y + 1, x])
# except IndexError:
# tid_nn_y = -1
# if tid_nn_y != -1:
# mask[-margin:, :] = False
# Mask top tile-edges
try:
tid_nn_y = int(grid[y - 1, x])
except IndexError:
tid_nn_y = -1
if tid_nn_y != -1:
offset = self.get_coarse_mesh_offset(tid_nn_y, axis=1)
dx, dy = offset[0], eval_(offset[1])
dy = min(offset[1] + rim, 0) # Testing phase
# dy = offset[1]
mask[:n_smr_lines] = False # Testing phase
if dx >= 0:
dxr = int(dx + rim) if int(dx + rim) != 0 else min_rim
# print(f'tile_id: {tile_id} dxr-: {dxr} dxy: {dx},{dy}')
mask[:abs(dy), :-dxr] = False
else:
dxr = int(abs(dx) + rim) if int(abs(dx) + rim) != 0 else min_rim
# print(f'tile_id: {tile_id} dxr-: {dxr} dxy: {dx},{dy}')
mask[:abs(dy), dxr:] = False
return mask
def create_margin_masks(tile_space: Tuple[TileXY], rim: int):
sbem_grid = create_sbem_grid()
self.margin_masks = {}
for tile_xy in tile_space:
x, y = tile_xy
tile_id = int(self.tile_id_map[y, x])
# Find the indices of the tile_id in sbem_grid
indices = np.where(tile_id == sbem_grid)
if len(indices[0]) == 0: # Tile not found in sbem_grid
continue
row, col = indices[0][0], indices[1][0]
odd_row = row % 2 != 0 # Checking for odd row directly
self.margin_masks[tile_xy] = tile_mask_junction(tile_id, odd_row, sbem_grid, rim)
return
def store_margin_masks():
if self.margin_masks is None:
logging.warning(f's{self.section_num} skipping storing None margin masks.')
return
data = {str(k): v for k, v in self.margin_masks.items()}
logging.debug(f'Storing margin masks to: {self.path_margin_masks}')
np.savez_compressed(self.path_margin_masks, **data)
return
if Path(self.path_margin_masks).exists() and not overwrite:
print('Skipping margin mask computation. File exists and overwriting is disabled.')
self.margin_masks = utils.load_mapped_npz(self.path_margin_masks)
return
if self.tile_id_map is None:
self.feed_section_data()
if self.mesh_offsets is None:
self.build_mesh_offsets(mesh_config=mesh_config)
if self.mesh_offsets is None:
logging.warning(f'Section s{self.section_num} mesh offsets could not be computed.')
return
tiles_xy = utils.build_tiles_coords(self.tile_id_map)
create_margin_masks(tiles_xy, rim_size)
store_margin_masks()
return
def build_margin_overrides(self,
grid_shape: Tuple[int, int],
rim: int = 10
) -> Optional[MarginOverrides]:
"""Builds margin overrides for each tile coordinate in section.
Args:
grid_shape: Total number of rows and columns in the SBEMimage grid
rim (int): Size of safety margin added to the coarse offset
to avoid holes in warped image. Defaults to 10 pixels.
Returns:
Optional[Dict[Tuple[int, int], Tuple[int, int, int, int]]]: A dictionary