From 5758a7eb1c156b2c513930af666acb2d4a2cb3c3 Mon Sep 17 00:00:00 2001 From: JINO ROHIT Date: Thu, 10 Oct 2024 14:34:16 +0530 Subject: [PATCH] ENH LoRA notebook for NER task (#2126) --- .../token_classification/peft_lora_ner.ipynb | 780 ++++++++++++++++++ 1 file changed, 780 insertions(+) create mode 100644 examples/token_classification/peft_lora_ner.ipynb diff --git a/examples/token_classification/peft_lora_ner.ipynb b/examples/token_classification/peft_lora_ner.ipynb new file mode 100644 index 0000000000..fae9b94b4e --- /dev/null +++ b/examples/token_classification/peft_lora_ner.ipynb @@ -0,0 +1,780 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Named Entity Recognition with Peft Model 🤗\n", + "\n", + "##### In this notebook, we will learn how to perform Named Entity Recognition(NER) on the CoNLL-2003 dataset using the Trainer class\n", + "\n", + "##### This notebook has been adapted from the main NLP course here - https://huggingface.co/learn/nlp-course/chapter7/2?fw=pt#fine-tuning-the-model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#install the required libraries\n", + "!pip install -q datasets evaluate transformers seqeval" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Import required libraries\n", + "from datasets import load_dataset\n", + "from transformers import AutoTokenizer, AutoModelForTokenClassification, DataCollatorForTokenClassification, TrainingArguments, Trainer, pipeline\n", + "from peft import get_peft_model, LoraConfig, TaskType\n", + "import evaluate\n", + "import numpy as np\n", + "from huggingface_hub import notebook_login" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DatasetDict({\n", + " train: Dataset({\n", + " features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],\n", + " num_rows: 14041\n", + " })\n", + " validation: Dataset({\n", + " features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],\n", + " num_rows: 3250\n", + " })\n", + " test: Dataset({\n", + " features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],\n", + " num_rows: 3453\n", + " })\n", + "})\n" + ] + } + ], + "source": [ + "raw_datasets = load_dataset(\"conll2003\")\n", + "print(raw_datasets)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.']" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Look at the tokens of the first training example\n", + "raw_datasets[\"train\"][0][\"tokens\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[3, 0, 7, 0, 0, 0, 7, 0, 0]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Look at the NER tags of the first training example\n", + "raw_datasets[\"train\"][0][\"ner_tags\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Get the label names for the NER tags\n", + "ner_feature = raw_datasets[\"train\"].features[\"ner_tags\"]\n", + "label_names = ner_feature.feature.names\n", + "label_names" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "EU rejects German call to boycott British lamb . \n", + "B-ORG O B-MISC O O O B-MISC O O \n" + ] + } + ], + "source": [ + "words = raw_datasets[\"train\"][0][\"tokens\"]\n", + "labels = raw_datasets[\"train\"][0][\"ner_tags\"]\n", + "line1 = \"\"\n", + "line2 = \"\"\n", + "for word, label in zip(words, labels):\n", + " full_label = label_names[label]\n", + " max_length = max(len(word), len(full_label))\n", + " line1 += word + \" \" * (max_length - len(word) + 1)\n", + " line2 += full_label + \" \" * (max_length - len(full_label) + 1)\n", + "\n", + "print(line1)\n", + "print(line2)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "e:\\open_source\\peft-folder\\ner-examples\\.venv\\Lib\\site-packages\\transformers\\tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "# Load the tokenizer\n", + "model_checkpoint = \"bert-base-cased\"\n", + "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['[CLS]',\n", + " 'EU',\n", + " 'rejects',\n", + " 'German',\n", + " 'call',\n", + " 'to',\n", + " 'boycott',\n", + " 'British',\n", + " 'la',\n", + " '##mb',\n", + " '.',\n", + " '[SEP]']" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Tokenize the first training example\n", + "inputs = tokenizer(raw_datasets[\"train\"][0][\"tokens\"], is_split_into_words=True)\n", + "inputs.tokens()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "def align_labels_with_tokens(labels, word_ids):\n", + " new_labels = []\n", + " current_word = None\n", + " for word_id in word_ids:\n", + " if word_id != current_word:\n", + " # Start of a new word!\n", + " current_word = word_id\n", + " label = -100 if word_id is None else labels[word_id]\n", + " new_labels.append(label)\n", + " elif word_id is None:\n", + " # Special token\n", + " new_labels.append(-100)\n", + " else:\n", + " # Same word as previous token\n", + " label = labels[word_id]\n", + " # If the label is B-XXX we change it to I-XXX\n", + " if label % 2 == 1:\n", + " label += 1\n", + " new_labels.append(label)\n", + "\n", + " return new_labels" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[3, 0, 7, 0, 0, 0, 7, 0, 0]\n", + "[-100, 3, 0, 7, 0, 0, 0, 7, 0, 0, 0, -100]\n" + ] + } + ], + "source": [ + "labels = raw_datasets[\"train\"][0][\"ner_tags\"]\n", + "word_ids = inputs.word_ids()\n", + "print(labels)\n", + "print(align_labels_with_tokens(labels, word_ids))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "def tokenize_and_align_labels(examples):\n", + " tokenized_inputs = tokenizer(\n", + " examples[\"tokens\"], truncation=True, is_split_into_words=True\n", + " )\n", + " all_labels = examples[\"ner_tags\"]\n", + " new_labels = []\n", + " for i, labels in enumerate(all_labels):\n", + " word_ids = tokenized_inputs.word_ids(i)\n", + " new_labels.append(align_labels_with_tokens(labels, word_ids))\n", + "\n", + " tokenized_inputs[\"labels\"] = new_labels\n", + " return tokenized_inputs" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "tokenized_datasets = raw_datasets.map(\n", + " tokenize_and_align_labels,\n", + " batched=True,\n", + " remove_columns=raw_datasets[\"train\"].column_names,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[-100, 3, 0, 7, 0, 0, 0, 7, 0, 0, 0, -100]\n", + "[-100, 1, 2, -100]\n" + ] + } + ], + "source": [ + "for i in range(2):\n", + " print(tokenized_datasets[\"train\"][i][\"labels\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "metric = evaluate.load(\"seqeval\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "# Create label mappings\n", + "id2label = {i: label for i, label in enumerate(label_names)}\n", + "label2id = {v: k for k, v in id2label.items()}" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "# Load the pre-trained model\n", + "model = AutoModelForTokenClassification.from_pretrained(\n", + " model_checkpoint,\n", + " id2label=id2label,\n", + " label2id=label2id,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "9" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.config.num_labels" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "BertForTokenClassification(\n", + " (bert): BertModel(\n", + " (embeddings): BertEmbeddings(\n", + " (word_embeddings): Embedding(28996, 768, padding_idx=0)\n", + " (position_embeddings): Embedding(512, 768)\n", + " (token_type_embeddings): Embedding(2, 768)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (encoder): BertEncoder(\n", + " (layer): ModuleList(\n", + " (0-11): 12 x BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSdpaSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " (intermediate_act_fn): GELUActivation()\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (classifier): Linear(in_features=768, out_features=9, bias=True)\n", + ")" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "trainable params: 301,833 || all params: 108,028,434 || trainable%: 0.2794\n" + ] + } + ], + "source": [ + "# Configure LoRA (Low-Rank Adaptation) for fine-tuning\n", + "peft_config = LoraConfig(target_modules = [\"query\", \"key\"], task_type = TaskType.TOKEN_CLS)\n", + "\n", + "model = get_peft_model(model, peft_config)\n", + "model.print_trainable_parameters()" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "def compute_metrics(eval_preds):\n", + " logits, labels = eval_preds\n", + " predictions = np.argmax(logits, axis=-1)\n", + "\n", + " # Remove ignored index (special tokens) and convert to labels\n", + " true_labels = [[label_names[l] for l in label if l != -100] for label in labels]\n", + " true_predictions = [\n", + " [label_names[p] for (p, l) in zip(prediction, label) if l != -100]\n", + " for prediction, label in zip(predictions, labels)\n", + " ]\n", + " all_metrics = metric.compute(predictions=true_predictions, references=true_labels)\n", + " return {\n", + " \"precision\": all_metrics[\"overall_precision\"],\n", + " \"recall\": all_metrics[\"overall_recall\"],\n", + " \"f1\": all_metrics[\"overall_f1\"],\n", + " \"accuracy\": all_metrics[\"overall_accuracy\"],\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "60bd54dd23de4822891a157430ff47b9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(HTML(value='