diff --git a/navis/utils/cv.py b/navis/utils/cv.py index 4ddf2e05..6aa8fd39 100644 --- a/navis/utils/cv.py +++ b/navis/utils/cv.py @@ -56,20 +56,22 @@ def patch_cloudvolume(): # If CV not installed do nothing if not cv: - logger.info('cloud-volume appears to not be installed?') + logger.info("cloud-volume appears to not be installed?") return - for ds in [cv.datasource.graphene.mesh.sharded.GrapheneShardedMeshSource, - cv.datasource.graphene.mesh.unsharded.GrapheneUnshardedMeshSource, - cv.datasource.precomputed.mesh.unsharded.UnshardedLegacyPrecomputedMeshSource, - cv.datasource.precomputed.mesh.multilod.UnshardedMultiLevelPrecomputedMeshSource, - cv.datasource.precomputed.mesh.multilod.ShardedMultiLevelPrecomputedMeshSource, - cv.datasource.precomputed.skeleton.sharded.ShardedPrecomputedSkeletonSource, - cv.datasource.precomputed.skeleton.unsharded.UnshardedPrecomputedSkeletonSource]: + for ds in [ + cv.datasource.graphene.mesh.sharded.GrapheneShardedMeshSource, + cv.datasource.graphene.mesh.unsharded.GrapheneUnshardedMeshSource, + cv.datasource.precomputed.mesh.unsharded.UnshardedLegacyPrecomputedMeshSource, + cv.datasource.precomputed.mesh.multilod.UnshardedMultiLevelPrecomputedMeshSource, + cv.datasource.precomputed.mesh.multilod.ShardedMultiLevelPrecomputedMeshSource, + cv.datasource.precomputed.skeleton.sharded.ShardedPrecomputedSkeletonSource, + cv.datasource.precomputed.skeleton.unsharded.UnshardedPrecomputedSkeletonSource, + ]: ds.get_navis = return_navis(ds.get, only_on_kwarg=False) ds.get = return_navis(ds.get, only_on_kwarg=True) - logger.info('cloud-volume successfully patched!') + logger.info("cloud-volume successfully patched!") def return_navis(func, only_on_kwarg=False): @@ -82,12 +84,15 @@ def return_navis(func, only_on_kwarg=False): only_on_kwarg : bool If True, will look for a `as_navis=True` (default=False) keyword argument to determine if results should be converted - to navis neurons. + to navis neurons. If 'process' is set to False, the neuron + will not be processed by TriMesh (remove nan, duplicate vertices,etc) """ + @functools.wraps(func) def wrapper(*args, **kwargs): - ret_navis = kwargs.pop('as_navis', False) + ret_navis = kwargs.pop("as_navis", False) + process = kwargs.pop("process", False) res = func(*args, **kwargs) if not only_on_kwarg or ret_navis: @@ -99,18 +104,20 @@ def wrapper(*args, **kwargs): for k, v in res.items(): if isinstance(v, cv.Mesh): - n = core.MeshNeuron(v, id=k, units='nm') + n = core.MeshNeuron(v, id=k, units="nm", process=process) neurons.append(n) elif isinstance(v, cv.Skeleton): swc_str = v.to_swc() n = io.read_swc(swc_str) n.id = k - n.units = 'nm' + n.units = "nm" neurons.append(n) else: - logger.warning(f'Skipped {k}: Unable to convert {type(v)} to ' - 'navis Neuron.') + logger.warning( + f"Skipped {k}: Unable to convert {type(v)} to " "navis Neuron." + ) return core.NeuronList(neurons) return res + return wrapper