Skip to content

Commit

Permalink
Inspect workflow (#33)
Browse files Browse the repository at this point in the history
* inspect workflow and small additions in nc_data

* added new workflow
  • Loading branch information
LucR31 authored Jun 12, 2024
1 parent 8699253 commit 085570d
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 21 deletions.
30 changes: 10 additions & 20 deletions aiida_flexpart/data/nc_data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
from aiida.orm import RemoteData
from netCDF4 import Dataset

class NetCDFData(RemoteData):

def __init__(self, filepath=None, remote_path=None, **kwargs):
class NetCdfData(RemoteData):

def __init__(
self, filepath=None, remote_path=None, g_att=None, nc_dimensions=None, **kwargs
):
"""
Data plugin for Netcdf files.
"""
Expand All @@ -13,32 +15,20 @@ def __init__(self, filepath=None, remote_path=None, **kwargs):
filename = os.path.basename(filepath)
self.set_remote_path(remote_path)
self.set_filename(filename)

# open and read as NetCDF
nc_file = Dataset(filepath, mode="r")
self.set_global_attributes(nc_file)
self.set_global_attributes(g_att, nc_dimensions)

def set_filename(self, val):
self.base.attributes.set("filename", val)

def set_global_attributes(self, nc_file):

g_att = {}
for a in nc_file.ncattrs():
g_att[a] = repr(nc_file.getncattr(a))
def set_global_attributes(self, g_att, nc_dimensions):
self.base.attributes.set("global_attributes", g_att)

nc_dimensions = {i: len(nc_file.dimensions[i]) for i in nc_file.dimensions}
self.base.attributes.set("dimensions", nc_dimensions)

def ncdump(self):
"""
Small python version of ncdump.
"""
"""Small python version of ncdump."""
print("dimensions:")
for k, v in self.base.attributes.get("dimensions").items():
print("\t%s =" % k, v)

print(f"\t {k} = {v}")
print("// global attributes:")
for k, v in self.base.attributes.get("global_attributes").items():
print("\t:%s =" % k, v)
print(f"\t :{k} = {v}")
90 changes: 90 additions & 0 deletions aiida_flexpart/workflows/inspect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from aiida.engine import WorkChain, calcfunction
from aiida.plugins import DataFactory
from aiida import orm
from pathlib import Path
import tempfile
from netCDF4 import Dataset

NetCDF = DataFactory("netcdf.data")


def check(nc_file, version):
"""
Checks if there is a netcdf file stored with the same name,
if so, it checks the created date, if that is a match then returns
False.
"""
qb = orm.QueryBuilder()
qb.append(
NetCDF,
project=[f"attributes.global_attributes.{version}"],
filters={"attributes.filename": nc_file.attributes["filename"]},
)
if qb.all():
for i in qb.all():
if i[0] == nc_file.attributes["global_attributes"][version]:
return False
return True


def validate_history(nc_file):
return True if "history" in nc_file.attributes["global_attributes"].keys() else None


@calcfunction
def store(remote_dir, file):
with tempfile.TemporaryDirectory() as td:
remote_path = Path(remote_dir.get_remote_path()) / file.value
temp_path = Path(td) / file.value
remote_dir.getfile(remote_path, temp_path)

# fill global attributes and dimensions
nc_file = Dataset(str(temp_path), mode="r")
nc_dimensions = {i: len(nc_file.dimensions[i]) for i in nc_file.dimensions}
global_att = {}
for a in nc_file.ncattrs():
global_att[a] = repr(nc_file.getncattr(a))

node = NetCDF(
str(temp_path),
remote_path=str(remote_path),
computer=remote_dir.computer,
g_att=global_att,
nc_dimensions=nc_dimensions,
)

if validate_history(node) == None:
return
elif check(node, "history"):
return node


class InspectWorkflow(WorkChain):
@classmethod
def define(cls, spec):
super().define(spec)
spec.input_namespace("remotes", valid_type=orm.RemoteData, required=False)
spec.input_namespace(
"remotes_cs", valid_type=orm.RemoteStashFolderData, required=False
)
spec.outputs.dynamic = True
spec.outline(
cls.fill_remote_data,
cls.inspect,
)

def fill_remote_data(self):
self.ctx.dict_remote_data = {}
if "remotes" in self.inputs:
self.ctx.dict_remote_data = self.inputs.remotes
else:
for k, v in self.inputs.remotes_cs.items():
self.ctx.dict_remote_data[k] = orm.RemoteData(
remote_path=v.target_basepath, computer=v.computer
)

def inspect(self):
for _, i in self.ctx.dict_remote_data.items():
for file in i.listdir():
if ".nc" in file:
store(i, file)
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ docs = [

[project.entry-points."aiida.workflows"]
"flexpart.multi_dates" = "aiida_flexpart.workflows.multi_dates_workflow:FlexpartMultipleDatesWorkflow"
"inspect.workflow" = "aiida_flexpart.workflows.inspect:InspectWorkflow"

[project.entry-points."aiida.data"]
"netcdf.data" = "aiida_flexpart.data.nc_data:NetCDFData"
"netcdf.data" = "aiida_flexpart.data.nc_data:NetCdfData"


[tool.pylint.format]
Expand Down

0 comments on commit 085570d

Please sign in to comment.