From 099912846ce13470c0b73220f73275aa68f9a97a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Konrad=20Ja=C5=82owiecki?= Date: Wed, 29 Nov 2023 01:20:27 +0100 Subject: [PATCH] Apply black --- src/omnisolver/common/cmd.py | 17 +++++++++++++---- src/omnisolver/common/plugin.py | 12 +++++++++--- src/omnisolver/random/sampler.py | 3 ++- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/src/omnisolver/common/cmd.py b/src/omnisolver/common/cmd.py index 0a4b676..5904e0e 100644 --- a/src/omnisolver/common/cmd.py +++ b/src/omnisolver/common/cmd.py @@ -40,13 +40,18 @@ def main(argv=None): "--vartype", help="Variable type", choices=["SPIN", "BINARY"], default="BINARY" ) - solver_commands = root_parser.add_subparsers(title="Solvers", dest="solver", required=True) + solver_commands = root_parser.add_subparsers( + title="Solvers", dest="solver", required=True + ) all_plugins = get_all_plugins() for plugin in all_plugins.values(): sub_parser = solver_commands.add_parser( - plugin.name, parents=[common_parser], add_help=False, description=plugin.description + plugin.name, + parents=[common_parser], + add_help=False, + description=plugin.description, ) plugin.populate_parser(sub_parser) @@ -54,14 +59,18 @@ def main(argv=None): chosen_plugin = all_plugins[args.solver] sampler = chosen_plugin.create_sampler( - **omnisolver.common.plugin.filter_namespace_by_iterable(args, chosen_plugin.init_args) + **omnisolver.common.plugin.filter_namespace_by_iterable( + args, chosen_plugin.init_args + ) ) bqm = bqm_from_coo(args.input, vartype=args.vartype) result = sampler.sample( bqm, - **omnisolver.common.plugin.filter_namespace_by_iterable(args, chosen_plugin.sample_args), + **omnisolver.common.plugin.filter_namespace_by_iterable( + args, chosen_plugin.sample_args + ), ) result.to_pandas_dataframe().to_csv(args.output, index=False) diff --git a/src/omnisolver/common/plugin.py b/src/omnisolver/common/plugin.py index 2e6947e..747f34b 100644 --- a/src/omnisolver/common/plugin.py +++ b/src/omnisolver/common/plugin.py @@ -65,18 +65,24 @@ def filter_namespace_by_iterable( return dictionary containing mapping attribute name -> attribute value for every attribute of a signature such that its name is in attribute_filter. """ - return {key: value for key, value in vars(namespace).items() if key in attribute_filter} + return { + key: value for key, value in vars(namespace).items() if key in attribute_filter + } TYPE_MAP = {"str": str, "int": int, "float": float} -def add_argument(parser: argparse.ArgumentParser, specification: Dict[str, Any]) -> None: +def add_argument( + parser: argparse.ArgumentParser, specification: Dict[str, Any] +) -> None: """Given specification of the argument, add it to parser.""" specification = copy.deepcopy(specification) arg_name = f"--{specification.pop('name')}" if "type" in specification: - specification["type"] = TYPE_MAP.get(specification["type"], specification["type"]) + specification["type"] = TYPE_MAP.get( + specification["type"], specification["type"] + ) parser.add_argument(arg_name, **specification) diff --git a/src/omnisolver/random/sampler.py b/src/omnisolver/random/sampler.py index 8fb290c..521f7a8 100644 --- a/src/omnisolver/random/sampler.py +++ b/src/omnisolver/random/sampler.py @@ -32,7 +32,8 @@ def sample(self, bqm: BQM, num_reads=1, **parameters) -> SampleSet: energies = [bqm.energy(sample) for sample in samples] return cast( - SampleSet, SampleSet.from_samples(samples, energy=energies, vartype=bqm.vartype) + SampleSet, + SampleSet.from_samples(samples, energy=energies, vartype=bqm.vartype), ) @property