Skip to content

Commit

Permalink
[UPDATE] remove prediction_type
Browse files Browse the repository at this point in the history
  • Loading branch information
markkua committed May 24, 2024
1 parent 9af23c9 commit 37b3d69
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 13 deletions.
4 changes: 2 additions & 2 deletions infer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Last modified: 2024-05-17
# Last modified: 2024-05-24
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -214,7 +214,7 @@ def check_directory(directory):

pipe = pipe.to(device)
logging.info(
f"{pipe.prediction_type = }, {pipe.scale_invariant = }, {pipe.shift_invariant = }"
f"scale_invariant: {pipe.scale_invariant}, shift_invariant: {pipe.shift_invariant}"
)

# -------------------- Inference and saving --------------------
Expand Down
10 changes: 0 additions & 10 deletions marigold/marigold_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,10 @@ def __init__(
scheduler: Union[DDIMScheduler, LCMScheduler],
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
prediction_type: str = None,
scale_invariant: bool = None,
shift_invariant: bool = None,
):
super().__init__()

if prediction_type is None:
logging.warn(
"`prediction_type` is required but not given, filled with 'depth'"
)
prediction_type = "depth"
if scale_invariant is None:
logging.warn(
"`scale_invariant` is required but not given, filled with `True`"
Expand All @@ -118,8 +111,6 @@ def __init__(
"`shift_invariant` is required but not given, filled with `True`"
)
shift_invariant = True

self.prediction_type = prediction_type
self.scale_invariant = scale_invariant
self.shift_invariant = shift_invariant

Expand All @@ -131,7 +122,6 @@ def __init__(
tokenizer=tokenizer,
)
self.register_to_config(
prediction_type=prediction_type,
scale_invariant=scale_invariant,
shift_invariant=shift_invariant,
)
Expand Down
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@

pipe = pipe.to(device)
logging.info(
f"{pipe.prediction_type = }, {pipe.scale_invariant = }, {pipe.shift_invariant = }"
f"scale_invariant: {pipe.scale_invariant}, shift_invariant: {pipe.shift_invariant}"
)

# -------------------- Inference and saving --------------------
Expand Down

0 comments on commit 37b3d69

Please sign in to comment.