Skip to content

Commit

Permalink
update comment
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Nov 3, 2023
1 parent d28fbd3 commit 9f910a2
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/python/tools/transformers/float16.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def make_value_info_from_tensor(tensor):


# Some operators has data type fixed as float for some inputs. Key is op_type, value is list of input indices
# Note that DirectML allows float16 gamma and beta in GroupNorm. Use force_fp16_inputs parameter to overwrite it.
# Note that DirectML allows float16 gamma and beta in GroupNorm. Use force_fp16_inputs parameter could overwrite this.
ALWAYS_FLOAT_INPUTS = {"Resize": [2], "GroupNorm": [1, 2], "SkipGroupNorm": [1, 2]}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class FusionSkipGroupNorm(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "SkipGroupNorm", "GroupNorm")
# Update shape inference is needed since other fusions might add new edge which does not have shape info yet.
self.shape_infer_helper = self.model.infer_runtime_shape(update=True)
self.shape_infer_helper = self.model.infer_runtime_shape({"batch_size": 4, "seq_len": 7}, update=True)

if self.shape_infer_helper is None:
logger.warning("SkipGroupNorm fusion will be skipped since symbolic shape inference disabled or failed.")
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/python/tools/transformers/onnx_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ def optimize(self, options: Optional[FusionOptions] = None):

self.fuse_reshape()

group_norm_channels_last = (options is None) or options.group_norm_channels_last
if (options is None) or options.enable_group_norm:
group_norm_fusion = FusionGroupNorm(self, group_norm_channels_last)
channels_last = (options is None) or options.group_norm_channels_last
group_norm_fusion = FusionGroupNorm(self, channels_last)
group_norm_fusion.apply()

insert_transpose_fusion = FusionInsertTranspose(self)
Expand Down

0 comments on commit 9f910a2

Please sign in to comment.