From df1c9c52a2140d49a486751335b5f930250dd6b3 Mon Sep 17 00:00:00 2001 From: ccsuu Date: Fri, 23 Aug 2024 02:35:46 +0000 Subject: [PATCH] Fix using VAE in quantization mode --- onediff_comfy_nodes/modules/booster_cache.py | 18 +++++++++--------- .../modules/oneflow/booster_quantization.py | 13 +++++++++++++ 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/onediff_comfy_nodes/modules/booster_cache.py b/onediff_comfy_nodes/modules/booster_cache.py index 06391c467..c65efa181 100644 --- a/onediff_comfy_nodes/modules/booster_cache.py +++ b/onediff_comfy_nodes/modules/booster_cache.py @@ -5,9 +5,6 @@ from comfy.model_patcher import ModelPatcher from comfy.sd import VAE from onediff.torch_utils.module_operations import get_sub_module -from onediff.utils.import_utils import is_oneflow_available - -from .._config import is_disable_oneflow_backend @singledispatch @@ -54,14 +51,17 @@ def _(model: ModelPatcher): @get_cached_model.register def _(model: VAE): - if is_oneflow_available() and not is_disable_oneflow_backend(): - from .oneflow.utils.booster_utils import is_using_oneflow_backend + # from onediff.utils.import_utils import is_oneflow_available + # from .._config import is_disable_oneflow_backend + # if is_oneflow_available() and not is_disable_oneflow_backend(): + # from .oneflow.utils.booster_utils import is_using_oneflow_backend - if is_using_oneflow_backend(model): - return None + # if is_using_oneflow_backend(model): + # return None - # TODO(TEST) if support cache - return model.first_stage_model + # # TODO(TEST) if support cache + # # return model.first_stage_model + return None class BoosterCacheService: diff --git a/onediff_comfy_nodes/modules/oneflow/booster_quantization.py b/onediff_comfy_nodes/modules/oneflow/booster_quantization.py index 4b246013c..fc0604a3a 100644 --- a/onediff_comfy_nodes/modules/oneflow/booster_quantization.py +++ b/onediff_comfy_nodes/modules/oneflow/booster_quantization.py @@ -1,4 +1,5 @@ import os +import warnings from dataclasses import dataclass from functools import partial, singledispatchmethod from typing import Any, Dict, Optional, Union @@ -7,6 +8,7 @@ import torch.nn as nn from comfy.controlnet import ControlNet from comfy.model_patcher import ModelPatcher +from comfy.sd import VAE from onediff.infer_compiler import oneflow_compile from onediff.infer_compiler.backends.oneflow import ( OneflowDeployableModule as DeployableModule, @@ -125,6 +127,17 @@ def execute(self, model, ckpt_name=None, **kwargs): def extract_torch_module(self, model): raise NotImplementedError(f"{type(model)}") + @execute.register(VAE) + def _(self, model: VAE, **kwargs): + # TODO: VAE does not support quantization and patch compatibility + from .booster_basic import BasicOneFlowBoosterExecutor + + warnings.warn( + "TODO: VAE does not support quantization and patch compatibility", + UserWarning, + ) + return BasicOneFlowBoosterExecutor().execute(model, **kwargs) + @execute.register(ModelPatcher) @execute.register(ControlNet) def _(