Skip to content

Commit

Permalink
rename var
Browse files Browse the repository at this point in the history
  • Loading branch information
frankaging committed Jan 21, 2024
1 parent 48d1a3b commit 78b3ac5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
9 changes: 5 additions & 4 deletions pyvene/models/interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,11 +532,12 @@ def __init__(self, embed_dim, **kwargs):
self.interchange_dim = embed_dim
rs = np.random.RandomState(1)
prng = lambda *shape: rs.randn(*shape)
self.noise = torch.from_numpy(
prng(1, 4, embed_dim)).to(device)
self.noise_level = kwargs["noise_leve"] \
noise_level = kwargs["noise_leve"] \
if "noise_leve" in kwargs else 0.13462981581687927

self.register_buffer('noise', torch.from_numpy(
prng(1, 4, embed_dim)))
self.register_buffer('noise_level', torch.tensor(noise_level))

def forward(self, base, source=None, subspaces=None):
base[..., : self.interchange_dim] += self.noise * self.noise_level
return base
Expand Down
4 changes: 2 additions & 2 deletions pyvene_101.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1583,7 +1583,7 @@
},
{
"cell_type": "code",
"execution_count": 100,
"execution_count": 2,
"id": "c0b6a70f",
"metadata": {},
"outputs": [
Expand All @@ -1606,7 +1606,7 @@
" [{\"component\": \"block_input\"}] + \n",
" [{\"layer\": l, \"component\": c} \n",
" for l in range(s, e)],\n",
" [pv.ZeroIntervention] +\n",
" [pv.NoiseIntervention] +\n",
" [pv.VanillaIntervention]*(e-s))\n",
" return config\n",
"\n",
Expand Down

0 comments on commit 78b3ac5

Please sign in to comment.