Skip to content

Commit

Permalink
Fix mypy due to variable name overwrite
Browse files Browse the repository at this point in the history
  • Loading branch information
alvarobartt committed Mar 5, 2024
1 parent 003c265 commit 2a756dd
Showing 1 changed file with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ def load(self, artifacts_uri: str) -> None:
model_kwargs = os.getenv("HF_MODEL_KWARGS", None)
if model_kwargs is not None:
try:
model_kwargs = eval(model_kwargs)
self._logger.info(f"HF_MODEL_KWARGS value is {model_kwargs}")
model_kwargs.pop("device", None)
model_kwargs.pop("device_map", None)
model_kwargs_dict = eval(model_kwargs)
self._logger.info(f"HF_MODEL_KWARGS value is {model_kwargs_dict}")
model_kwargs_dict.pop("device", None)
model_kwargs_dict.pop("device_map", None)
except Exception:
self._logger.error(
f"Failed to parse `HF_MODEL_KWARGS` environment variable: {model_kwargs}"
)
model_kwargs = {}
model_kwargs_dict = {}

task = os.getenv("HF_TASK", "")
if task != "":
Expand All @@ -65,7 +65,7 @@ def load(self, artifacts_uri: str) -> None:
task,
model=model_path,
device_map="auto",
**model_kwargs, # type: ignore
**model_kwargs_dict,
)
except ValueError as ve:
self._logger.error(
Expand All @@ -82,7 +82,7 @@ def load(self, artifacts_uri: str) -> None:
task,
model=model_path,
device=device,
**model_kwargs, # type: ignore
**model_kwargs_dict,
)

self._logger.info(
Expand Down

0 comments on commit 2a756dd

Please sign in to comment.