Skip to content

Commit

Permalink
Device aware cache (#250)
Browse files Browse the repository at this point in the history
* Add hook example

* Fix docstrings

* Convert device to torch.device

* Respect provided cfg device

* Keep cached tensors on device

* Test cache device

* Handle sequence on GPU

* Keep tensor on device

* Test plain cuda

* Update test_multi_gpu.py

fix spelling of cache

---------

Co-authored-by: Joseph Bloom <[email protected]>
  • Loading branch information
slavachalnev and jbloomAus authored Apr 24, 2023
1 parent 3b30122 commit 25a9c07
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 16 deletions.
34 changes: 32 additions & 2 deletions tests/acceptance/test_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,17 @@ def test_device_separation_and_cache(gpt2_medium_on_1_device, n_devices):
gpt2_logits_n_devices, gpt2_cache_n_devices = model_n_devices.run_with_cache(
gpt2_tokens, remove_batch_dim=True)

# Make sure the tensors in cache remain on their respective devices
for i in range(model_n_devices.cfg.n_layers):
expected_device = get_device_for_block_index(i, cfg=model_n_devices.cfg)
cache_device = gpt2_cache_n_devices[f"blocks.{i}.mlp.hook_post"].device
assert cache_device == expected_device

assert torch.allclose(gpt2_logits_1_device.to("cpu"),
gpt2_logits_n_devices.to("cpu"))
for key in gpt2_cache_1_device.keys():
assert torch.allclose(gpt2_cache_1_device[key],
gpt2_cache_n_devices[key])
assert torch.allclose(gpt2_cache_1_device[key].to("cpu"),
gpt2_cache_n_devices[key].to("cpu"))

cuda_devices = set()
n_params_on_device = {}
Expand All @@ -114,3 +120,27 @@ def test_device_separation_and_cache(gpt2_medium_on_1_device, n_devices):
print(
f"Number of devices: {n_devices}, Model loss (1 device): {loss_1_device}, Model loss ({n_devices} devices): {loss_n_devices}, Time taken (1 device): {elapsed_time_1_device:.4f} seconds, Time taken ({n_devices} devices): {elapsed_time_n_devices:.4f} seconds"
)

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 CUDA devices")
def test_cache_device():
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda:1")

logits, cache = model.run_with_cache("Hello there")
assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(torch.device("cuda:1"))

logits, cache = model.run_with_cache("Hello there", device="cpu")
assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(torch.device("cpu"))

model.to("cuda")
logits, cache = model.run_with_cache("Hello there")
assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(logits.device)


def norm_device(device):
"""
Convenience function to normalize device strings for comparison.
"""
device_str = str(device)
if device_str.startswith("cuda") and ':' not in device_str:
device_str += ':0'
return device_str
7 changes: 4 additions & 3 deletions transformer_lens/head_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,10 @@ def get_duplicate_token_head_detection_pattern(model: HookedTransformer, sequenc
model: Model being used.
sequence: String being fed to the model."""

sequence = model.to_tokens(sequence)
token_pattern = [np.array(sequence) for i in range(sequence.shape[-1])]
token_pattern = np.concatenate(token_pattern, axis=0)
sequence = model.to_tokens(sequence).detach().cpu()

# Repeat sequence to create a square matrix.
token_pattern = sequence.repeat(sequence.shape[-1], 1).numpy()

# If token_pattern[i][j] matches its transpose, then token j and token i are duplicates.
eq_mask = np.equal(token_pattern, token_pattern.T).astype(int)
Expand Down
11 changes: 3 additions & 8 deletions transformer_lens/hook_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,15 +280,13 @@ def add_caching_hooks(
Args:
names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True.
incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False.
device (_type_, optional): The device to store on. Defaults to CUDA if available else CPU.
device (_type_, optional): The device to store on. Defaults to same device as model.
remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None.
Returns:
cache (dict): The cache where activations will be stored.
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if cache is None:
cache = {}

Expand Down Expand Up @@ -340,8 +338,7 @@ def run_with_cache(
list of str, or a function that takes a string and returns a bool. Defaults to None, which
means cache everything.
device (str or torch.Device, optional): The device to cache activations on. Defaults to the
model device. Note that this must be set if the model does not have a model.cfg.device
attribute. WARNING: Setting a different device than the one used by the model leads to
model device. WARNING: Setting a different device than the one used by the model leads to
significant performance degradation.
remove_batch_dim (bool, optional): If True, removes the batch dimension when caching. Only
makes sense with batch_size=1 inputs. Defaults to False.
Expand Down Expand Up @@ -382,7 +379,7 @@ def get_caching_hooks(
Args:
names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True.
incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False.
device (_type_, optional): The device to store on. Defaults to CUDA if available else CPU.
device (_type_, optional): The device to store on. Keeps on the same device as the layer if None.
remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None.
Expand All @@ -391,8 +388,6 @@ def get_caching_hooks(
fwd_hooks (list): The forward hooks.
bwd_hooks (list): The backward hooks. Empty if incl_bwd is False.
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if cache is None:
cache = {}

Expand Down
21 changes: 18 additions & 3 deletions transformer_lens/utilities/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,25 @@ def get_device_for_block_index(
cfg: HookedTransformerConfig,
device: Optional[Union[torch.device, str]] = None,
):
"""
Determine the device for a given layer index based on the model configuration.
This function assists in distributing model layers across multiple devices. The distribution
is based on the configuration's number of layers (cfg.n_layers) and devices (cfg.n_devices).
Args:
index (int): Model layer index.
cfg (HookedTransformerConfig): Model and device configuration.
device (Optional[Union[torch.device, str]], optional): Initial device used for determining the target device.
If not provided, the function uses the device specified in the configuration (cfg.device).
Returns:
torch.device: The device for the specified layer index.
"""
assert cfg.device is not None
layers_per_device = cfg.n_layers // cfg.n_devices
if device is None:
device = cfg.device
if isinstance(device, torch.device):
device = device.type
return torch.device(device, index // layers_per_device)
device = torch.device(device)
device_index = (device.index or 0) + (index // layers_per_device)
return torch.device(device.type, device_index)

0 comments on commit 25a9c07

Please sign in to comment.