Skip to content

Commit

Permalink
chore: use black formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
alx committed Jul 7, 2023
1 parent 9870612 commit 3087c55
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 15 deletions.
18 changes: 4 additions & 14 deletions options/predict_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,18 @@ def initialize(self, parser):
"--model-in-file",
type=str,
help="file path to generator model (.pth file)",
required=True
required=True,
)

parser.add_argument(
"--img-in",
type=str,
help="image to transform",
required=True
"--img-in", type=str, help="image to transform", required=True
)

parser.add_argument(
"--img-out",
type=str,
help="transformed image",
required=True
"--img-out", type=str, help="transformed image", required=True
)

parser.add_argument(
"--cpu",
action="store_true",
help="whether to use CPU"
)
parser.add_argument("--cpu", action="store_true", help="whether to use CPU")

parser.add_argument(
"--gpuid",
Expand Down
5 changes: 4 additions & 1 deletion scripts/gen_single_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def load_model(modelpath, model_in_file, cpu, gpuid):
model = model.to(device)
return model, opt, device


def launch_predict(args):
# loading model
modelpath = args.model_in_file.replace(os.path.basename(args.model_in_file), "")
Expand Down Expand Up @@ -97,7 +98,9 @@ def launch_predict(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-in-file", help="file path to generator model (.pth file)", required=True
"--model-in-file",
help="file path to generator model (.pth file)",
required=True,
)

parser.add_argument("--img-in", help="image to transform", required=True)
Expand Down
3 changes: 3 additions & 0 deletions server/joligen_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from options.predict_options import PredictOptions

import sys

sys.path.append(os.path.join(os.path.dirname(__file__), "../scripts"))
from gen_single_image import launch_predict

Expand Down Expand Up @@ -128,6 +129,7 @@ def stop_training(process):
def is_alive(process):
return process.is_alive()


@app.post(
"/predict",
status_code=201,
Expand All @@ -154,6 +156,7 @@ async def predict(request: Request):

return {"message": "ok", "name": name, "status": "running"}


@app.post(
"/train/{name}",
status_code=201,
Expand Down

0 comments on commit 3087c55

Please sign in to comment.