Skip to content

Commit

Permalink
(MA) implemented endpointOverride parameter to allow user to specif…
Browse files Browse the repository at this point in the history
…y the base URL to use and point to custom provider or a specific sub-database of a provider
  • Loading branch information
amkrajewski committed Apr 3, 2024
1 parent 785dabb commit ac819da
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions pysipfenn/core/modelAdjusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,14 @@ class OPTIMADEAdjuster(LocalAdjuster):
maxResults: The maximum number of results to be fetched from the OPTIMADE API for a given query. Default is
``10000`` which is a very high number for most re-training tasks. If you are fetching a lot of data, it's
possible the query is too broad, and you should consider narrowing it down.
endpointOverride: List of URL strings with the endpoint to be used for the OPTIMADE queries. This is an advanced
option allowing you to ignore the ``provider`` parameter and directly specify the endpoint to be used. It is
useful if you want to use a specific version of the provider's endpoint or narrow down the query to a
sub-database (Alexandria has two different endpoints for PBEsol and SCAN, for instance). You can also use it
to query unofficial endpoints. Make sure to (a) include protocol (``http://`` or ``https://``) and (b) not
include version (``/v1/``), nor the specific endpoint (``/structures``) as the client will add them. I.e.,
you want ``https://alexandria.icams.rub.de/pbesol`` rather than
``alexandria.icams.rub.de/pbesol/v1/structures``. Default is ``None`` which has no effect.
"""

def __init__(
Expand Down Expand Up @@ -704,15 +712,16 @@ def __init__(
descriptor: Literal["Ward2017", "KS2022"] = "KS2022",
useClearML: bool = False,
taskName: str = "OPTIMADEFineTuning",
maxResults: int = 10000
maxResults: int = 10000,
endpointOverride: List[str] = None
) -> None:
from optimade.client import OptimadeClient

assert isinstance(calculator, Calculator), "The calculator must be an instance of the Calculator class."
assert isinstance(model, str), "The model must be a string with the name of the model to be adjusted."
assert isinstance(provider, str), "The provider must be a string with the name of the provider to be used."
assert len(provider) != 0, "The provider must not be an empty string."
assert targetPath and isinstance(targetPath, list), "The target path must be a list of strings pointing to the target data in the OPTIMADE response."
assert targetPath and isinstance(targetPath, list) or isinstance(targetPath, tuple), "The target path must be a list of strings pointing to the target data in the OPTIMADE response."
assert len(targetPath) > 0, "The target path must not be empty, i.e., it cannot point to no data."
if provider != "mp" and targetPath == ('attributes', '_mp_stability', 'gga_gga+u', 'formation_energy_per_atom'):
raise ValueError("You are utilizing the default (example) property target path specific to the Materials "
Expand All @@ -734,11 +743,20 @@ def __init__(
self.descriptor = descriptor
self.targetPath = targetPath
self.provider = provider
self.client = OptimadeClient(
use_async=False,
include_providers=[provider],
max_results_per_provider=maxResults
)
if endpointOverride is None:
self.client = OptimadeClient(
use_async=False,
include_providers=[provider],
max_results_per_provider=maxResults
)
else:
assert isinstance(endpointOverride, list) or isinstance(endpointOverride, tuple), "The endpoint override must be a list of strings."
assert len(endpointOverride) != 0, "The endpoint override must not be an empty list."
self.client = OptimadeClient(
use_async=False,
base_urls=endpointOverride,
max_results_per_provider=maxResults
)

if self.descriptor == "Ward2017":
self.descriptorData: np.ndarray = np.empty((0, 271))
Expand Down

0 comments on commit ac819da

Please sign in to comment.