Skip to content

Commit

Permalink
Add string parameter to metadata conversion support.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 573243721
  • Loading branch information
qiuyiz authored and copybara-github committed Oct 13, 2023
1 parent 6edfdd0 commit 641b3a0
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 0 deletions.
68 changes: 68 additions & 0 deletions vizier/pyvizier/converters/string_converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2023 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

"""Converter utils for parameters for free-form strings."""

from typing import Sequence
import copy
import json
import attrs
from vizier import pyvizier as vz

_METADATA_VERSION = '0.0.1a'
PROMPT_TUNING_NS = 'prompt_tuning'


@attrs.define
class PromptTuningConfig:
"""Variables and utils for configuring prompt tuning."""

default_prompts: dict[str, str] = attrs.field(factory=dict)

def augment_problem(
self, problem: vz.ProblemStatement
) -> vz.ProblemStatement:
"""Augments problem statement to enable for prompt tuning."""
for k, v in self.default_prompts.items():
problem.search_space.root.add_categorical_param(k, [v], default_value=v)
problem.metadata.ns(PROMPT_TUNING_NS)['version'] = _METADATA_VERSION
return problem

def to_prompt_trials(self, trials: Sequence[vz.Trial]) -> Sequence[vz.Trial]:
"""Convert to prompt Trial via metadata to string valued parameters."""
prompt_trials = copy.deepcopy(trials)
for trial in prompt_trials:
prompt_values = json.loads(trial.metadata.ns(PROMPT_TUNING_NS)['values'])
for k in self.default_prompts.keys():
if k in prompt_values:
trial.parameters[k] = prompt_values[k]
return prompt_trials

def to_valid_suggestions(
self, suggestions: Sequence[vz.TrialSuggestion]
) -> Sequence[vz.TrialSuggestion]:
"""Returns TrialSuggestions that are valid in the augmented problem."""
valid_suggestions = copy.deepcopy(suggestions)
for suggestion in valid_suggestions:
prompt_values = {}
for k, default_value in self.default_prompts.items():
prompt_values[k] = suggestion.parameters[k].value
suggestion.parameters[k] = default_value
suggestion.metadata.ns(PROMPT_TUNING_NS)['values'] = json.dumps(
prompt_values
)
suggestion.metadata.ns(PROMPT_TUNING_NS)['version'] = _METADATA_VERSION
return valid_suggestions
73 changes: 73 additions & 0 deletions vizier/pyvizier/converters/string_converters_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2023 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import json

from vizier import pyvizier as vz
from vizier.pyvizier.converters import string_converters

from absl.testing import absltest


class StringConvertersTest(absltest.TestCase):

def setUp(self):
super().setUp()
self.default_prompts = {'prompt1': 'def1', 'prompt2': 'def2'}
self.config = string_converters.PromptTuningConfig(
default_prompts=self.default_prompts
)

def test_augment_problem(self):
problem = vz.ProblemStatement()
tuning_problem = self.config.augment_problem(problem)
for k, v in self.default_prompts.items():
pconfig = tuning_problem.search_space.get(k)
self.assertCountEqual(pconfig.feasible_values, [v])
self.assertEqual(pconfig.default_value, v)

def test_prompt_trials(self):
trial = vz.Trial(parameters={'int': 3, 'float': 1.2, 'cat': 'test'})
trial.metadata.ns(string_converters.PROMPT_TUNING_NS)['values'] = (
json.dumps({'prompt1': 'test1', 'prompt2': 'test2'})
)
results = self.config.to_prompt_trials([trial])
self.assertLen(results, 1)

self.assertStartsWith(results[0].parameters['prompt1'].value, 'test1')
self.assertEqual(results[0].parameters['prompt2'].value, 'test2')

def test_valid_suggestions(self):
problem = vz.ProblemStatement()
tuning_problem = self.config.augment_problem(problem)
suggestion = vz.TrialSuggestion(
parameters={'prompt1': 'test1', 'prompt2': 'test2'}
)
valid_suggestions = self.config.to_valid_suggestions([suggestion])
self.assertLen(valid_suggestions, 1)
valid_suggestion = valid_suggestions[0]
self.assertTrue(
tuning_problem.search_space.contains(valid_suggestion.parameters)
)

# Test the reverse conversion retrieves original parameters.
trial = valid_suggestion.to_trial().complete(vz.Measurement())
prompt_trial = self.config.to_prompt_trials([trial])[0]
self.assertCountEqual(prompt_trial.parameters, suggestion.parameters)


if __name__ == '__main__':
absltest.main()

0 comments on commit 641b3a0

Please sign in to comment.