From c5559f48305a303c8c3aae1576f192e0ead3faf6 Mon Sep 17 00:00:00 2001 From: Zain Huda Date: Thu, 19 Sep 2024 12:19:10 -0700 Subject: [PATCH] add optimizer to notebook (#2405) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2405 add optimizer to tutorial notebook Differential Revision: D63041649 --- ...active_Tutorial_Notebook_OSS_version.ipynb | 330 +++++++++++++----- 1 file changed, 249 insertions(+), 81 deletions(-) diff --git a/TorchRec_Interactive_Tutorial_Notebook_OSS_version.ipynb b/TorchRec_Interactive_Tutorial_Notebook_OSS_version.ipynb index 807e8c89e..87dec0c4e 100644 --- a/TorchRec_Interactive_Tutorial_Notebook_OSS_version.ipynb +++ b/TorchRec_Interactive_Tutorial_Notebook_OSS_version.ipynb @@ -100,7 +100,7 @@ "!pip3 install torchmetrics==1.0.3\n", "!pip3 install torchrec --index-url https://download.pytorch.org/whl/nightly/cu121" ], - "execution_count": 47, + "execution_count": null, "outputs": [] }, { @@ -206,7 +206,7 @@ "base_uri": "https://localhost:8080/" }, "id": "1X5C_Dnccso-", - "outputId": "6ee71ccd-7857-4e20-d047-6a02e460f47a" + "outputId": "6a88f771-2f9b-4d1d-f84f-7e0be9715efc" }, "source": [ "num_embeddings, embedding_dim = 10, 4\n", @@ -221,16 +221,16 @@ "output_type": "stream", "name": "stdout", "text": [ - "Weights: tensor([[0.1072, 0.9457, 0.9209, 0.7171],\n", - " [0.0412, 0.2312, 0.3200, 0.5536],\n", - " [0.1699, 0.1457, 0.9400, 0.4208],\n", - " [0.3612, 0.6193, 0.9533, 0.0962],\n", - " [0.2410, 0.8711, 0.1048, 0.9601],\n", - " [0.7713, 0.3515, 0.4125, 0.2535],\n", - " [0.3944, 0.0446, 0.9126, 0.9890],\n", - " [0.2477, 0.8815, 0.6849, 0.5373],\n", - " [0.3581, 0.9593, 0.6951, 0.2933],\n", - " [0.3735, 0.6325, 0.1342, 0.9888]])\n" + "Weights: tensor([[0.0689, 0.7306, 0.9038, 0.1236],\n", + " [0.7800, 0.4969, 0.4306, 0.9177],\n", + " [0.3871, 0.5991, 0.1885, 0.6965],\n", + " [0.1844, 0.9105, 0.4754, 0.2711],\n", + " [0.7920, 0.9244, 0.7218, 0.4849],\n", + " [0.8046, 0.1542, 0.0744, 0.2792],\n", + " [0.4305, 0.8207, 0.9054, 0.2821],\n", + " [0.7047, 0.8777, 0.0902, 0.4385],\n", + " [0.6134, 0.4716, 0.8802, 0.5822],\n", + " [0.3582, 0.1362, 0.3836, 0.3505]])\n" ] } ] @@ -255,7 +255,7 @@ "base_uri": "https://localhost:8080/" }, "id": "bxszzeGdcso-", - "outputId": "a91d8d5e-c8f1-45f5-dbac-83c56c639e17" + "outputId": "eefe0891-2a98-47b3-c54f-0a6d6e066c71" }, "source": [ "# Pass in pre generated weights just for example, typically weights are randomly initialized\n", @@ -282,27 +282,27 @@ "name": "stdout", "text": [ "Embedding Collection Table: Parameter containing:\n", - "tensor([[0.1072, 0.9457, 0.9209, 0.7171],\n", - " [0.0412, 0.2312, 0.3200, 0.5536],\n", - " [0.1699, 0.1457, 0.9400, 0.4208],\n", - " [0.3612, 0.6193, 0.9533, 0.0962],\n", - " [0.2410, 0.8711, 0.1048, 0.9601],\n", - " [0.7713, 0.3515, 0.4125, 0.2535],\n", - " [0.3944, 0.0446, 0.9126, 0.9890],\n", - " [0.2477, 0.8815, 0.6849, 0.5373],\n", - " [0.3581, 0.9593, 0.6951, 0.2933],\n", - " [0.3735, 0.6325, 0.1342, 0.9888]], requires_grad=True)\n", + "tensor([[0.0689, 0.7306, 0.9038, 0.1236],\n", + " [0.7800, 0.4969, 0.4306, 0.9177],\n", + " [0.3871, 0.5991, 0.1885, 0.6965],\n", + " [0.1844, 0.9105, 0.4754, 0.2711],\n", + " [0.7920, 0.9244, 0.7218, 0.4849],\n", + " [0.8046, 0.1542, 0.0744, 0.2792],\n", + " [0.4305, 0.8207, 0.9054, 0.2821],\n", + " [0.7047, 0.8777, 0.0902, 0.4385],\n", + " [0.6134, 0.4716, 0.8802, 0.5822],\n", + " [0.3582, 0.1362, 0.3836, 0.3505]], requires_grad=True)\n", "Embedding Bag Collection Table: Parameter containing:\n", - "tensor([[0.1072, 0.9457, 0.9209, 0.7171],\n", - " [0.0412, 0.2312, 0.3200, 0.5536],\n", - " [0.1699, 0.1457, 0.9400, 0.4208],\n", - " [0.3612, 0.6193, 0.9533, 0.0962],\n", - " [0.2410, 0.8711, 0.1048, 0.9601],\n", - " [0.7713, 0.3515, 0.4125, 0.2535],\n", - " [0.3944, 0.0446, 0.9126, 0.9890],\n", - " [0.2477, 0.8815, 0.6849, 0.5373],\n", - " [0.3581, 0.9593, 0.6951, 0.2933],\n", - " [0.3735, 0.6325, 0.1342, 0.9888]], requires_grad=True)\n", + "tensor([[0.0689, 0.7306, 0.9038, 0.1236],\n", + " [0.7800, 0.4969, 0.4306, 0.9177],\n", + " [0.3871, 0.5991, 0.1885, 0.6965],\n", + " [0.1844, 0.9105, 0.4754, 0.2711],\n", + " [0.7920, 0.9244, 0.7218, 0.4849],\n", + " [0.8046, 0.1542, 0.0744, 0.2792],\n", + " [0.4305, 0.8207, 0.9054, 0.2821],\n", + " [0.7047, 0.8777, 0.0902, 0.4385],\n", + " [0.6134, 0.4716, 0.8802, 0.5822],\n", + " [0.3582, 0.1362, 0.3836, 0.3505]], requires_grad=True)\n", "Input row IDS: tensor([[1, 3]])\n" ] } @@ -328,7 +328,7 @@ "base_uri": "https://localhost:8080/" }, "id": "xkedJeTOcso_", - "outputId": "5f77e32a-563c-47f2-855e-d6c0818acadb" + "outputId": "ff463e9a-d289-4f84-867c-ebc1661618a9" }, "source": [ "embeddings = embedding_collection(ids)\n", @@ -346,8 +346,8 @@ "name": "stdout", "text": [ "Embedding Collection Results: \n", - "tensor([[[0.0412, 0.2312, 0.3200, 0.5536],\n", - " [0.3612, 0.6193, 0.9533, 0.0962]]], grad_fn=)\n", + "tensor([[[0.7800, 0.4969, 0.4306, 0.9177],\n", + " [0.1844, 0.9105, 0.4754, 0.2711]]], grad_fn=)\n", "Shape: torch.Size([1, 2, 4])\n" ] } @@ -373,7 +373,7 @@ "base_uri": "https://localhost:8080/" }, "id": "PmtJkxLccso_", - "outputId": "5ca1b383-1c6a-4b32-91e3-ef3735821452" + "outputId": "fc053ab7-3755-43d6-dab4-a211648e1cc5" }, "source": [ "# nn.EmbeddingBag default pooling is mean, so should be mean of batch dimension of values above\n", @@ -394,9 +394,9 @@ "name": "stdout", "text": [ "Embedding Bag Collection Results: \n", - "tensor([[0.2012, 0.4252, 0.6367, 0.3249]], grad_fn=)\n", + "tensor([[0.4822, 0.7037, 0.4530, 0.5944]], grad_fn=)\n", "Shape: torch.Size([1, 4])\n", - "Mean: tensor([[0.2012, 0.4252, 0.6367, 0.3249]], grad_fn=)\n" + "Mean: tensor([[0.4822, 0.7037, 0.4530, 0.5944]], grad_fn=)\n" ] } ] @@ -510,7 +510,7 @@ "colab": { "base_uri": "https://localhost:8080/" }, - "outputId": "2fadb955-6f0c-4ffb-ec0b-bb9535dd2888" + "outputId": "619d3560-ef27-486c-b033-6b2a7adfae53" }, "source": [ "ebc = torchrec.EmbeddingBagCollection(\n", @@ -582,7 +582,7 @@ "base_uri": "https://localhost:8080/" }, "id": "UuIrEWupcspA", - "outputId": "94c0b3b5-247a-47d9-b3bd-aff2c374dcdc" + "outputId": "3d594ba0-1c11-48aa-d794-9d1320be83f8" }, "source": [ "import inspect\n", @@ -716,7 +716,7 @@ "base_uri": "https://localhost:8080/" }, "id": "t5T5S8_mcspB", - "outputId": "b7c61b24-7a09-4788-82b1-df05d8ece4d5" + "outputId": "f7773d10-3723-4e31-fb57-b7301d210de2" }, "source": [ "# Lengths can be converted to offsets for easy indexing of values\n", @@ -762,7 +762,7 @@ "base_uri": "https://localhost:8080/" }, "id": "2OOK2BBecspB", - "outputId": "56caf27f-a7f6-4af1-c42c-022f938bc68b" + "outputId": "4bead533-0978-41ff-95e1-9423d87ed4ed" }, "source": [ "from torchrec import JaggedTensor\n", @@ -815,7 +815,7 @@ "base_uri": "https://localhost:8080/" }, "id": "fs10Fxu2cspB", - "outputId": "8f675894-cf26-4fdf-e839-7624d1ed78a8" + "outputId": "30ca7527-271b-45d3-9f3c-06d4c76b6ab4" }, "source": [ "from torchrec import KeyedJaggedTensor\n", @@ -858,7 +858,7 @@ "Keys: ['product', 'user']\n", "Lengths: tensor([3, 1, 2, 2])\n", "Values: tensor([1, 2, 1, 5, 2, 3, 4, 1])\n", - "to_dict: {'product': , 'user': }\n", + "to_dict: {'product': , 'user': }\n", "KeyedJaggedTensor({\n", " \"product\": [[1, 2, 1], [5]],\n", " \"user\": [[2, 3], [4, 1]]\n", @@ -888,7 +888,7 @@ "base_uri": "https://localhost:8080/" }, "id": "JeLwyCNRcspB", - "outputId": "d3afa2c0-e28a-4de2-b1b7-927bcf276776" + "outputId": "14e21af4-930d-4759-bf77-ea9bee0060ef" }, "source": [ "# Now we can run a forward pass on our ebc from before\n", @@ -901,7 +901,7 @@ "output_type": "execute_result", "data": { "text/plain": [ - "" + "" ] }, "metadata": {}, @@ -929,7 +929,7 @@ "base_uri": "https://localhost:8080/" }, "id": "R2K4v2vqcspB", - "outputId": "c2e72ff0-b0d2-4631-abfc-a0fab83fc2a3" + "outputId": "d8b9dc90-0093-4904-b33d-fc51fef49e33" }, "source": [ "# Result is a KeyedTensor, which contains a list of the feature names and the embedding results\n", @@ -1012,7 +1012,7 @@ "colab": { "base_uri": "https://localhost:8080/" }, - "outputId": "d670cd23-7bd8-48de-bb89-ef957c3ea86a" + "outputId": "12af6742-bab4-4e70-abd1-d2e886ec1bad" }, "source": [ "# Here we set up our torch distributed environment\n", @@ -1115,7 +1115,7 @@ "base_uri": "https://localhost:8080/" }, "id": "FX65VcQ6cspB", - "outputId": "7e7148c5-7194-4a4d-9580-19f6811783a8" + "outputId": "de29a2cc-13bd-4077-a4cd-e79afc503281" }, "source": [ "# Refresher of our EmbeddingBagCollection module\n", @@ -1160,7 +1160,7 @@ "base_uri": "https://localhost:8080/" }, "id": "1hSzTg4pcspC", - "outputId": "be6bc55e-2742-4940-e2a6-ad9f8eb02198" + "outputId": "06d3f374-5ab3-488b-b626-a21bbaca7520" }, "source": [ "from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder\n", @@ -1182,7 +1182,7 @@ "output_type": "stream", "name": "stdout", "text": [ - "Process Group: \n" + "Process Group: \n" ] } ] @@ -1241,7 +1241,7 @@ "base_uri": "https://localhost:8080/" }, "id": "PQeXnuAGcspC", - "outputId": "03670001-a4ef-43ca-bddf-1f6cc3e12a0b" + "outputId": "3350d389-a3a0-48bb-e7bc-2f0917168695" }, "source": [ "# In our case, 1 GPU and compute on CUDA device\n", @@ -1315,7 +1315,7 @@ "base_uri": "https://localhost:8080/" }, "id": "JIci5Gz6cspC", - "outputId": "a9675129-080e-46b5-927d-721491bd3604" + "outputId": "cbf5b01d-929f-4027-bb71-db5b429ff12d" }, "source": [ "# The static plan that was generated\n", @@ -1355,7 +1355,7 @@ "base_uri": "https://localhost:8080/" }, "id": "2__Do2tqcspC", - "outputId": "cffbb5ae-f4ff-4173-d025-6efd443d036a" + "outputId": "03e23c20-e468-4ddb-f66d-ea2bf92f19f8" }, "source": [ "env = ShardingEnv.from_process_group(pg)\n", @@ -1391,6 +1391,13 @@ } ] }, + { + "cell_type": "markdown", + "source": [], + "metadata": { + "id": "ErXXbYzJmVzI" + } + }, { "cell_type": "markdown", "metadata": { @@ -1426,7 +1433,7 @@ "base_uri": "https://localhost:8080/" }, "id": "rwYzKwyNcspC", - "outputId": "ac678a80-cd99-4c88-a3a0-e8a7e38c73d2" + "outputId": "d0bed3b7-5c10-494d-f623-4893a41407d4" }, "source": [ "from typing import List\n", @@ -1483,7 +1490,7 @@ "base_uri": "https://localhost:8080/" }, "id": "cs41RfzGcspC", - "outputId": "f3f27715-6fec-4465-c962-569ee9ed7a01" + "outputId": "fce85636-9995-411a-e7a5-448977bd0257" }, "source": [ "kjt = kjt.to(\"cuda\")\n", @@ -1497,7 +1504,7 @@ "output_type": "stream", "name": "stdout", "text": [ - "\n" + "\n" ] } ] @@ -1522,7 +1529,7 @@ "base_uri": "https://localhost:8080/" }, "id": "_1sdt75rcspG", - "outputId": "e1288888-2584-4185-f851-670df9ba303c" + "outputId": "a1f49742-64c4-4774-ef3b-c0c81ef2f7bb" }, "source": [ "kt = output.wait()\n", @@ -1555,6 +1562,167 @@ } ] }, + { + "cell_type": "markdown", + "source": [ + "### Adding in the Optimizer\n", + "\n", + "You may have noticed we're missing one key part of model training, the optimizer! The recommended way to apply the optimizer in TorchRec is through PyTorch's `apply_optimizer_in_backward` API. This gives you parameter level granualtiy and as such lets you define different optimizers or optimizer parameters for each paramter in your model.\n", + "\n", + "With this level of control, you can define a different optimizer for a set of parameters or for embedding_bags as a whole. A key thing to note here, you need to be aware of is the parameter naming. EmbeddingBagCollections will be named something similar to `embedding_bags....` and EmbeddingCollections `embeddings...`. Inspect the model parameters through `model.named_parameters()` to understand how your model parameters are named.\n", + "\n", + "TorchRec uses `CombinedOptimizer` which contains `KeyedOptimizers` within. A `CombinedOptimizer` effectively makes it easy to handle multiple optimizers for various sub groups in the model. A `KeyedOptimizer` extends the `torch.optim.Optimizer` and is initialized through a dictionary of parameters exposes the parameters. Each `TBE` module in a `EmbeddingBagCollection` will have it's own KeyedOptimizer which all combined into one `CombinedOptimizer`.\n", + "\n", + "#### Fused optimizer in TorchRec\n", + "\n", + "Using DistributedModelParallel, the optimizer is fused, which means that the optimizer update is done in the backward. Hence the term \"fused\". This is an optimization in TorchRec and FBGEMM, the optimizer embedding gradients are not materialized and direclty applied to the parameters. This brings significant memory savings are embedding gradients are typically size of the parameters themselves. You can, however, choose to make the optimizer \"dense\" which does not apply this optimization and let's you inspect the embedding gradients or apply computations to it as you wish. A dense optimizer in this case would be your [canonical PyTorch model training loop with optimizer.](https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html)\n", + "\n", + "Once the optimizer is created through DistributedModelParallel, you still need to manage an optimizer for the dense parameters. This you will do through the canonical PyTorch training loop, to find the dense parameters, you can call `in_backward_optimizer_filter(model.named_parameters())`. Apply the dense optimizer to those parameters as you would a normal Torch optimizer and combine this and the `model.fused_optimizer` into one `CombinedOptimizer` that you can use in your training loop to `zero_grad` and `step` through.\n", + "\n", + "#### Let's add an optimizer to our EmbeddingBagCollection\n", + "We will do this in two ways, which are equivalent, but give you options depending on your preferences\n", + "1. Passing optimizer kwargs through fused parameters (fused_params) in sharder\n", + "2. Through `apply_optimizer_in_backward`\n", + "Note: `apply_optimizer_in_backward` converts the optimizer parameters to `fused_params` to pass to the `TBE` in the `EmbeddingBagCollection`/`EmbeddingCollection`." + ], + "metadata": { + "id": "zFhggkUCmd7f" + } + }, + { + "cell_type": "code", + "source": [ + "# Approach 1: passing optimizer kwargs through fused parameters\n", + "from torchrec.optim.optimizers import in_backward_optimizer_filter\n", + "from fbgemm_gpu.split_embedding_configs import EmbOptimType\n", + "\n", + "\n", + "# We initialize the sharder with\n", + "fused_params = {\n", + " \"optimizer\": EmbOptimType.EXACT_ROWWISE_ADAGRAD,\n", + " \"learning_rate\": 0.02,\n", + " \"eps\": 0.002,\n", + "}\n", + "\n", + "# Init sharder with fused_params\n", + "sharder_with_fused_params = EmbeddingBagCollectionSharder(fused_params=fused_params)\n", + "\n", + "# We'll use same plan and unsharded EBC as before but this time with our new sharder\n", + "sharded_ebc_fused_params = sharder_with_fused_params.shard(ebc, plan.plan[\"\"], env, torch.device(\"cuda\"))\n", + "\n", + "# Looking at the optimizer of each, we can see that the learning rate changed, which indicates our optimizer has been applied correclty.\n", + "# If seen, we can also look at the TBE logs of the cell to see that our new optimizer is indeed being applied\n", + "print(f\"Original Sharded EBC fused optimizer: {sharded_ebc.fused_optimizer}\")\n", + "print(f\"Sharded EBC with fused parameters fused optimizer: {sharded_ebc_fused_params.fused_optimizer}\")\n", + "\n", + "# We can also check through the filter, we set include=True to show us the parameters that have an optimizer applied to them\n", + "dict(in_backward_optimizer_filter(sharded_ebc_fused_params.named_parameters(), include=True))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "h5BCEFidmnEw", + "outputId": "46069e79-af2c-434e-c48e-4fa1a92c29ef" + }, + "execution_count": 30, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Original Sharded EBC fused optimizer: : EmbeddingFusedOptimizer (\n", + "Parameter Group 0\n", + " lr: 0.01\n", + ")\n", + "Sharded EBC with fused parameters fused optimizer: : EmbeddingFusedOptimizer (\n", + "Parameter Group 0\n", + " lr: 0.02\n", + ")\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{'embedding_bags.product_table.weight': Parameter containing:\n", + " Parameter(TableBatchedEmbeddingSlice([[-0.0087, -0.0091, 0.0007, ..., -0.0147,\n", + " -0.0050, 0.0064],\n", + " [ 0.0102, -0.0059, -0.0082, ..., 0.0112,\n", + " -0.0076, 0.0103],\n", + " [-0.0122, -0.0123, -0.0143, ..., -0.0108,\n", + " -0.0137, -0.0038],\n", + " ...,\n", + " [-0.0073, -0.0144, -0.0130, ..., -0.0135,\n", + " -0.0074, 0.0020],\n", + " [-0.0048, -0.0110, -0.0124, ..., 0.0110,\n", + " -0.0121, 0.0153],\n", + " [-0.0134, -0.0144, 0.0073, ..., -0.0058,\n", + " -0.0027, 0.0006]], device='cuda:0')),\n", + " 'embedding_bags.user_table.weight': Parameter containing:\n", + " Parameter(TableBatchedEmbeddingSlice([[-0.0019, -0.0022, -0.0048, ..., 0.0086,\n", + " -0.0152, 0.0050],\n", + " [-0.0088, -0.0128, -0.0059, ..., 0.0003,\n", + " -0.0065, -0.0123],\n", + " [-0.0006, 0.0082, -0.0072, ..., -0.0120,\n", + " -0.0070, -0.0017],\n", + " ...,\n", + " [ 0.0065, -0.0024, 0.0130, ..., 0.0118,\n", + " 0.0059, -0.0124],\n", + " [ 0.0012, 0.0025, -0.0111, ..., -0.0152,\n", + " -0.0040, 0.0050],\n", + " [ 0.0081, -0.0097, 0.0031, ..., -0.0064,\n", + " 0.0093, -0.0048]], device='cuda:0'))}" + ] + }, + "metadata": {}, + "execution_count": 30 + } + ] + }, + { + "cell_type": "code", + "source": [ + "from torch.distributed.optim import _apply_optimizer_in_backward as apply_optimizer_in_backward\n", + "import copy\n", + "# Approach 2: applying optimizer through apply_optimizer_in_backward\n", + "# Note: we need to call apply_optimizer_in_backward on unsharded model first and then shard it\n", + "\n", + "# We can achieve the same result as we did in the previous\n", + "ebc_apply_opt = copy.deepcopy(ebc)\n", + "optimizer_kwargs = {\"lr\": 0.5}\n", + "\n", + "for name, param in ebc_apply_opt.named_parameters():\n", + " print(f\"{name=}\")\n", + " apply_optimizer_in_backward(torch.optim.SGD, [param], optimizer_kwargs)\n", + "\n", + "sharded_ebc_apply_opt = sharder.shard(ebc_apply_opt, plan.plan[\"\"], env, torch.device(\"cuda\"))\n", + "# Now when we print the optimizer, we will see our new learning rate, you can verify momentum through the TBE logs as well if outputted\n", + "print(sharded_ebc_apply_opt.fused_optimizer)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "T-xx724MmoKv", + "outputId": "f08b83cb-51b6-4890-db8c-bc7132a002b9" + }, + "execution_count": 29, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "name='embedding_bags.product_table.weight'\n", + "name='embedding_bags.user_table.weight'\n", + ": EmbeddingFusedOptimizer (\n", + "Parameter Group 0\n", + " lr: 0.5\n", + ")\n" + ] + } + ] + }, { "cell_type": "markdown", "metadata": { @@ -1607,7 +1775,7 @@ "source": [ "sharded_ebc" ], - "execution_count": 25, + "execution_count": null, "outputs": [ { "output_type": "execute_result", @@ -1668,7 +1836,7 @@ "# Distribute input KJTs to all other GPUs and receive KJTs\n", "sharded_ebc._input_dists" ], - "execution_count": 26, + "execution_count": null, "outputs": [ { "output_type": "execute_result", @@ -1710,7 +1878,7 @@ "# Distribute output embeddingts to all other GPUs and receive embeddings\n", "sharded_ebc._output_dists" ], - "execution_count": 27, + "execution_count": null, "outputs": [ { "output_type": "execute_result", @@ -1774,7 +1942,7 @@ "source": [ "sharded_ebc._lookups" ], - "execution_count": 28, + "execution_count": null, "outputs": [ { "output_type": "execute_result", @@ -1840,7 +2008,7 @@ "source": [ "ebc" ], - "execution_count": 29, + "execution_count": null, "outputs": [ { "output_type": "execute_result", @@ -1884,7 +2052,7 @@ "source": [ "model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device(\"cuda\"))" ], - "execution_count": 30, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -1921,7 +2089,7 @@ "out = model(kjt)\n", "out.wait()" ], - "execution_count": 31, + "execution_count": null, "outputs": [ { "output_type": "execute_result", @@ -1960,7 +2128,7 @@ "source": [ "model" ], - "execution_count": 32, + "execution_count": null, "outputs": [ { "output_type": "execute_result", @@ -2066,7 +2234,7 @@ "source": [ "ebc" ], - "execution_count": 33, + "execution_count": null, "outputs": [ { "output_type": "execute_result", @@ -2109,7 +2277,7 @@ " def forward(self, kjt: KeyedJaggedTensor):\n", " return self.ebc_(kjt)" ], - "execution_count": 34, + "execution_count": null, "outputs": [] }, { @@ -2141,7 +2309,7 @@ " # FP32 is default, regularly used for training\n", " print(name, param.shape, param.dtype)" ], - "execution_count": 35, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -2233,7 +2401,7 @@ "\n", "print(f\"Quantized EBC: {qebc}\")" ], - "execution_count": 36, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -2270,7 +2438,7 @@ "source": [ "kjt = kjt.to(\"cpu\")" ], - "execution_count": 37, + "execution_count": null, "outputs": [] }, { @@ -2298,7 +2466,7 @@ "source": [ "qebc(kjt)" ], - "execution_count": 38, + "execution_count": null, "outputs": [ { "output_type": "execute_result", @@ -2341,7 +2509,7 @@ " # post quantization\n", " print(name, buffer.shape, buffer.dtype)" ], - "execution_count": 39, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -2408,7 +2576,7 @@ "\n", "print(f\"Sharded Quantized EBC: {sharded_qebc}\")" ], - "execution_count": 40, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -2462,7 +2630,7 @@ "source": [ "sharded_qebc(kjt)" ], - "execution_count": 41, + "execution_count": null, "outputs": [ { "output_type": "execute_result", @@ -2526,7 +2694,7 @@ "\n", "print(\"Graph Module Created!\")" ], - "execution_count": 42, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -2562,7 +2730,7 @@ "source": [ "print(gm.code)" ], - "execution_count": 43, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -2632,7 +2800,7 @@ "scripted_gm = torch.jit.script(gm)\n", "print(\"Scripted Graph Module Created!\")" ], - "execution_count": 44, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -2676,7 +2844,7 @@ "source": [ "print(scripted_gm.code)" ], - "execution_count": 46, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -2753,4 +2921,4 @@ ] } ] -} +} \ No newline at end of file