Skip to content

Commit

Permalink
Created a demo for attribution patching with minor bug fixes to demos…
Browse files Browse the repository at this point in the history
… and code on activation patching (#168)

* Attribution Patching Progress

* Finished Attribution Patching Demo!
  • Loading branch information
neelnanda-io authored Feb 4, 2023
1 parent 5f73982 commit 407f5fe
Show file tree
Hide file tree
Showing 6 changed files with 1,582 additions and 5 deletions.
1 change: 1 addition & 0 deletions Attribution_Patching_Demo.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion Exploratory_Analysis_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"**This section explains how to do activation patching conceptually by implementing it from scratch. To use it in practice with TransformerLens, see [this demonstration instead](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/activation_patching_in_TL_demo.py.ipynb)**\n",
"\n",
"The obvious limitation to the techniques used above is that they only look at the very end of the circuit - the parts that directly affect the logits. Clearly this is not sufficient to understand the circuit! We want to understand how things compose together to produce this final output, and ideally to produce an end-to-end circuit fully explaining this behaviour. \n",
"\n",
"The technique we'll use to investigate this is called **activation patching**. This was first introduced in [David Bau and Kevin Meng's excellent ROME paper](https://rome.baulab.info/), there called causal tracing. \n",
Expand Down Expand Up @@ -2871,7 +2873,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.10.8 (main, Nov 4 2022, 13:48:29) [GCC 11.2.0]"
},
"vscode": {
"interpreter": {
Expand Down
1,572 changes: 1,571 additions & 1 deletion activation_patching_in_TL_demo.py.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions ioi_patching_data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"patched_residual_stream_diff": [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.000651478767395, -0.0002479881513863802, 9.296619282395113e-06, -0.0003644475946202874, -4.842967973672785e-05], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0010520219802856, -2.721862074395176e-05, -2.0741194020956755e-05, -0.0004594610654748976, -0.0005936747184023261], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0002667903900146, 0.0008668677764944732, 0.0005157442064955831, -0.0009937314316630363, -0.0008655253332108259], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9949089884757996, 0.005429024342447519, 0.0016050260746851563, -0.0006185104721225798, -0.0016331843798980117], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9675674438476562, 0.031340714544057846, 0.002841711277142167, -0.0012312817852944136, -0.0009871532674878836], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9675208330154419, 0.030999794602394104, 0.0017816954059526324, -0.00048647832591086626, -0.0006477763527072966], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9228330254554749, 0.05134354904294014, 0.004729931708425283, 0.0009341927361674607, 0.017046811059117317], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6565515398979187, 0.023855896666646004, 0.002357548801228404, -1.7452137399232015e-05, 0.3186882734298706], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.027303969487547874, 0.031424786895513535, 0.0018208284163847566, 0.0007986366399563849, 0.9383859634399414], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.026844238862395287, 0.02098052203655243, 0.0012515531852841377, 0.0003235022013541311, 1.0048311948776245], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.005686946678906679, 0.014263397082686424, 0.0004866796953137964, -8.977782272268087e-05, 0.991420567035675]], "patched_attn_diff": [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.035460326820611954, -0.0002479881513863802, 9.296619282395113e-06, -0.0003644475946202874, -4.842967973672785e-05], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0029838120099157095, 7.850105612305924e-05, 1.9969271306763403e-05, 8.068257011473179e-05, -0.0005969973281025887], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.001912116538733244, 0.000666503852698952, 0.00039458609535358846, -0.0007047642720863223, -0.0002737636095844209], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.15463462471961975, 0.0038017802871763706, 0.0005164489848539233, -0.00012025193427689373, -0.0005613211542367935], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.005407712422311306, 0.019580895081162453, 0.001006820471957326, -0.00024275251780636609, 0.0007935016765259206], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3521000146865845, 0.0010515919420868158, 0.0002249311946798116, 0.00013216637307778, 8.202504250220954e-05], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.11986151337623596, 0.02124321088194847, 0.002727869665250182, 0.0013409617822617292, 0.017974458634853363], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.013308966532349586, 0.011509046889841557, 0.00037411341327242553, -4.165019709034823e-05, 0.29759806394577026], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0015008501941338181, 0.017351988703012466, 0.0005840760422870517, 0.00101185473613441, 0.5697361826896667], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.00012907868949696422, 0.00630119489505887, 0.00014119449770078063, 0.00031289667822420597, 0.27153223752975464], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0009376160451211035, 8.585109026171267e-05, 0.00033199333120137453, 1.5438429272762733e-06, -0.19299523532390594], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.40618857741355896]], "patched_mlp_diff": [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8507810831069946, -0.000278361578239128, -7.292979717021808e-05, -0.00047395977890118957, 3.95022398151923e-05], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.008865014649927616, 0.0002215750137111172, 0.00014908151933923364, -4.9168040277436376e-05, 0.0003036336274817586], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.013549369759857655, 5.725643495679833e-05, -0.00033014744985848665, -0.000639385893009603, 0.00077249197056517], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0019460810581222177, 0.0004987283609807491, 0.00017364876111969352, 0.00016871518164407462, 0.00040720534161664546], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.019789112731814384, 0.004128504544496536, -4.8631052777636796e-05, -0.00016999052604660392, 0.0007910181302577257], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.09652552008628845, -0.0018829178297892213, -0.0004840283072553575, 0.0007095971959643066, -0.00018338167865294963], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.01589762233197689, -0.0008510937332175672, 0.00012283619435038418, 2.8091228159610182e-05, -0.007237971760332584], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.010361199267208576, 0.0031500770710408688, 0.0005304442602209747, 0.00023483192489948124, 0.00849811639636755], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.012533755972981453, 2.1244621166260913e-05, -0.0003541441401466727, 8.675725985085592e-05, -0.02163410559296608], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.00033521527075208724, 0.0008088729809969664, 1.6076102838269435e-05, 0.00012991773837711662, 0.03162303566932678], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0013596221106126904, -0.00019489337864797562, -9.93764988379553e-05, -0.00014206710329744965, 0.02876400761306286], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.020446084439754486]], "patched_head_z_diff": [[0.0009491948876529932, 0.01612412929534912, 0.0018557998118922114, 0.0034389100037515163, -0.009823639877140522, 0.011060794815421104, -0.004063830710947514, -0.0015787136508151889, -0.0012075871927663684, 0.003828898072242737, -0.0042571802623569965, -0.00114311499055475], [-0.0010770318331196904, -0.0003781408304348588, 2.6849440928344848e-06, -0.00026033888570964336, -0.00014079175889492035, 0.0038331940304487944, -0.00042818146175704896, -0.0014287594240158796, -0.0009219091152772307, 0.0006930847885087132, 0.00043304794235154986, -0.003570942208170891], [-0.0004963455139659345, 0.000805919524282217, 0.0005419224034994841, -0.0005311155109666288, -0.000715705449692905, -0.001038032933138311, -0.00094942981377244, -8.615314436610788e-05, 0.0002773211745079607, 0.002107983222231269, -0.00019781326409429312, -0.0016397960716858506], [0.11626482754945755, 0.0002489949984010309, -0.001467355526983738, -0.0003970025572925806, 0.01896204799413681, -0.0001887851394712925, 0.011169938370585442, -0.0013294836971908808, -0.0007355068810284138, -0.00030346581479534507, -0.00014582602307200432, -0.00022248119057621807], [-0.0016508714761584997, 0.00029235685360617936, -0.0014360088389366865, 0.030840208753943443, -0.007432563230395317, -0.0002836643543560058, 0.006017262116074562, -0.011007364839315414, -0.0012651120778173208, 0.0014894056366756558, -0.00017958920216187835, 0.0029440748039633036], [-0.00421096570789814, 0.0029594460502266884, 0.002045289846137166, 0.0013403912307694554, -0.001218863995745778, 0.3435027003288269, 0.0005634355475194752, -0.00012548758240882307, -0.0051532466895878315, 0.016240723431110382, 0.017091985791921616, -0.0041746520437300205], [0.03977486491203308, 0.015226955525577068, -0.0010223931167274714, 0.0008083695429377258, -0.004934491124004126, -0.0021227505058050156, -0.014274069108068943, 0.0013750941725447774, 0.0014846062986180186, 0.13027158379554749, -0.0003356515953782946, 0.0012918944703415036], [0.0003727709408849478, 0.01951444335281849, 0.0002230852987850085, 0.12425166368484497, -0.00040455395355820656, -0.007651688065379858, 0.0013004862703382969, -0.0011243539629504085, -0.007450351025909185, 0.19223351776599884, -0.003276403760537505, -0.0005017154035158455], [-0.0010078608756884933, 3.144740912830457e-05, -0.0008587458287365735, 0.012362959794700146, -0.00040452039684168994, -0.004331318195909262, 0.31855371594429016, 0.0023293904960155487, 0.002117783296853304, 0.00014065751747693866, 0.27794405817985535, 0.0057383631356060505], [0.005890297703444958, -0.0009688956779427826, 0.009125386364758015, 0.02067665383219719, -0.03700762614607811, 0.014263329096138477, -0.04828083515167236, 0.05834205821156502, 0.0006519044400192797, 0.2636124789714813, 0.0004916132893413305, -0.002611208939924836], [0.08374394476413727, 0.020674103870987892, -0.0037426778580993414, 0.010850698687136173, -0.0010965983383357525, 0.00047395977890118957, 0.04817712679505348, -0.47993069887161255, 0.0001845563529059291, 0.011859935708343983, 0.06088154390454292, 0.000845656730234623], [0.00532709714025259, -0.011492668651044369, -0.1135050430893898, 0.006329151801764965, 0.0003161186177749187, -0.0011606006883084774, -0.022670023143291473, 0.004070979543030262, 0.007316372357308865, -0.008346282877027988, -0.27818864583969116, 0.0036351794842630625]], "patched_head_attn_diff": [[0.0006395537056960166, 0.005319780670106411, 0.0011585198808461428, -5.9337264247005805e-05, -0.0010672317584976554, 0.005080149043351412, -0.003081409726291895, -0.002051867777481675, -0.0014416807098314166, 0.0034932801499962807, -0.00256761210039258, -0.0009158680331893265], [-0.000760812486987561, 0.00016958778724074364, 0.00012320537643972784, -0.0003490762901492417, 1.5371304471045732e-05, 0.005008259788155556, -0.0002968877088278532, -0.0014443321852013469, -0.0011006592540070415, 0.00047379196621477604, 5.064475772087462e-05, -0.003494656179100275], [-0.0007240287377499044, 0.0017483349656686187, -0.00015555895515717566, 5.789410715806298e-05, -9.753059566719458e-05, -0.0004235499363858253, -0.0007922934601083398, 0.0002720183983910829, 0.00010236349771730602, 0.0004227109020575881, 0.0001511287991888821, -0.0007437966414727271], [0.11458384990692139, 0.00021254688908811659, -0.0009415092063136399, 0.0004294568207114935, 0.020042773336172104, 0.0021046942565590143, 7.655446825083345e-05, -0.0015432722866535187, -0.0008481402765028179, -0.00058226368855685, 0.00011944645666517317, -1.9600092855398543e-05], [-0.0011277101002633572, 0.0012377927778288722, -0.0012322216061875224, -0.0005951514467597008, -0.0007551069720648229, -0.0005841096281073987, 0.004812728613615036, 0.00018187140813097358, -0.0005364854005165398, 0.0008574033272452652, -0.00029870003345422447, -1.204868658533087e-05], [-0.004241171292960644, 0.0029524988494813442, 0.0005213490221649408, 0.0009537928272038698, 0.00016240555851254612, 0.34351256489753723, -0.000304237735690549, 0.00010320253932150081, -0.0053021605126559734, 0.024866778403520584, 0.014384889975190163, -0.002327880123630166], [-0.0023880228400230408, -0.0021727238781750202, -0.0004762755415868014, 0.00043378627742640674, -0.0046735480427742004, 0.0018593238200992346, -0.002654302166774869, 0.0014368478441610932, 0.00030329800210893154, 0.13043102622032166, 8.91401432454586e-05, 0.001177515834569931], [0.0003193405573256314, 0.020570598542690277, 0.0003192398580722511, -0.0025125371757894754, -0.0002626882342156023, -0.0002465785655658692, 0.0005513197393156588, -0.00043110133265145123, 0.00025671423645690084, 0.008090643212199211, -0.0030700992792844772, -0.00042452322668395936], [0.0009762121480889618, 0.0003926395147573203, 0.0017537048552185297, 0.022597799077630043, -4.514062311500311e-05, 0.00014126162568572909, 0.009583035483956337, -0.00031524599762633443, 0.0015267934650182724, 0.0011811405420303345, -0.010774211026728153, 0.009365353733301163], [0.006315559148788452, -0.0010947524569928646, 0.01166168600320816, 0.0013476070016622543, -0.029189303517341614, 0.0038326906505972147, -0.04408983886241913, -0.005032323766499758, 0.004822293762117624, 0.2766503691673279, -3.211864532204345e-05, -0.0006622750661335886], [0.09538958966732025, 0.025066403672099113, 0.014239265583455563, 0.014754808507859707, 9.89737527561374e-05, -9.024768223753199e-05, 0.05082616209983826, -0.505113422870636, 0.00014716849545948207, -0.0016017034649848938, 0.06882696598768234, 0.002326973946765065], [0.0013423713389784098, 0.009629988111555576, -0.07776373624801636, -0.00772992055863142, -0.0005737390019930899, -0.002957197604700923, -0.004948989953845739, 0.00045855488860979676, -0.0006327406736090779, -0.006520554888993502, -0.3204990029335022, -0.002472867025062442]]}
5 changes: 4 additions & 1 deletion transformer_lens/ActivationCache.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ def items(self):
def __iter__(self):
return self.cache_dict.__iter__()

def __len__(self):
return len(self.cache_dict)

def apply_slice_to_batch_dim(self, batch_slice: Union[Slice, SliceInput]):
if not isinstance(batch_slice, Slice):
batch_slice = Slice(batch_slice)
Expand Down Expand Up @@ -269,7 +272,7 @@ def stack_head_results(
incl_remainder: bool = False,
pos_slice: Union[Slice, SliceInput] = None,
) -> TT[T.num_components, T.batch_and_pos_dims:..., T.d_model]:
"""Returns a stack of all head results (ie residual stream contribution) up to layer L. A good way to decompose the outputs of attention layers into attribution by specific heads.
"""Returns a stack of all head results (ie residual stream contribution) up to layer L. A good way to decompose the outputs of attention layers into attribution by specific heads. Note that the num_components axis has length layer x n_heads ((layer head_index) in einops notation)
Assumes that the model has been run with use_attn_results=True
Expand Down
4 changes: 2 additions & 2 deletions transformer_lens/patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,8 @@ def get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache,
act_patch_results = []
act_patch_results.append(get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, metric))
act_patch_results.append(get_act_patch_attn_head_q_all_pos(model, corrupted_tokens, clean_cache, metric))
act_patch_results.append(get_act_patch_attn_head_v_all_pos(model, corrupted_tokens, clean_cache, metric))
act_patch_results.append(get_act_patch_attn_head_k_all_pos(model, corrupted_tokens, clean_cache, metric))
act_patch_results.append(get_act_patch_attn_head_v_all_pos(model, corrupted_tokens, clean_cache, metric))
act_patch_results.append(get_act_patch_attn_head_pattern_all_pos(model, corrupted_tokens, clean_cache, metric))
return torch.stack(act_patch_results, dim=0)

Expand All @@ -362,8 +362,8 @@ def get_act_patch_attn_head_by_pos_every(model, corrupted_tokens, clean_cache, m
act_patch_results = []
act_patch_results.append(get_act_patch_attn_head_out_by_pos(model, corrupted_tokens, clean_cache, metric))
act_patch_results.append(get_act_patch_attn_head_q_by_pos(model, corrupted_tokens, clean_cache, metric))
act_patch_results.append(get_act_patch_attn_head_v_by_pos(model, corrupted_tokens, clean_cache, metric))
act_patch_results.append(get_act_patch_attn_head_k_by_pos(model, corrupted_tokens, clean_cache, metric))
act_patch_results.append(get_act_patch_attn_head_v_by_pos(model, corrupted_tokens, clean_cache, metric))

# Reshape pattern to be compatible with the rest of the results
pattern_results = (get_act_patch_attn_head_pattern_by_pos(model, corrupted_tokens, clean_cache, metric))
Expand Down

0 comments on commit 407f5fe

Please sign in to comment.