Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Are oracles trained on normalized x or unnormalized x? #15

Open
yuanqidu opened this issue Jan 18, 2024 · 2 comments
Open

Are oracles trained on normalized x or unnormalized x? #15

yuanqidu opened this issue Jan 18, 2024 · 2 comments
Assignees

Comments

@yuanqidu
Copy link

yuanqidu commented Jan 18, 2024

Dear authors,

Thanks for open sourcing this library. I'm trying to understand if the oracles were trained on normalized or unnormalized x. Thank you for your help. If yes, what dataset they were used for normalizing them? Since when we use the oracle, we can only normalize them by the provided dataset (smaller), would this cause problems as input to the oracle (if normalized on a larger dataset)?

Best,
Yuanqi

@brandontrabucco
Copy link
Owner

Hello yuanqidu,

Thanks for your interest in design-bench!

The benchmark keeps two copies of an MBO dataset---a private internal version that matches the format expected by the oracle model, and a public version exposed to the user for benchmarking their own algorithms.

The code for this separation is located in four functions:

dataset_to_oracle_x:

def dataset_to_oracle_x(self, x_batch, dataset=None):

dataset_to_oracle_y:

def dataset_to_oracle_y(self, y_batch, dataset=None):

oracle_to_dataset_x:

def oracle_to_dataset_x(self, x_batch, dataset=None):

oracle_to_dataset_y:

def oracle_to_dataset_y(self, y_batch, dataset=None):

This boilerplate code handles the conversion from oracle format to public format.

For a particular task from design-bench, you can find out what format the oracle expects by checking task.oracle.expect_normalized_x and task.oracle.expect_normalized_y.

Design-Bench internally manages these to ensure the format is correct when task.predict(xs) is called, where xs is a batch of designs you want to evaluate.

@brandontrabucco
Copy link
Owner

This section of the README may help too:

import design_bench
task = design_bench.make('TFBind8-Exact-v0')

# convert x to logits of a categorical probability distribution
task.map_to_logits()
discrete_x = task.to_integers(task.x)

# normalize the inputs to have zero mean and unit variance
task.map_normalize_x()
original_x = task.denormalize_x(task.x)

# normalize the outputs to have zero mean and unit variance
task.map_normalize_y()
original_y = task.denormalize_y(task.y)

# remove the normalization applied to the outputs
task.map_denormalize_y()
normalized_y = task.normalize_y(task.y)

# remove the normalization applied to the inputs
task.map_denormalize_x()
normalized_x = task.normalize_x(task.x)

# convert x back to integers
task.map_to_integers()
continuous_x = task.to_logits(task.x)

@brandontrabucco brandontrabucco self-assigned this Jan 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants