Skip to content

Commit

Permalink
Apply black
Browse files Browse the repository at this point in the history
  • Loading branch information
dexter2206 committed Nov 29, 2023
1 parent b845211 commit 0999128
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 8 deletions.
17 changes: 13 additions & 4 deletions src/omnisolver/common/cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,37 @@ def main(argv=None):
"--vartype", help="Variable type", choices=["SPIN", "BINARY"], default="BINARY"
)

solver_commands = root_parser.add_subparsers(title="Solvers", dest="solver", required=True)
solver_commands = root_parser.add_subparsers(
title="Solvers", dest="solver", required=True
)

all_plugins = get_all_plugins()

for plugin in all_plugins.values():
sub_parser = solver_commands.add_parser(
plugin.name, parents=[common_parser], add_help=False, description=plugin.description
plugin.name,
parents=[common_parser],
add_help=False,
description=plugin.description,
)
plugin.populate_parser(sub_parser)

args = root_parser.parse_args(argv)

chosen_plugin = all_plugins[args.solver]
sampler = chosen_plugin.create_sampler(
**omnisolver.common.plugin.filter_namespace_by_iterable(args, chosen_plugin.init_args)
**omnisolver.common.plugin.filter_namespace_by_iterable(
args, chosen_plugin.init_args
)
)

bqm = bqm_from_coo(args.input, vartype=args.vartype)

result = sampler.sample(
bqm,
**omnisolver.common.plugin.filter_namespace_by_iterable(args, chosen_plugin.sample_args),
**omnisolver.common.plugin.filter_namespace_by_iterable(
args, chosen_plugin.sample_args
),
)

result.to_pandas_dataframe().to_csv(args.output, index=False)
12 changes: 9 additions & 3 deletions src/omnisolver/common/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,24 @@ def filter_namespace_by_iterable(
return dictionary containing mapping attribute name -> attribute value for every
attribute of a signature such that its name is in attribute_filter.
"""
return {key: value for key, value in vars(namespace).items() if key in attribute_filter}
return {
key: value for key, value in vars(namespace).items() if key in attribute_filter
}


TYPE_MAP = {"str": str, "int": int, "float": float}


def add_argument(parser: argparse.ArgumentParser, specification: Dict[str, Any]) -> None:
def add_argument(
parser: argparse.ArgumentParser, specification: Dict[str, Any]
) -> None:
"""Given specification of the argument, add it to parser."""
specification = copy.deepcopy(specification)
arg_name = f"--{specification.pop('name')}"
if "type" in specification:
specification["type"] = TYPE_MAP.get(specification["type"], specification["type"])
specification["type"] = TYPE_MAP.get(
specification["type"], specification["type"]
)
parser.add_argument(arg_name, **specification)


Expand Down
3 changes: 2 additions & 1 deletion src/omnisolver/random/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def sample(self, bqm: BQM, num_reads=1, **parameters) -> SampleSet:
energies = [bqm.energy(sample) for sample in samples]

return cast(
SampleSet, SampleSet.from_samples(samples, energy=energies, vartype=bqm.vartype)
SampleSet,
SampleSet.from_samples(samples, energy=energies, vartype=bqm.vartype),
)

@property
Expand Down

0 comments on commit 0999128

Please sign in to comment.