Skip to content

Commit

Permalink
Fix scale normal distributions in aten::normal | fix(torchlib) (#1096)
Browse files Browse the repository at this point in the history
Previously I mistakenly swapped mean and std when transforming $X \sim
N(0, 1)$ to $X \sim N(mean, std^2)$

Original PR: #1094
  • Loading branch information
justinchuby authored Oct 18, 2023
1 parent 87029cf commit 778e8e8
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5867,7 +5867,7 @@ def aten_normal_float_tensor(mean: FLOAT, std: TFloat) -> TFloat:
mean_casted = op.CastLike(mean, std)
sampled = op.RandomNormalLike(mean_casted, mean=0.0, scale=1.0)
# Transform the distribution to the mean and std
return op.Add(op.Mul(mean_casted, sampled), std)
return op.Add(op.Mul(std, sampled), mean_casted)


@torch_op("aten::normal.Tensor_float")
Expand All @@ -5876,7 +5876,7 @@ def aten_normal_tensor_float(mean: TFloat, std: FLOAT) -> TFloat:

sampled = op.RandomNormalLike(mean, mean=0.0, scale=1.0)
# Transform the distribution to the mean and std
return op.Add(op.Mul(mean, sampled), op.CastLike(std, sampled))
return op.Add(op.Mul(op.CastLike(std, sampled), sampled), mean)


@torch_op("aten::normal.Tensor_Tensor")
Expand All @@ -5885,7 +5885,7 @@ def aten_normal_tensor_tensor(mean: TFloat, std: TFloat) -> TFloat:

sampled = op.RandomNormalLike(mean, mean=0.0, scale=1.0)
# Transform the distribution to the mean and std
return op.Add(op.Mul(mean, sampled), std)
return op.Add(op.Mul(std, sampled), mean)


def aten_not_equal(self: TensorType, other: TensorType) -> TensorType:
Expand Down

0 comments on commit 778e8e8

Please sign in to comment.