Skip to content

Commit

Permalink
extra attributes to nc files
Browse files Browse the repository at this point in the history
  • Loading branch information
LucR31 committed Sep 4, 2024
1 parent a1a132f commit d88a9ce
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
6 changes: 5 additions & 1 deletion aiida_flexpart/data/nc_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@ def __init__(
"""
Data plugin for Netcdf files.
"""
super().__init__(**kwargs)
super(NetCdfData,self).__init__()
if filepath is not None:
filename = os.path.basename(filepath)
self.set_remote_path(remote_path)
self.set_filename(filename)
self.set_global_attributes(g_att, nc_dimensions)
if kwargs:
for k,v in kwargs.items():
self.base.attributes.set(k, v)


def set_filename(self, val):
self.base.attributes.set("filename", val)
Expand Down
9 changes: 6 additions & 3 deletions aiida_flexpart/workflows/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def check(nc_file, version):
return True

@calcfunction
def store(remote_dir, file):
def store(remote_dir, file, time_label, nc_type):
with tempfile.TemporaryDirectory() as td:
remote_path = Path(remote_dir.get_remote_path()) / file.value
temp_path = Path(td) / file.value
Expand All @@ -46,6 +46,8 @@ def store(remote_dir, file):
computer=remote_dir.computer,
g_att=global_att,
nc_dimensions=nc_dimensions,
time_label = time_label.value,
nc_type = nc_type.value
)

if "history" in node.attributes["global_attributes"].keys():
Expand All @@ -68,12 +70,13 @@ def define(cls, spec):
spec.input_namespace(
"remotes_cs", valid_type=orm.RemoteStashFolderData, required=False
)
spec.input('time_label', valid_type=orm.Str, required=False)
spec.input('nc_type', valid_type=orm.Str, required=False)
spec.outputs.dynamic = True
spec.outline(
cls.fill_remote_data,
cls.inspect,
)

def fill_remote_data(self):
self.ctx.dict_remote_data = {}
if "remotes" in self.inputs:
Expand All @@ -88,4 +91,4 @@ def inspect(self):
for _, i in self.ctx.dict_remote_data.items():
for file in i.listdir():
if ".nc" in file:
store(i, file)
store(i, file, self.inputs.time_label, self.inputs.nc_type)

0 comments on commit d88a9ce

Please sign in to comment.