Skip to content

Commit

Permalink
Fix the CI
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Jun 21, 2024
1 parent 53faa0a commit 76489bd
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions torchbenchmark/models/soft_actor_critic/nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import distributions as pyd
from torch import nn

from . import utils
from . import sac_utils
from torchbenchmark.util.distribution import SquashedNormal

def weight_init(m):
Expand All @@ -30,11 +30,11 @@ def __init__(self, obs_shape, out_dim=50):
self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1)
self.conv4 = nn.Conv2d(32, 32, kernel_size=3, stride=1)

output_height, output_width = utils.compute_conv_output(
output_height, output_width = sac_utils.compute_conv_output(
obs_shape[1:], kernel_size=(3, 3), stride=(2, 2)
)
for _ in range(3):
output_height, output_width = utils.compute_conv_output(
output_height, output_width = sac_utils.compute_conv_output(
(output_height, output_width), kernel_size=(3, 3), stride=(1, 1)
)

Expand Down Expand Up @@ -63,15 +63,15 @@ def __init__(self, obs_shape, out_dim=50):
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)

output_height, output_width = utils.compute_conv_output(
output_height, output_width = sac_utils.compute_conv_output(
obs_shape[1:], kernel_size=(8, 8), stride=(4, 4)
)

output_height, output_width = utils.compute_conv_output(
output_height, output_width = sac_utils.compute_conv_output(
(output_height, output_width), kernel_size=(4, 4), stride=(2, 2)
)

output_height, output_width = utils.compute_conv_output(
output_height, output_width = sac_utils.compute_conv_output(
(output_height, output_width), kernel_size=(3, 3), stride=(1, 1)
)

Expand Down

0 comments on commit 76489bd

Please sign in to comment.