Skip to content

Commit

Permalink
fix: Pydantic 2 deprecation warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisburr committed Jun 10, 2024
1 parent 327bb18 commit 954641d
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 84 deletions.
6 changes: 3 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ install_requires =
certifi
cwltool
diraccfg
diracx-client
diracx-core
diracx-client >=0.0.1a17
diracx-core >=0.0.1a17
db12
fts3
gfal2-python
Expand All @@ -44,7 +44,7 @@ install_requires =
psutil
pyasn1
pyasn1-modules
pydantic
pydantic >=2.4
pyparsing
python-dateutil
pytz
Expand Down
2 changes: 1 addition & 1 deletion src/DIRAC/Core/Utilities/test/Test_JDL.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_jdlToBaseJobDescriptionModel_valid(jdl_monkey_business):
res = jdlToBaseJobDescriptionModel(ClassAd(jdl))
assert res["OK"], res["Message"]

data = res["Value"].dict()
data = res["Value"].model_dump()
assert JobDescriptionModel(owner="owner", ownerGroup="ownerGroup", vo="lhcb", **data)


Expand Down
2 changes: 1 addition & 1 deletion src/DIRAC/Resources/Computing/BatchSystems/SLURM.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _generateSrunWrapper(self, executableFile):
content = f.read()

# Need to escape environment variables of the executable file
content = re.sub("\$", "\\$", content)
content = re.sub(r"\$", r"\\$", content)

# Build the script to run the executable in parallel multiple times
# - Embed the content of executableFile inside the parallel library wrapper script
Expand Down
151 changes: 72 additions & 79 deletions src/DIRAC/WorkloadManagementSystem/Utilities/JobModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
# pylint: disable=no-self-argument, no-self-use, invalid-name, missing-function-docstring

from collections.abc import Iterable
from typing import Any, Annotated
from typing import Any, Annotated, TypeAlias, Self

import pydantic
from packaging.version import Version
from pydantic import BaseModel, root_validator, validator
from pydantic import BaseModel, BeforeValidator, model_validator, field_validator, ConfigDict

from DIRAC import gLogger
from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations
Expand All @@ -16,71 +14,71 @@

# HACK: Convert appropriate iterables into sets
def default_set_validator(value):
if not isinstance(value, Iterable):
if value is None:
return set()
elif not isinstance(value, Iterable):
return value
elif isinstance(value, (str, bytes, bytearray)):
return value
else:
return set(value)


if Version(pydantic.__version__) > Version("2.0.0a0"):
CoercibleSetStr = Annotated[
set[str] | None, pydantic.BeforeValidator(default_set_validator) # pylint: disable=no-member
]
else:
CoercibleSetStr = set[str]
CoercibleSetStr: TypeAlias = Annotated[set[str], BeforeValidator(default_set_validator)]


class BaseJobDescriptionModel(BaseModel):
"""Base model for the job description (not parametric)"""

class Config:
validate_assignment = True
model_config = ConfigDict(validate_assignment=True)

arguments: str = None
bannedSites: CoercibleSetStr = None
arguments: str = ""
bannedSites: CoercibleSetStr = set()
# TODO: This should use a field factory
cpuTime: int = Operations().getValue("JobDescription/DefaultCPUTime", 86400)
executable: str
executionEnvironment: dict = None
gridCE: str = None
inputSandbox: CoercibleSetStr = None
inputData: CoercibleSetStr = None
inputDataPolicy: str = None
jobConfigArgs: str = None
jobGroup: str = None
gridCE: str = ""
inputSandbox: CoercibleSetStr = set()
inputData: CoercibleSetStr = set()
inputDataPolicy: str = ""
jobConfigArgs: str = ""
jobGroup: str = ""
jobType: str = "User"
jobName: str = "Name"
# TODO: This should be an StrEnum
logLevel: str = "INFO"
# TODO: This can't be None with this type hint
maxNumberOfProcessors: int = None
minNumberOfProcessors: int = 1
outputData: CoercibleSetStr = None
outputPath: str = None
outputSandbox: CoercibleSetStr = None
outputSE: str = None
platform: str = None
outputData: CoercibleSetStr = set()
outputPath: str = ""
outputSandbox: CoercibleSetStr = set()
outputSE: str = ""
platform: str = ""
# TODO: This should use a field factory
priority: int = Operations().getValue("JobDescription/DefaultPriority", 1)
sites: CoercibleSetStr = None
sites: CoercibleSetStr = set()
stderr: str = "std.err"
stdout: str = "std.out"
tags: CoercibleSetStr = None
extraFields: dict[str, Any] = None
tags: CoercibleSetStr = set()
extraFields: dict[str, Any] = {}

@validator("cpuTime")
@field_validator("cpuTime")
def checkCPUTimeBounds(cls, v):
minCPUTime = Operations().getValue("JobDescription/MinCPUTime", 100)
maxCPUTime = Operations().getValue("JobDescription/MaxCPUTime", 500000)
if not minCPUTime <= v <= maxCPUTime:
raise ValueError(f"cpuTime out of bounds (must be between {minCPUTime} and {maxCPUTime})")
return v

@validator("executable")
@field_validator("executable")
def checkExecutableIsNotAnEmptyString(cls, v: str):
if not v:
raise ValueError("executable must not be an empty string")
return v

@validator("jobType")
@field_validator("jobType")
def checkJobTypeIsAllowed(cls, v: str):
jobTypes = Operations().getValue("JobDescription/AllowedJobTypes", ["User", "Test", "Hospital"])
transformationTypes = Operations().getValue("Transformations/DataProcessing", [])
Expand All @@ -89,15 +87,15 @@ def checkJobTypeIsAllowed(cls, v: str):
raise ValueError(f"jobType '{v}' is not allowed for this kind of user (must be in {allowedTypes})")
return v

@validator("inputData")
@field_validator("inputData")
def checkInputDataDoesntContainDoubleSlashes(cls, v):
if v:
for lfn in v:
if lfn.find("//") > -1:
raise ValueError("Input data contains //")
return v

@validator("inputData")
@field_validator("inputData")
def addLFNPrefixIfStringStartsWithASlash(cls, v: set[str]):
if v:
v = {lfn.strip() for lfn in v if lfn.strip()}
Expand All @@ -108,30 +106,30 @@ def addLFNPrefixIfStringStartsWithASlash(cls, v: set[str]):
raise ValueError("Input data files must start with LFN:/")
return v

@root_validator(skip_on_failure=True)
def checkNumberOfInputDataFiles(cls, values):
if "inputData" in values and values["inputData"]:
@model_validator(mode="after")
def checkNumberOfInputDataFiles(self) -> Self:
if self.inputData:
maxInputDataFiles = Operations().getValue("JobDescription/MaxInputData", 500)
if values["jobType"] == "User" and len(values["inputData"]) >= maxInputDataFiles:
if self.jobType == "User" and len(self.inputData) >= maxInputDataFiles:
raise ValueError(f"inputData contains too many files (must contain at most {maxInputDataFiles})")
return values
return self

@validator("inputSandbox")
@field_validator("inputSandbox")
def checkLFNSandboxesAreWellFormated(cls, v: set[str]):
for inputSandbox in v:
if inputSandbox.startswith("LFN:") and not inputSandbox.startswith("LFN:/"):
raise ValueError("LFN files must start by LFN:/")
return v

@validator("logLevel")
@field_validator("logLevel")
def checkLogLevelIsValid(cls, v: str):
v = v.upper()
possibleLogLevels = gLogger.getAllPossibleLevels()
if v not in possibleLogLevels:
raise ValueError(f"Log level {v} not in {possibleLogLevels}")
return v

@validator("minNumberOfProcessors")
@field_validator("minNumberOfProcessors")
def checkMinNumberOfProcessorsBounds(cls, v):
minNumberOfProcessors = Operations().getValue("JobDescription/MinNumberOfProcessors", 1)
maxNumberOfProcessors = Operations().getValue("JobDescription/MaxNumberOfProcessors", 1024)
Expand All @@ -141,7 +139,7 @@ def checkMinNumberOfProcessorsBounds(cls, v):
)
return v

@validator("maxNumberOfProcessors")
@field_validator("maxNumberOfProcessors")
def checkMaxNumberOfProcessorsBounds(cls, v):
minNumberOfProcessors = Operations().getValue("JobDescription/MinNumberOfProcessors", 1)
maxNumberOfProcessors = Operations().getValue("JobDescription/MaxNumberOfProcessors", 1024)
Expand All @@ -151,27 +149,22 @@ def checkMaxNumberOfProcessorsBounds(cls, v):
)
return v

@root_validator(skip_on_failure=True)
def checkThatMaxNumberOfProcessorsIsGreaterThanMinNumberOfProcessors(cls, values):
if "maxNumberOfProcessors" in values and values["maxNumberOfProcessors"]:
if values["maxNumberOfProcessors"] < values["minNumberOfProcessors"]:
@model_validator(mode="after")
def checkThatMaxNumberOfProcessorsIsGreaterThanMinNumberOfProcessors(self) -> Self:
if self.maxNumberOfProcessors:
if self.maxNumberOfProcessors < self.minNumberOfProcessors:
raise ValueError("maxNumberOfProcessors must be greater than minNumberOfProcessors")
return values

@root_validator(skip_on_failure=True)
def addTagsDependingOnNumberOfProcessors(cls, values):
if "maxNumberOfProcessors" in values and values["minNumberOfProcessors"] == values["maxNumberOfProcessors"]:
if values["tags"] is None:
values["tags"] = set()
values["tags"].add(f"{values['minNumberOfProcessors']}Processors")
if values["minNumberOfProcessors"] > 1:
if values["tags"] is None:
values["tags"] = set()
values["tags"].add("MultiProcessor")

return values

@validator("sites")
return self

@model_validator(mode="after")
def addTagsDependingOnNumberOfProcessors(self) -> Self:
if self.minNumberOfProcessors == self.maxNumberOfProcessors:
self.tags.add(f"{self.minNumberOfProcessors}Processors")
if self.minNumberOfProcessors > 1:
self.tags.add("MultiProcessor")
return self

@field_validator("sites")
def checkSites(cls, v: set[str]):
if v:
res = getSites()
Expand All @@ -182,16 +175,16 @@ def checkSites(cls, v: set[str]):
raise ValueError(f"Invalid sites: {' '.join(invalidSites)}")
return v

@root_validator(skip_on_failure=True)
def checkThatSitesAndBannedSitesAreNotMutuallyExclusive(cls, values):
if "sites" in values and values["sites"] and "bannedSites" in values and values["bannedSites"]:
values["sites"] -= values["bannedSites"]
values["bannedSites"] = None
if not values["sites"]:
@model_validator(mode="after")
def checkThatSitesAndBannedSitesAreNotMutuallyExclusive(self) -> Self:
if self.sites and self.bannedSites:
self.sites -= self.bannedSites
self.bannedSites = set()
if not self.sites:
raise ValueError("sites and bannedSites are mutually exclusive")
return values
return self

@validator("platform")
@field_validator("platform")
def checkPlatform(cls, v: str):
if v:
res = getDIRACPlatforms()
Expand All @@ -201,7 +194,7 @@ def checkPlatform(cls, v: str):
raise ValueError("Invalid platform")
return v

@validator("priority")
@field_validator("priority")
def checkPriorityBounds(cls, v):
minPriority = Operations().getValue("JobDescription/MinPriority", 0)
maxPriority = Operations().getValue("JobDescription/MaxPriority", 10)
Expand All @@ -217,10 +210,10 @@ class JobDescriptionModel(BaseJobDescriptionModel):
ownerGroup: str
vo: str

@root_validator(skip_on_failure=True)
def checkLFNMatchesREGEX(cls, values):
if "inputData" in values and values["inputData"]:
for lfn in values["inputData"]:
if not lfn.startswith(f"LFN:/{values['vo']}/"):
raise ValueError(f"Input data not correctly specified (must start with LFN:/{values['vo']}/)")
return values
@model_validator(mode="after")
def checkLFNMatchesREGEX(self) -> Self:
if self.inputData:
for lfn in self.inputData:
if not lfn.startswith(f"LFN:/{self.vo}/"):
raise ValueError(f"Input data not correctly specified (must start with LFN:/{self.vo}/)")
return self

0 comments on commit 954641d

Please sign in to comment.