From 100e8cb3b6fef8e6562e643ec83f8fadd1fe3cec Mon Sep 17 00:00:00 2001 From: Xuchen Pan <32844285+pan-x-c@users.noreply.github.com> Date: Mon, 12 Aug 2024 17:20:52 +0800 Subject: [PATCH] [Hot Fix] Fix `_ModelConfig` state get and set (#397) --------- Co-authored-by: DavdGao --- src/agentscope/manager/_model.py | 79 +++++++++++--------------------- 1 file changed, 26 insertions(+), 53 deletions(-) diff --git a/src/agentscope/manager/_model.py b/src/agentscope/manager/_model.py index 422293958..0f63f14be 100644 --- a/src/agentscope/manager/_model.py +++ b/src/agentscope/manager/_model.py @@ -100,16 +100,16 @@ def load_model_configs( f"list of dicts), but got {type(model_configs)}", ) - format_configs = _ModelConfig.format_configs(configs=cfgs) + formatted_configs = _format_configs(configs=cfgs) # check if name is unique - for cfg in format_configs: - if cfg.config_name in self.model_configs: + for cfg in formatted_configs: + if cfg["config_name"] in self.model_configs: logger.warning( - f"config_name [{cfg.config_name}] already exists.", + f"config_name [{cfg['config_name']}] already exists.", ) continue - self.model_configs[cfg.config_name] = cfg + self.model_configs[cfg["config_name"]] = cfg # print the loaded model configs logger.info( @@ -137,7 +137,7 @@ def get_model_by_config_name(self, config_name: str) -> ModelWrapperBase: f"Cannot find [{config_name}] in loaded configurations.", ) - model_type = config.model_type + model_type = config["model_type"] kwargs = {k: v for k, v in config.items() if k != "model_type"} @@ -164,55 +164,28 @@ def flush(self) -> None: self.clear_model_configs() -class _ModelConfig(dict): - """Base class for model config.""" +def _format_configs( + configs: Union[Sequence[dict], dict], +) -> Sequence: + """Check the format of model configs. - __getattr__ = dict.__getitem__ - __setattr__ = dict.__setitem__ + Args: + configs (Union[Sequence[dict], dict]): configs in dict format. - def __init__( - self, - config_name: str, - model_type: str = None, - **kwargs: Any, - ): - """Initialize the config with the given arguments, and checking the - type of the arguments. - - Args: - config_name (`str`): A unique name of the model config. - model_type (`str`, optional): The class name (or its model type) of - the generated model wrapper. Defaults to None. - - Raises: - `ValueError`: If `config_name` is not provided. - """ - if config_name is None: - raise ValueError("The `config_name` field is required for Cfg") - if model_type is None: + Returns: + Sequence[dict]: converted ModelConfig list. + """ + if isinstance(configs, dict): + configs = [configs] + for config in configs: + if "config_name" not in config: + raise ValueError( + "The `config_name` field is required for Cfg", + ) + if "model_type" not in config: logger.warning( - f"`model_type` is not provided in config [{config_name}]," + "`model_type` is not provided in config" + f"[{config['config_name']}]," " use `PostAPIModelWrapperBase` by default.", ) - super().__init__( - config_name=config_name, - model_type=model_type, - **kwargs, - ) - - @classmethod - def format_configs( - cls, - configs: Union[Sequence[dict], dict], - ) -> Sequence: - """Covert config dicts into a list of _ModelConfig. - - Args: - configs (Union[Sequence[dict], dict]): configs in dict format. - - Returns: - Sequence[_ModelConfig]: converted ModelConfig list. - """ - if isinstance(configs, dict): - return [_ModelConfig(**configs)] - return [_ModelConfig(**cfg) for cfg in configs] + return configs