Skip to content

Commit

Permalink
update lightning imports
Browse files Browse the repository at this point in the history
  • Loading branch information
ash0ts committed Dec 4, 2023
1 parent c63bb1b commit d57448e
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
"Coupled with the [Weights & Biases integration](https://docs.wandb.com/library/integrations/lightning), you can quickly train and monitor models for full traceability and reproducibility with only 2 extra lines of code:\n",
"\n",
"```python\n",
"from pytorch_lightning.loggers import WandbLogger\n",
"from pytorch_lightning import Trainer\n",
"from lightning.pytorch.loggers import WandbLogger\n",
"from lightning.pytorch import Trainer\n",
"\n",
"wandb_logger = WandbLogger()\n",
"trainer = Trainer(logger=wandb_logger)\n",
Expand Down Expand Up @@ -64,7 +64,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install -q pytorch-lightning wandb"
"!pip install -q lightning wandb torchvision"
]
},
{
Expand Down Expand Up @@ -150,6 +150,15 @@
"* Call self.log in `training_step` and `validation_step` to log the metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import lightning.pytorch as pl"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -160,9 +169,8 @@
"from torch.nn import Linear, CrossEntropyLoss, functional as F\n",
"from torch.optim import Adam\n",
"from torchmetrics.functional import accuracy\n",
"from pytorch_lightning import LightningModule\n",
"\n",
"class MNIST_LitModule(LightningModule):\n",
"class MNIST_LitModule(pl.LightningModule):\n",
"\n",
" def __init__(self, n_classes=10, n_layer_1=128, n_layer_2=256, lr=1e-3):\n",
" '''method used to define our model parameters'''\n",
Expand Down Expand Up @@ -273,7 +281,7 @@
"metadata": {},
"outputs": [],
"source": [
"from pytorch_lightning.callbacks import ModelCheckpoint\n",
"from lightning.pytorch.callbacks import ModelCheckpoint\n",
"\n",
"checkpoint_callback = ModelCheckpoint(monitor='val_accuracy', mode='max')"
]
Expand All @@ -284,9 +292,9 @@
"source": [
"## 💡 Tracking Experiments with WandbLogger\n",
"\n",
"PyTorch Lightning has a `WandbLogger` to easily log your experiments with Wights & Biases. Just pass it to your `Trainer` to log to W&B. See the [WandbLogger docs](https://pytorch-lightning.readthedocs.io/en/stable/extensions/generated/pytorch_lightning.loggers.WandbLogger.html#pytorch_lightning.loggers.WandbLogger) for all parameters. Note, to log the metrics to a specific W&B Team, pass your Team name to the `entity` argument in `WandbLogger`\n",
"PyTorch Lightning has a `WandbLogger` to easily log your experiments with Wights & Biases. Just pass it to your `Trainer` to log to W&B. See the [WandbLogger docs](https://lightning.ai/docs/pytorch/stable/extensions/generated/pytorch_lightning.loggers.WandbLogger.html#pytorch_lightning.loggers.WandbLogger) for all parameters. Note, to log the metrics to a specific W&B Team, pass your Team name to the `entity` argument in `WandbLogger`\n",
"\n",
"#### `pytorch_lightning.loggers.WandbLogger()`\n",
"#### `lightning.pytorch.loggers.WandbLogger()`\n",
"\n",
"| Functionality | Argument/Function | PS |\n",
"| ------ | ------ | ------ |\n",
Expand All @@ -295,9 +303,9 @@
"| Organize runs by project | `WandbLogger(... ,project='my_project')` | |\n",
"| Log histograms of gradients and parameters | `WandbLogger.watch(model)` | `WandbLogger.watch(model, log='all')` to log parameter histograms |\n",
"| Log hyperparameters | Call `self.save_hyperparameters()` within `LightningModule.__init__()` |\n",
"| Log custom objects (images, audio, video, molecules…) | Use `WandbLogger.log_text`, `WandbLogger.log_image` and `WandbLogger.log_table` |\n",
"| Log custom objects (images, audio, video, molecules…) | Use `WandbLogger.log_text`, `WandbLogger.log_image` and `WandbLogger.log_table`, etc. |\n",
"\n",
"See the [WandbLogger docs](https://pytorch-lightning.readthedocs.io/en/stable/extensions/generated/pytorch_lightning.loggers.WandbLogger.html#pytorch_lightning.loggers.WandbLogger) here for all parameters. "
"See the [WandbLogger docs](https://lightning.ai/docs/pytorch/stable/extensions/generated/pytorch_lightning.loggers.WandbLogger.html#pytorch_lightning.loggers.WandbLogger) here for all parameters. "
]
},
{
Expand All @@ -306,8 +314,8 @@
"metadata": {},
"outputs": [],
"source": [
"from pytorch_lightning.loggers import WandbLogger\n",
"from pytorch_lightning import Trainer\n",
"from lightning.pytorch.loggers import WandbLogger\n",
"from lightning.pytorch import Trainer\n",
"\n",
"wandb_logger = WandbLogger(project='MNIST', # group runs in \"MNIST\" project\n",
" log_model='all') # log all new checkpoints during training"
Expand All @@ -334,7 +342,7 @@
"metadata": {},
"outputs": [],
"source": [
"from pytorch_lightning.callbacks import Callback\n",
"from lightning.pytorch.callbacks import Callback\n",
" \n",
"class LogPredictionsCallback(Callback):\n",
" \n",
Expand Down
4 changes: 2 additions & 2 deletions colabs/pytorch-lightning/Profile_PyTorch_Code.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install -q wandb pytorch_lightning torch_tb_profiler"
"!pip install -q wandb lightning torch_tb_profiler torchvision"
]
},
{
Expand All @@ -99,7 +99,7 @@
"source": [
"import glob\n",
"\n",
"import pytorch_lightning as pl\n",
"import lightning.pytorch as pl\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install -qqq wandb pytorch-lightning torchmetrics"
"!pip install -qqq wandb lightning torchmetrics"
]
},
{
Expand Down Expand Up @@ -142,15 +142,15 @@
"outputs": [],
"source": [
"# ⚡ PyTorch Lightning\n",
"import pytorch_lightning as pl\n",
"import lightning.pytorch as pl\n",
"import torchmetrics\n",
"pl.seed_everything(hash(\"setting random seeds\") % 2**32 - 1)\n",
"\n",
"# 🏋️‍♀️ Weights & Biases\n",
"import wandb\n",
"\n",
"# ⚡ 🤝 🏋️‍♀️\n",
"from pytorch_lightning.loggers import WandbLogger\n"
"from lightning.pytorch.loggers import WandbLogger\n"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install wandb pytorch-lightning -qqq"
"!pip install wandb lightning torchvision -qqq"
]
},
{
Expand All @@ -56,9 +56,9 @@
"source": [
"import os\n",
"\n",
"import pytorch_lightning as pl\n",
"import lightning.pytorch as pl\n",
"# your favorite machine learning tracking tool\n",
"from pytorch_lightning.loggers import WandbLogger\n",
"from lightning.pytorch.loggers import WandbLogger\n",
"\n",
"import torch\n",
"from torch import nn\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install -q pytorch-lightning wandb"
"!pip install -q lightning wandb torchvision"
]
},
{
Expand Down Expand Up @@ -184,7 +184,7 @@
"outputs": [],
"source": [
"from torchvision import transforms\n",
"import pytorch_lightning as pl\n",
"import lightning.pytorch as pl\n",
"import torch\n",
"from torch.utils.data import Dataset, DataLoader, random_split\n",
"from skimage import io, transform\n",
Expand Down Expand Up @@ -368,7 +368,7 @@
"from torch.nn import Linear, CrossEntropyLoss, functional as F\n",
"from torch.optim import Adam\n",
"from torchmetrics.functional import accuracy\n",
"from pytorch_lightning import LightningModule\n",
"from lightning.pytorch import LightningModule\n",
"from torchvision import models\n",
"\n",
"class NatureLitModule(LightningModule):\n",
Expand Down Expand Up @@ -468,7 +468,7 @@
"metadata": {},
"outputs": [],
"source": [
"from pytorch_lightning.callbacks import Callback\n",
"from lightning.pytorch.callbacks import Callback\n",
"\n",
"class LogPredictionsCallback(Callback):\n",
"\n",
Expand Down Expand Up @@ -525,9 +525,9 @@
"metadata": {},
"outputs": [],
"source": [
"from pytorch_lightning.callbacks import ModelCheckpoint\n",
"from pytorch_lightning.loggers import WandbLogger\n",
"from pytorch_lightning import Trainer\n",
"from lightning.pytorch.callbacks import ModelCheckpoint\n",
"from lightning.pytorch.loggers import WandbLogger\n",
"from lightning.pytorch import Trainer\n",
"\n",
"wandb.init(project=PROJECT_NAME,\n",
" entity=ENTITY,\n",
Expand Down

0 comments on commit d57448e

Please sign in to comment.