Skip to content

Commit

Permalink
Group nodes support.
Browse files Browse the repository at this point in the history
  • Loading branch information
shiimizu committed Dec 2, 2023
1 parent b1defa0 commit a9a4238
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 146 deletions.
5 changes: 3 additions & 2 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def INPUT_TYPES(s):
"text_l": ("STRING", {"multiline": True, "placeholder": "CLIP_L"}),
},
"optional": {
"steps": ("INT", {"default": 1, "min": 1, "max": 0xffffffffffffffff}),
"smZ_steps": ("INT", {"default": 1, "min": 1, "max": 0xffffffffffffffff}),
},
}
RETURN_TYPES = ("CONDITIONING",)
Expand All @@ -58,8 +58,9 @@ def INPUT_TYPES(s):
def encode(self, clip: comfy.sd.CLIP, text, parser, mean_normalization,
multi_conditioning, use_old_emphasis_implementation,
with_SDXL, ascore, width, height, crop_w,
crop_h, target_width, target_height, text_g, text_l, steps=1):
crop_h, target_width, target_height, text_g, text_l, smZ_steps=1):
params = locals()
params['steps'] = params.pop('smZ_steps', smZ_steps)
from .modules.shared import opts
is_sdxl = "SDXL" in type(clip.cond_stage_model).__name__

Expand Down
7 changes: 3 additions & 4 deletions smZNodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ def set_dtype_compat(dtype):
newv = False
if newv:
dtype = devices.dtype if dtype != devices.dtype else dtype
token_embedding_dtype = position_embedding_dtype = torch.float32
if newv:
# self.transformer.text_model.embeddings.position_embedding.to(dtype)
# self.transformer.text_model.embeddings.token_embedding.to(dtype)
inner_model = getattr(self.transformer, self.inner_name, None)
Expand All @@ -80,14 +78,15 @@ def set_dtype_compat(dtype):
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(dtype))
def reset_dtype_compat(newv):
if newv:
# token_embedding_dtype = position_embedding_dtype = torch.float32
# self.transformer.text_model.embeddings.token_embedding.to(token_embedding_dtype)
# self.transformer.text_model.embeddings.position_embedding.to(position_embedding_dtype)
inner_model = getattr(self.transformer, self.inner_name, None)
if inner_model is not None and hasattr(inner_model, "embeddings"):
inner_model.embeddings.to(torch.float32)
else:
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
# set_dtype_compat()
# set_dtype_compat(torch.float16)

backup_embeds = self.transformer.get_input_embeddings()
device = backup_embeds.weight.device
Expand Down Expand Up @@ -818,7 +817,7 @@ def check_link_to_clip(node_id, clip_id, visited=None):
current_clip_id = clip_id
steps = find_nearest_ksampler(clip_id)
if steps is not None:
node["inputs"]["steps"] = steps
node["inputs"]["smZ_steps"] = steps
if opts.debug:
print(f'[smZNodes] id: {current_clip_id} | steps: {steps}')
tmp()
Expand Down
Loading

0 comments on commit a9a4238

Please sign in to comment.