Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transformer encoder fine-tuning with Adapters #90

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
17 changes: 17 additions & 0 deletions kiwi/lib/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,23 @@ def run(
)
return run_info

if trainer.model.config.model.encoder.adapter is not None:
if trainer.model.config.model.encoder.adapter.fusion:
# We just save the entire model
pass
else:
language = trainer.model.config.model.encoder.adapter.language
adapter_path = (
Path(checkpoint_callback.best_model_path).parent / f'{language}'
)
logger.info(f"Saving the Adapter '{language}' to: {adapter_path}")
# TODO/FIXME: we should really move towards calling these things with a
# generic name, like 'trainer.model.encoder.transformer' or somtething
try:
trainer.model.encoder.bert.save_adapter(adapter_path, language)
except AttributeError:
trainer.model.encoder.xlm_roberta.save_adapter(adapter_path, language)

if tracking_logger:
# Send best model file to logger
tracking_logger.log_model(best_model_path)
Expand Down
51 changes: 49 additions & 2 deletions kiwi/systems/encoders/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@
import logging
from collections import Counter, OrderedDict
from pathlib import Path
from typing import Dict, Union
from typing import Dict, List, Union

import torch
from pydantic import confloat
from pydantic import DirectoryPath, confloat
from pydantic.class_validators import validator
from torch import Tensor, nn
from transformers import (
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
AdapterType,
AutoTokenizer,
BertConfig,
BertModel,
)
from transformers.adapter_config import PfeifferConfig

from kiwi import constants as const
from kiwi.data.batch import MultiFieldBatch
Expand Down Expand Up @@ -90,6 +92,26 @@ def fit_vocab(
)


class EncoderAdapterConfig(BaseConfig):
language: str = None
"""Specify a name to create a new language adapter, e.g. en-de, en-zh, etc."""

load: List[DirectoryPath] = None
"""Load trained adapters to use for prediction or for Adapter Fusion.
Point it to the root directory."""

fusion: bool = False
"""Train Adapter Fusion on top of the loaded adapters."""

@validator('fusion')
def check_load(cls, v, values):
if v and not values.get('load'):
raise NotImplementedError(
'Specify adapters to load if you want to fuse them'
)
return v


@MetaModule.register_subclass
class BertEncoder(MetaModule):
"""BERT model as presented in Google's paper and using Hugging Face's code
Expand All @@ -102,6 +124,9 @@ class Config(BaseConfig):
model_name: Union[str, Path] = 'bert-base-multilingual-cased'
"""Pre-trained BERT model to use."""

adapter: EncoderAdapterConfig = None
"""Use an Adapter to fine tune the encoder."""

use_mismatch_features: bool = False
"""Use Alibaba's mismatch features."""

Expand Down Expand Up @@ -156,6 +181,28 @@ def __init__(
)
self.bert = BertModel(bert_config)

# Add Adapters if specified
if config.adapter is not None:
if config.adapter.language is not None:
# Add an adapter module
self.bert.add_adapter(
config.adapter.language,
AdapterType.text_lang,
config=PfeifferConfig(),
)
self.bert.train_adapter(config.adapter.language)
if config.adapter.load is not None:
# Load the adapter module
for path in config.adapter.load:
self.bert.load_adapter(
str(path), AdapterType.text_lang, config=PfeifferConfig(),
)
# Add fusion of adapters
if config.adapter.fusion:
adapter_setup = [[path.name for path in config.adapter.load]]
self.bert.add_fusion(adapter_setup[0], "dynamic")
self.bert.train_fusion(adapter_setup)

self.vocabs = {
const.TARGET: vocabs[const.TARGET],
const.SOURCE: vocabs[const.SOURCE],
Expand Down
51 changes: 49 additions & 2 deletions kiwi/systems/encoders/xlmroberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,20 @@
import logging
from collections import Counter, OrderedDict
from pathlib import Path
from typing import Dict, Union
from typing import Dict, List, Union

import torch
from pydantic import confloat
from pydantic import DirectoryPath, confloat
from pydantic.class_validators import validator
from torch import Tensor, nn
from transformers import (
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
AdapterType,
AutoTokenizer,
XLMRobertaConfig,
XLMRobertaModel,
)
from transformers.adapter_config import PfeifferConfig
from typing_extensions import Literal

from kiwi import constants as const
Expand Down Expand Up @@ -99,6 +101,26 @@ def fit_vocab(
self.vocab.max_size(vocab_size)


class EncoderAdapterConfig(BaseConfig):
language: str = None
"""Specify a name to add a new language adapter, e.g. 'en-de', 'en-zh', etc."""

load: List[DirectoryPath] = None
"""Load trained adapters to use for prediction or for Adapter Fusion.
Point it to the root directory."""

fusion: bool = False
"""Train Adapter Fusion on top of the loaded adapters."""

@validator('fusion')
def check_load(cls, v, values):
if v and not values.get('load'):
raise NotImplementedError(
'Specify adapters to load if you want to fuse them'
)
return v


@MetaModule.register_subclass
class XLMRobertaEncoder(MetaModule):
"""XLM-RoBERTa model, using HuggingFace's implementation."""
Expand All @@ -107,6 +129,9 @@ class Config(BaseConfig):
model_name: Union[str, Path] = 'xlm-roberta-base'
"""Pre-trained XLMRoberta model to use."""

adapter: EncoderAdapterConfig = None
"""Use an Adapter to fine tune the encoder."""

interleave_input: bool = False
"""Concatenate SOURCE and TARGET without internal padding
(111222000 instead of 111002220)"""
Expand Down Expand Up @@ -156,6 +181,28 @@ def __init__(
)
self.xlm_roberta = XLMRobertaModel(xlm_roberta_config)

# Add Adapters if specified
if config.adapter is not None:
if config.adapter.language is not None:
# Add an adapter module
self.xlm_roberta.add_adapter(
config.adapter.language,
AdapterType.text_lang,
config=PfeifferConfig(),
)
self.xlm_roberta.train_adapter(config.adapter.language)
if config.adapter.load is not None:
# Load the adapter module
for path in config.adapter.load:
self.xlm_roberta.load_adapter(
str(path), AdapterType.text_lang, config=PfeifferConfig(),
)
# Add fusion of adapters
if config.adapter.fusion:
adapter_setup = [[path.name for path in config.adapter.load]]
self.xlm_roberta.add_fusion(adapter_setup[0], "dynamic")
self.xlm_roberta.train_fusion(adapter_setup)

self.vocabs = {
const.TARGET: vocabs[const.TARGET],
const.SOURCE: vocabs[const.SOURCE],
Expand Down