Skip to content

Commit

Permalink
2 general base classes for workflows (#3)
Browse files Browse the repository at this point in the history
* Added base classes for DFTMethod and DFTOutputs, and for GWOutputs

* Reformatted and using inheritance gw.py

* Added TBOutputs and DMFTOutputs to general.py

Reformatted dmft.py

* Added BeyondDFT2Tasks base class and MaxEntOutputs in general.py

Reformatted and inheritance in maxent.py

* Reformatted xs.py and photon_polarization.py

* Changed inheritance by composition

Reformat tb.py

* Added composition to XS workflow

* Defining ElectronicStructureOutputs for generic complex workflows

* Fix bug name

* Fixing XS and tests

* Changed name to electronic workflows to add Plus

* Fix testing

* Generalize method in BeyondDFT class

* Improved extraction of workflow_name

* Adding nmr workflow (#6)

* Added MagneticOutputs in general.py

* Added more imports for usage in __init__

* Fix testing
  • Loading branch information
JosePizarro3 authored Feb 7, 2024
1 parent c8db9bd commit 2f26453
Show file tree
Hide file tree
Showing 9 changed files with 891 additions and 770 deletions.
18 changes: 14 additions & 4 deletions simulationworkflowschema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
SimulationWorkflow,
SimulationWorkflowMethod,
SimulationWorkflowResults,
ParallelSimulation,
SerialSimulation,
BeyondDFT,
DFTMethod,
ElectronicStructureOutputs,
MagneticOutputs,
)
from .single_point import SinglePoint, SinglePointMethod, SinglePointResults
from .geometry_optimization import (
Expand All @@ -44,11 +50,15 @@
ChemicalReactionResults,
)
from .elastic import Elastic, ElasticMethod, ElasticResults
from .tb import TB, TBMethod, TBResults
from .gw import GW, GWMethod, GWResults
from .tb import (
FirstPrinciplesPlusTB,
FirstPrinciplesPlusTBMethod,
FirstPrinciplesPlusTBResults,
)
from .gw import DFTPlusGW, DFTPlusGWMethod, DFTPlusGWResults
from .xs import XS, XSMethod, XSResults
from .dmft import DMFT, DMFTMethod, DMFTResults
from .max_ent import MaxEnt, MaxEntMethod, MaxEntResults
from .dmft import DFTPlusTBPlusDMFT, DFTPlusTBPlusDMFTMethod, DFTPlusTBPlusDMFTResults
from .max_ent import DMFTPlusMaxEnt, DMFTPlusMaxEntMethod, DMFTPlusMaxEntResults
from .photon_polarization import (
PhotonPolarization,
PhotonPolarizationMethod,
Expand Down
132 changes: 22 additions & 110 deletions simulationworkflowschema/dmft.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,109 +17,41 @@
#
from nomad.metainfo import SubSection, Quantity, Reference
from runschema.method import (
XCFunctional,
BasisSetContainer,
TB as TBMethodology,
DMFT as DMFTMethodology,
)
from runschema.calculation import BandGap, Dos, BandStructure, GreensFunctions
from .general import (
SimulationWorkflowResults,
SimulationWorkflowMethod,
SerialSimulation,
ElectronicStructureOutputs,
DFTMethod,
BeyondDFT,
)


class DMFTResults(SimulationWorkflowResults):
"""Groups DFT, TB and DMFT outputs: band gaps (all), DOS (DFT, TB), band
class DFTPlusTBPlusDMFTResults(SimulationWorkflowResults):
"""
Groups DFT, TB and DMFT outputs: band gaps (all), DOS (DFT, TB), band
structures (DFT, TB), Greens functions (DMFT). The ResultsNormalizer takes care
of adding a label 'DFT', 'PROJECTION, or 'DMFT' in the method `get_dmft_workflow_properties`.
"""

band_gap_dft = Quantity(
type=Reference(BandGap),
shape=["*"],
description="""
DFT band gap.
""",
)

band_gap_tb = Quantity(
type=Reference(BandGap),
shape=["*"],
description="""
TB band gap.
""",
)

band_gap_dmft = Quantity(
type=Reference(BandGap),
shape=["*"],
description="""
DMFT band gap.
""",
)

band_structure_dft = Quantity(
type=Reference(BandStructure),
shape=["*"],
description="""
Ref to the DFT band structure.
""",
)

dos_dft = Quantity(
type=Reference(Dos),
shape=["*"],
description="""
Ref to the DFT density of states.
""",
)

band_structure_tb = Quantity(
type=Reference(BandStructure),
shape=["*"],
description="""
Ref to the TB band structure.
""",
dft_outputs = SubSection(
sub_section=ElectronicStructureOutputs.m_def, repeats=False
)

dos_tb = Quantity(
type=Reference(Dos),
shape=["*"],
description="""
Ref to the TB density of states.
""",
)
tb_outputs = SubSection(sub_section=ElectronicStructureOutputs.m_def, repeats=False)

greens_functions_dmft = Quantity(
type=Reference(GreensFunctions),
shape=["*"],
description="""
Ref to the DMFT Greens functions.
""",
dmft_outputs = SubSection(
sub_section=ElectronicStructureOutputs.m_def, repeats=False
)


class DMFTMethod(SimulationWorkflowMethod):
"""Groups DFT, TB and DMFT input methodologies: starting XC functional, electrons
class DFTPlusTBPlusDMFTMethod(DFTMethod):
"""
Specifies all DFT, TB and DMFT input methodologies: starting XC functional, electrons
representation (basis set), TB method reference, DMFT method reference.
"""

starting_point = Quantity(
type=Reference(XCFunctional),
description="""
Starting point (XC functional or HF) used.
""",
)

electrons_representation = Quantity(
type=Reference(BasisSetContainer),
description="""
Basis set used.
""",
)

tb_method_ref = Quantity(
type=Reference(TBMethodology),
description="""
Expand All @@ -135,38 +67,18 @@ class DMFTMethod(SimulationWorkflowMethod):
)


class DMFT(SerialSimulation):
"""The DMFT workflow is generated in an extra EntryArchive IF both the TB SinglePoint
class DFTPlusTBPlusDMFT(BeyondDFT): # TODO implement connection with DFT task
"""
The DMFT workflow is generated in an extra EntryArchive IF both the TB SinglePoint
and the DMFT SinglePoint EntryArchives are present in the upload.
"""

# TODO extend to reference a DFT SinglePoint.

method = SubSection(sub_section=DMFTMethod)
method = SubSection(sub_section=DFTPlusTBPlusDMFTMethod)

results = SubSection(sub_section=DMFTResults)
results = SubSection(sub_section=DFTPlusTBPlusDMFTResults)

def normalize(self, archive, logger):
super().normalize(archive, logger)

if len(self.tasks) != 2:
logger.error("Expected two tasks: TB and DMFT SinglePoint tasks")
return

proj_task = self.tasks[0]
dmft_task = self.tasks[1]
if not self.results: # creates Results section if not present
self.results = DFTPlusTBPlusDMFTResults()

if not self.results:
self.results = DMFTResults()

for name, section in self.results.m_def.all_quantities.items():
calc_name = "_".join(name.split("_")[:-1])
if calc_name in ["dos", "band_structure"]:
calc_name = f"{calc_name}_electronic"
calc_section = []
if "tb" in name:
calc_section = getattr(proj_task.outputs[-1].section, calc_name)
elif "dmft" in name:
calc_section = getattr(dmft_task.outputs[-1].section, calc_name)
if calc_section:
self.results.m_set(section, calc_section)
super().normalize(archive, logger)
179 changes: 176 additions & 3 deletions simulationworkflowschema/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,24 @@
from nomad.datamodel.data import ArchiveSection
from nomad.metainfo import SubSection, Section, Quantity, Reference
from nomad.datamodel.metainfo.common import FastAccess
from nomad.datamodel.metainfo.workflow import Workflow, Link, Task
from runschema.method import Method
from nomad.datamodel.metainfo.workflow import Workflow, Link, Task, TaskReference
from runschema.method import (
Method,
XCFunctional,
BasisSetContainer,
)
from runschema.system import System
from runschema.calculation import Calculation
from runschema.calculation import (
Calculation,
BandGap,
Dos,
BandStructure,
GreensFunctions,
MagneticShielding,
ElectricFieldGradient,
SpinSpinCoupling,
MagneticSusceptibility,
)


def resolve_difference(values):
Expand Down Expand Up @@ -324,3 +338,162 @@ def normalize(self, archive, logger):
self.tasks.append(
Task(name=f"Step {n}", inputs=inputs, outputs=outputs)
)


class BeyondDFT(SerialSimulation):
"""
Base class used to normalize standard workflows beyond DFT containing two specific
SinglePoint tasks (GWWorkflow = DFT + GW, DMFTWorkflow = DFT + DMFT,
MaxEntWorkflow = DMFT + MaxEnt, and so on) and store the outputs in the self.results
section.
"""

def _resolve_outputs_section(self, output_section, task: TaskReference) -> None:
"""
Resolves the output_section of a task and stores the results in the output_section.
Args:
task (TaskReference): The task from which the outputs are got.
"""
for name, section in output_section.m_def.all_quantities.items():
name = f"{name}_electronic" if name in ["dos", "band_structure"] else name
try:
calc_section = getattr(task.outputs[-1].section, name)
if calc_section:
output_section.m_set(section, calc_section)
except Exception:
continue

def get_electronic_structure_workflow_results(self, task_map: dict) -> None:
"""
Gets the standard electronic structure workflow results section by resolving the
outputs specified in the `task_map`.
Args:
task_map (dict): The dictionary used to resolve the outputs sections.
"""
for method, task in task_map.items():
outputs = ElectronicStructureOutputs()
self._resolve_outputs_section(outputs, task)
setattr(self.results, f"{method}_outputs", outputs)

def normalize(self, archive, logger):
super().normalize(archive, logger)

if len(self.tasks) != 2:
logger.error("Expected two tasks.")
return

# We extract the workflow name from the tasks names
self.name = "+".join([task.name for task in self.tasks if task.name])
task_map = {
task.name.lower(): self.tasks[n] for n, task in enumerate(self.tasks)
}
# Resolve workflow2.results for each standard BeyondDFT workflow
if self.name == "DFT+GW":
self.get_electronic_structure_workflow_results(task_map)
elif self.name == "TB+DMFT": # TODO extend for DFT tasks
self.get_electronic_structure_workflow_results(task_map)
elif self.name == "DMFT+MaxEnt":
self.get_electronic_structure_workflow_results(task_map)
elif self.name == "FirstPrinciples+TB":
task_map["first_principles"] = task_map.pop("firstprinciples")
self.get_electronic_structure_workflow_results(task_map)


class DFTMethod(SimulationWorkflowMethod):
"""
Base class defining the DFT input methodologies: starting XC functional and electrons
representation (basis set).
"""

starting_point = Quantity(
type=Reference(XCFunctional),
description="""
Reference to the starting point (XC functional or HF) used.
""",
)

electrons_representation = Quantity(
type=Reference(BasisSetContainer),
description="""
Reference to the basis set used.
""",
)


class ElectronicStructureOutputs(SimulationWorkflowResults):
"""
Base class defining the typical output properties of any electronic structure
SinglePoint calculation: DFT, TB, DMFT, GW, MaxEnt, XS.
"""

band_gap = Quantity(
type=Reference(BandGap),
shape=["*"],
description="""
Reference to the band gap section.
""",
)

dos = Quantity(
type=Reference(Dos),
shape=["*"],
description="""
Reference to the density of states section.
""",
)

band_structure = Quantity(
type=Reference(BandStructure),
shape=["*"],
description="""
Reference to the band structure section.
""",
)

greens_functions = Quantity(
type=Reference(GreensFunctions),
shape=["*"],
description="""
Ref to the Green functions section.
""",
)


class MagneticOutputs(SimulationWorkflowResults):
"""
Base class defining the typical output properties of magnetic SinglePoint calculations.
"""

magnetic_shielding = Quantity(
type=Reference(MagneticShielding),
shape=["*"],
description="""
Reference to the magnetic shielding tensors.
""",
)

electric_field_gradient = Quantity(
type=Reference(ElectricFieldGradient),
shape=["*"],
description="""
Reference to the electric field gradient tensors.
""",
)

spin_spin_coupling = Quantity(
type=Reference(SpinSpinCoupling),
shape=["*"],
description="""
Reference to the spin-spin coupling tensors.
""",
)

magnetic_susceptibility_nmr = Quantity(
type=Reference(MagneticSusceptibility),
shape=["*"],
description="""
Reference to the magnetic susceptibility tensors.
""",
)
Loading

0 comments on commit 2f26453

Please sign in to comment.