From d88a9ce150d36bfb419803457c7f53e488475102 Mon Sep 17 00:00:00 2001 From: LucR31 <94859181+LucR31@users.noreply.github.com> Date: Wed, 4 Sep 2024 07:08:40 +0000 Subject: [PATCH] extra attributes to nc files --- aiida_flexpart/data/nc_data.py | 6 +++++- aiida_flexpart/workflows/inspect.py | 9 ++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/aiida_flexpart/data/nc_data.py b/aiida_flexpart/data/nc_data.py index 7ffa3ae..8d2d82b 100644 --- a/aiida_flexpart/data/nc_data.py +++ b/aiida_flexpart/data/nc_data.py @@ -10,12 +10,16 @@ def __init__( """ Data plugin for Netcdf files. """ - super().__init__(**kwargs) + super(NetCdfData,self).__init__() if filepath is not None: filename = os.path.basename(filepath) self.set_remote_path(remote_path) self.set_filename(filename) self.set_global_attributes(g_att, nc_dimensions) + if kwargs: + for k,v in kwargs.items(): + self.base.attributes.set(k, v) + def set_filename(self, val): self.base.attributes.set("filename", val) diff --git a/aiida_flexpart/workflows/inspect.py b/aiida_flexpart/workflows/inspect.py index 054a2b3..5c83f1a 100644 --- a/aiida_flexpart/workflows/inspect.py +++ b/aiida_flexpart/workflows/inspect.py @@ -27,7 +27,7 @@ def check(nc_file, version): return True @calcfunction -def store(remote_dir, file): +def store(remote_dir, file, time_label, nc_type): with tempfile.TemporaryDirectory() as td: remote_path = Path(remote_dir.get_remote_path()) / file.value temp_path = Path(td) / file.value @@ -46,6 +46,8 @@ def store(remote_dir, file): computer=remote_dir.computer, g_att=global_att, nc_dimensions=nc_dimensions, + time_label = time_label.value, + nc_type = nc_type.value ) if "history" in node.attributes["global_attributes"].keys(): @@ -68,12 +70,13 @@ def define(cls, spec): spec.input_namespace( "remotes_cs", valid_type=orm.RemoteStashFolderData, required=False ) + spec.input('time_label', valid_type=orm.Str, required=False) + spec.input('nc_type', valid_type=orm.Str, required=False) spec.outputs.dynamic = True spec.outline( cls.fill_remote_data, cls.inspect, ) - def fill_remote_data(self): self.ctx.dict_remote_data = {} if "remotes" in self.inputs: @@ -88,4 +91,4 @@ def inspect(self): for _, i in self.ctx.dict_remote_data.items(): for file in i.listdir(): if ".nc" in file: - store(i, file) + store(i, file, self.inputs.time_label, self.inputs.nc_type)