diff --git a/naslib/search_spaces/hierarchical/graph.py b/naslib/search_spaces/hierarchical/graph.py index 4c61ff218..824a7e6f6 100755 --- a/naslib/search_spaces/hierarchical/graph.py +++ b/naslib/search_spaces/hierarchical/graph.py @@ -74,7 +74,7 @@ def __init__(self): self.add_nodes_from([i for i in range(1, 9)]) self.add_edges_from([(i, i + 1) for i in range(1, 8)]) - self.edges[1, 2].set("op", ops.Stem(16)) + self.edges[1, 2].set("op", ops.Stem(C_out=16)) self.edges[2, 3].set("op", cells[0]) self.edges[3, 4].set( "op", ops.SepConv(16, 32, kernel_size=3, stride=2, padding=1) @@ -117,7 +117,7 @@ def prepare_evaluation(self): single_instances=False, ) - self.edges[1, 2].set("op", ops.Stem(channels[0])) + self.edges[1, 2].set("op", ops.Stem(C_out=channels[0])) self.edges[2, 3].set("op", cells[0].copy()) self.edges[3, 4].set( "op", @@ -191,7 +191,7 @@ def _expand(self): # single_instances=False # ) - # self.edges[1, 2].set('op', ops.Stem(channels[0])) + # self.edges[1, 2].set('op', ops.Stem(C_out=channels[0])) # self.edges[2, 3].set('op', cells[0].copy()) # self.edges[3, 4].set('op', ops.SepConv(channels[0], channels[1], kernel_size=3, stride=2, padding=1)) # self.edges[4, 5].set('op', cells[1].copy()) @@ -400,7 +400,7 @@ def __init__(self): self.add_nodes_from([i for i in range(1, 15)]) self.add_edges_from([(i, i + 1) for i in range(1, 14)]) - self.edges[1, 2].set("op", ops.Stem(channels[0])) + self.edges[1, 2].set("op", ops.Stem(C_out=channels[0])) self.edges[2, 3].set("op", cells[0].copy()) self.edges[3, 4].set( "op", diff --git a/naslib/search_spaces/simple_cell/graph.py b/naslib/search_spaces/simple_cell/graph.py index de51ba346..580bd9ae4 100644 --- a/naslib/search_spaces/simple_cell/graph.py +++ b/naslib/search_spaces/simple_cell/graph.py @@ -139,7 +139,7 @@ def __init__( # Compile the ops self.edges[1, 2].set( - "op", ops.Stem(channels[0]) + "op", ops.Stem(C_out=channels[0]) ) # we can also set a compiled op. Will be ignored by compile() def set_channels(edge, C): diff --git a/requirements.txt b/requirements.txt index 74b689a8a..d0c7aff66 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ xgboost==1.4.2 emcee==3.1.0 pybnn==0.0.5 grakel==0.1.8 -pyro-ppl==1.6.0 +pyro-ppl==1.8.4 # additional from setup.py prev tqdm==4.61.1 diff --git a/tests/test_nb301_search_space.py b/tests/test_nb301_search_space.py index f1c8bdce4..31470cdee 100644 --- a/tests/test_nb301_search_space.py +++ b/tests/test_nb301_search_space.py @@ -117,7 +117,7 @@ def test_forward_pass_aux_head(self): graph(torch.randn(3, 3, 32, 32)) aux_out = graph.auxiliary_logits() - self.assertEqual(aux_out.shape, (3, 512, 8, 8)) + self.assertEqual(aux_out.shape, (3, 256, 8, 8)) def test_forward_pass_aux_head_eval(self): graph = create_model()