Skip to content

Commit

Permalink
1. add tests for spin_type=SpinType.SPIN_ORBIT;
Browse files Browse the repository at this point in the history
2. minor fix in `Wannier90WorkChain` and `Wannier90BandsWorkChain`.
  • Loading branch information
Yuhao Jiang committed Dec 12, 2023
1 parent 997f1a6 commit 85176f3
Show file tree
Hide file tree
Showing 20 changed files with 3,591 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/aiida_wannier90_workflows/workflows/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def get_builder_from_protocol( # pylint: disable=arguments-differ
overrides=kwargs.pop("overrides", None),
)

if run_open_grid and kwargs.get("electronic_type", None) == SpinType.SPIN_ORBIT:
if run_open_grid and kwargs.get("spin_type", None) == SpinType.SPIN_ORBIT:
raise ValueError("open_grid.x does not support spin orbit coupling")

if run_open_grid:
Expand Down
5 changes: 5 additions & 0 deletions src/aiida_wannier90_workflows/workflows/wannier90.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,11 @@ def get_builder_from_protocol( # pylint: disable=unused-argument
# Note: if overrides are specified, they take precedence!
protocol_overrides = cls.get_protocol_overrides()

# If recursive_merge get an arg = None, the arg.copy() will raise an error.
# When overrides is not given (default value None), it should be set to an empty dict.
if overrides is None:
overrides = {}

if plot_wannier_functions:
overrides = recursive_merge(
protocol_overrides["plot_wannier_functions"], overrides
Expand Down
131 changes: 124 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,20 @@ def _serializer(node):


@pytest.fixture(scope="session", autouse=True)
def sssp(aiida_profile, generate_upf_data):
"""Create an SSSP pseudo potential family from scratch."""
def pseudos(aiida_profile, generate_upf_data, generate_upf_data_soc):
"""Create pseudo potential families from scratch."""
from aiida.common.constants import elements
from aiida.plugins import GroupFactory

aiida_profile.clear_profile()

# Create an SSSP pseudo potential family from scratch.
SsspFamily = GroupFactory("pseudo.family.sssp")

stringency = "standard"
label = "SSSP/1.1/PBE/efficiency"
family = SsspFamily(label=label)
family.store()
sssp = SsspFamily(label=label)
sssp.store()

cutoffs = {}
upfs = []
Expand All @@ -115,10 +116,40 @@ def sssp(aiida_profile, generate_upf_data):
"cutoff_rho": 240.0,
}

family.add_nodes(upfs)
family.set_cutoffs(cutoffs, stringency, unit="Ry")
sssp.add_nodes(upfs)
sssp.set_cutoffs(cutoffs, stringency, unit="Ry")

return family
# Create an pseudoDojo pseudo potential family from scratch.
DojoFamily = GroupFactory("pseudo.family.pseudo_dojo")

stringency = "standard"
label = "PseudoDojo/0.4/PBE/FR/standard/upf"
dojo = DojoFamily(label=label)
dojo.store()

cutoffs = {}
upfs = []

for values in elements.values():
element = values["symbol"]
if element in ["X"]:
continue
try:
upf = generate_upf_data_soc(element)
except ValueError:
continue

upfs.append(upf)

cutoffs[element] = {
"cutoff_wfc": 40.0,
"cutoff_rho": 300.0,
}

dojo.add_nodes(upfs)
dojo.set_cutoffs(cutoffs, stringency, unit="Ry")

return sssp, dojo


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -175,6 +206,71 @@ def _generate_upf_data(element):
return _generate_upf_data


@pytest.fixture(scope="session")
def generate_upf_data_soc(filepath_fixtures):
"""Return a `UpfData` instance for the given element a file for which should exist in `tests/fixtures/pseudos`."""

def _generate_upf_data_soc(element):
"""Return `UpfData` node."""
from aiida_pseudo.data.pseudo import PseudoPotentialData, UpfData
import yaml

yaml_file = (
filepath_fixtures / "pseudos" / "PseudoDojo_0.4_PBE_FR_standard_upf.yaml"
)
with open(yaml_file, encoding="utf-8") as file:
upf_metadata = yaml.load(file, Loader=yaml.FullLoader)

if element not in upf_metadata:
raise ValueError(f"Element {element} not found in {yaml_file}")

filename = upf_metadata[element]["filename"]
md5 = upf_metadata[element]["md5"]
z_valence = upf_metadata[element]["z_valence"]
number_of_wfc = upf_metadata[element]["number_of_wfc"]
has_so = upf_metadata[element]["has_so"]
ppspinorb = upf_metadata[element]["ppspinorb"]
jchi = ppspinorb["jchi"]
lchi = ppspinorb["lchi"]
nn = ppspinorb["nn"]
pprelwfc = ""

for i, l in enumerate(lchi): # pylint: disable=invalid-name
pprelwfc += (
f'<PP_RELWFC.{i+1} index="{i+1}" '
f'lchi="{l}" '
f'jchi="{jchi[i]}" '
f'nn="{nn[i]}"/>\n'
)

content = (
'<UPF version="2.0.1">\n'
"<PP_HEADER\n"
f'element="{element}"\n'
f'z_valence="{z_valence}"\n'
f'has_so="{has_so}"\n'
f'number_of_wfc="{number_of_wfc}"\n'
"/>\n"
"<PP_SPIN_ORB>\n"
f"{pprelwfc}"
"</PP_SPIN_ORB>\n"
"</UPF>\n"
)
stream = io.BytesIO(content.encode("utf-8"))
upf = UpfData(stream, filename=f"{filename}")

# I need to hack the md5
# upf.md5 = md5
upf.set_attribute(upf._key_md5, md5) # pylint: disable=protected-access
# UpfData.store will check md5
# `PseudoPotentialData` is the parent class of `UpfData`, this will skip md5 check
super(PseudoPotentialData, upf).store()

return upf

return _generate_upf_data_soc


@pytest.fixture(scope="session")
def get_sssp_upf():
"""Returen a SSSP pseudo with a given element name."""
Expand All @@ -196,6 +292,27 @@ def _get_sssp_upf(element):
return _get_sssp_upf


@pytest.fixture(scope="session")
def get_dojo_upf():
"""Returen a pseudoDojo pseudo with a given element name."""

def _get_dojo_upf(element):
"""Returen pesudoDojo pseudo."""
from aiida.orm import QueryBuilder
from aiida.plugins import GroupFactory

DojoFamily = GroupFactory("pseudo.family.pseudo_dojo")

label = "PseudoDojo/0.4/PBE/FR/standard/upf"
pseudo_family = (
QueryBuilder().append(DojoFamily, filters={"label": label}).one()[0]
)

return pseudo_family.get_pseudo(element=element)

return _get_dojo_upf


@pytest.fixture
def generate_calc_job():
"""Fixture to construct a new `CalcJob` instance and call `prepare_for_submission` for testing `CalcJob` classes.
Expand Down
Loading

0 comments on commit 85176f3

Please sign in to comment.