diff --git a/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py b/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py index 7988dbaa898..9dd317869bf 100644 --- a/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py +++ b/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py @@ -164,9 +164,9 @@ def tune( self, # TODO (andreyvelich): How to be consistent with other APIs (name) ? name: str, - objective: Callable, + objective: Union[Callable, str], parameters: Dict[str, Any], - base_image: str = constants.BASE_IMAGE_TENSORFLOW, + #base_image: str = constants.BASE_IMAGE_TENSORFLOW, namespace: Optional[str] = None, env_per_trial: Optional[ Union[Dict[str, str], List[Union[client.V1EnvVar, client.V1EnvFromSource]]] @@ -294,20 +294,12 @@ def tune( if max_failed_trial_count is not None: experiment.spec.max_failed_trial_count = max_failed_trial_count - # Validate objective function. - utils.validate_objective_function(objective) - - # Extract objective function implementation. - objective_code = inspect.getsource(objective) - - # Objective function might be defined in some indented scope - # (e.g. in another function). We need to dedent the function code. - objective_code = textwrap.dedent(objective_code) - # Iterate over input parameters. input_params = {} experiment_params = [] trial_params = [] + base_image = constants.BASE_IMAGE_TENSORFLOW, + for p_name, p_value in parameters.items(): # If input parameter value is Katib Experiment parameter sample. if isinstance(p_value, models.V1beta1ParameterSpec): @@ -326,33 +318,49 @@ def tune( # Otherwise, add value to the function input. input_params[p_name] = p_value - # Wrap objective function to execute it from the file. For example - # def objective(parameters): - # print(f'Parameters are {parameters}') - # objective({'lr': '${trialParameters.lr}', 'epochs': '${trialParameters.epochs}', 'is_dist': False}) - objective_code = f"{objective_code}\n{objective.__name__}({input_params})\n" - - # Prepare execute script template. - exec_script = textwrap.dedent( - """ - program_path=$(mktemp -d) - read -r -d '' SCRIPT << EOM\n - {objective_code} - EOM - printf "%s" "$SCRIPT" > $program_path/ephemeral_objective.py - python3 -u $program_path/ephemeral_objective.py""" - ) - - # Add objective code to the execute script. - exec_script = exec_script.format(objective_code=objective_code) - - # Install Python packages if that is required. - if packages_to_install is not None: - exec_script = ( - utils.get_script_for_python_packages(packages_to_install, pip_index_url) - + exec_script + # Handle different types of objective input + if callable(objective): + # Validate objective function. + utils.validate_objective_function(objective) + + # Extract objective function implementation. + objective_code = inspect.getsource(objective) + + # Objective function might be defined in some indented scope + # (e.g. in another function). We need to dedent the function code. + objective_code = textwrap.dedent(objective_code) + + # Wrap objective function to execute it from the file. For example + # def objective(parameters): + # print(f'Parameters are {parameters}') + # objective({'lr': '${trialParameters.lr}', 'epochs': '${trialParameters.epochs}', 'is_dist': False}) + objective_code = f"{objective_code}\n{objective.__name__}({input_params})\n" + + # Prepare execute script template. + exec_script = textwrap.dedent( + """ + program_path=$(mktemp -d) + read -r -d '' SCRIPT << EOM\n + {objective_code} + EOM + printf "%s" "$SCRIPT" > $program_path/ephemeral_objective.py + python3 -u $program_path/ephemeral_objective.py""" ) + # Add objective code to the execute script. + exec_script = exec_script.format(objective_code=objective_code) + + # Install Python packages if that is required. + if packages_to_install is not None: + exec_script = ( + utils.get_script_for_python_packages(packages_to_install, pip_index_url) + + exec_script + ) + elif isinstance(objective, str): + base_image=objective + else: + raise ValueError("The objective must be a callable function or a docker image.") + if isinstance(resources_per_trial, dict): if "gpu" in resources_per_trial: resources_per_trial["nvidia.com/gpu"] = resources_per_trial.pop("gpu") @@ -395,8 +403,8 @@ def tune( client.V1Container( name=constants.DEFAULT_PRIMARY_CONTAINER_NAME, image=base_image, - command=["bash", "-c"], - args=[exec_script], + command=["bash", "-c"] if callable(objective) else None, + args=[exec_script] if callable(objective) else None, env=env if env else None, env_from=env_from if env_from else None, resources=resources_per_trial,