Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example how to pretrain lm + introduction of config_name #57

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,28 @@ learn.save_encoder("enc")
...
```

## How to pre-train your own language model
Here is how you can per-train a 'de' language model:
From command line:
```
$ bash prepare_wiki.sh de
$ python -W ignore -m multifit new multifit_paper_version replace_ --name my_lm - train_ --pretrain-dataset data/wiki/de-100
Copy link
Collaborator

@sebastianruder sebastianruder Nov 17, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there's a superfluous space between - and train-. Why do we use train_ here? What is the difference between train_ and train?

```
This should give you the pre-trained language model in 'data/wiki/de-100/models/sp15k/my_lm'
You can later use it as follows:
```
from fastai.text import *
import multifit

exp = multifit.from_pretrained("data/wiki/de-100/models/sp15k/my_lm")
exp.finetune_lm.train_("data/cls/de-books", num_epochs=20)
exp.classifier.train_(seed=0)
```

Please note, even though `python -m multifit new ` let's you pick other configurations than `multifit_paper_version` it is not recommended.
As the `from_pretrained` do not yet detect the configuration so your training specific parameters will be overwritten with
defaults from `multifit_paper_version`.

## Reproducing the results
This repository is a rewrite of the original training scripts so it lacks all the scripts used in the paper.
We are working on a port to fastai v2.0 and then we will be adding the scripts that show how to reproduce the results.
Expand Down
3 changes: 0 additions & 3 deletions multifit/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ class Experiment:
def new(self):
return {n: getattr(multifit.configurations,n) for n in multifit.configurations.__all__}

def load(self, model_path):
return multifit.ULMFiT().load_(Path(model_path))

def from_pretrained(self):
return multifit.from_pretrained

Expand Down
34 changes: 24 additions & 10 deletions multifit/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
'multifit_lstm',
'multifit1152_lstm_nl3',
'multifit1152_lstm_nl3_fp16_large',

'multifit_mini_test',
]

def multifit1552_fp32(bs=64):
Expand All @@ -25,7 +27,7 @@ def multifit1552_fp32(bs=64):
bs=bs,
use_adam_08=False,
early_stopping=None,
name=_use_caller_name()
config_name=_use_caller_name()
)
self.arch.replace_(
tokenizer_type='fsp',
Expand All @@ -42,29 +44,29 @@ def multifit1552_fp32(bs=64):
multifit_fp32 = multifit1552_fp32

def multifit_fp32_nl3():
return multifit1552_fp32().replace_(n_layers=3, name=_use_caller_name())
return multifit1552_fp32().replace_(n_layers=3, config_name=_use_caller_name())

# FP16

def multifit1552_fp16():
return multifit1552_fp32(bs=128).replace_(fp16=True, name=_use_caller_name())
return multifit1552_fp32(bs=128).replace_(fp16=True, config_name=_use_caller_name())

def multifit1552_fp16_nl3_large():
return multifit1552_fp32(bs=448).replace_(fp16=True, n_layers=3, num_epochs=20, name=_use_caller_name())
return multifit1552_fp32(bs=448).replace_(fp16=True, n_layers=3, num_epochs=20, config_name=_use_caller_name())

multifit_fp16 = multifit1552_fp16

def multifit_lstm():
return multifit1552_fp32(bs=128).replace_(qrnn=False, n_hid=1552, name=_use_caller_name())
return multifit1552_fp32(bs=128).replace_(qrnn=False, n_hid=1552, config_name=_use_caller_name())

def multifit1152_lstm_nl3(bs=128):
return multifit1552_fp32(bs).replace_(qrnn=False, n_hid=1152, n_layers=3, name=_use_caller_name())
return multifit1552_fp32(bs).replace_(qrnn=False, n_hid=1152, n_layers=3, config_name=_use_caller_name())

def multifit1152_lstm_nl3_fp16_large():
return multifit1152_lstm_nl3(bs=448).replace_(fp16=True, num_epochs=20, name=_use_caller_name())
return multifit1152_lstm_nl3(bs=448).replace_(fp16=True, num_epochs=20, config_name=_use_caller_name())

def multifit_fp16_nl3():
return multifit1552_fp16().replace_(n_layers=3, name=_use_caller_name())
return multifit1552_fp16().replace_(n_layers=3, config_name=_use_caller_name())

def multifit_paper_version():
self = ULMFiT()
Expand All @@ -81,7 +83,7 @@ def multifit_paper_version():
early_stopping=None,
clip=0.12,
dropout_values=dps,
name=_use_caller_name()
config_name=_use_caller_name()
)
self.arch.replace_(
tokenizer_type='sp',
Expand All @@ -99,7 +101,7 @@ def ulmfit_orig():
self = multifit_paper_version()
self.replace_(
seed=None,
name=_use_caller_name()
config_name=_use_caller_name()
)
self.arch.replace_(
tokenizer_type='f',
Expand All @@ -110,6 +112,18 @@ def ulmfit_orig():
)
return self

def multifit_mini_test():
self = multifit_paper_version()
self.replace_(
config_name=_use_caller_name(),
n_hid=240,
n_layers=2,
bs=40,
num_epochs=1,
fp16=False,
limit=100
)
return self

def _use_caller_name():
return inspect.stack()[1].function
32 changes: 18 additions & 14 deletions multifit/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class Dataset:
dataset_path: Path

noise: float = 0.0
limit: int = None

ds_type: str = None
lang: str = None
Expand Down Expand Up @@ -96,6 +95,8 @@ def __post_init__(self):
use_lang_as_prefix=True)
else:
self.read_data = read_clas_csv
self.lang = self._language_from_dataset_path()
self.uses_moses = False
self.trn_path = self.dataset_path / self.trn_name
self.val_path = self.dataset_path / self.val_name
self.tst_path = self.dataset_path / self.tst_name
Expand Down Expand Up @@ -161,11 +162,6 @@ def load_supervised_data(self):
trn_df = self._add_noise(trn_df, self.noise)
val_df = self._add_noise(val_df, self.noise)

if self.limit is not None:
print("Limiting data set to:", self.limit)
trn_df = trn_df[:self.limit]
val_df = val_df[:self.limit]

return trn_df, val_df, tst_df

def load_unsupervised_data(self):
Expand Down Expand Up @@ -200,14 +196,16 @@ def __post_init__(self):
super().__post_init__()
self._vocab = None

def load_lm_databunch(self, bs, bptt):
def load_lm_databunch(self, bs, bptt, limit=None):
lm_suffix = str(bptt) if bptt != 70 else ""
lm_suffix += "" if self.use_tst_for_lm else "-notst"
lm_suffix += "" if limit is None else f"-{limit}"
data_lm = self.load_n_cache_databunch(f"lm{lm_suffix}",
bunch_class=TextLMDataBunch,
data_loader=self.load_unsupervised_data,
bptt=bptt,
bs=bs)
bs=bs,
limit=limit)

with (self.cache_path / "itos.pkl").open('wb') as f:
pickle.dump(data_lm.vocab.itos, f)
Expand All @@ -223,24 +221,26 @@ def _load_vocab(self):
self._vocab = self.load_lm_databunch(bs=20, bptt=70).vocab
return self._vocab

def load_clas_databunch(self, bs):
def load_clas_databunch(self, bs, limit=None):
vocab = self._load_vocab()

cls_name = "cls"
if self.limit is not None:
cls_name = f'{cls_name}limit{self.limit}'
if limit is not None:
cls_name = f'{cls_name}limit{limit}'
if self.noise > 0.0:
cls_name = f'{cls_name}noise{self.noise}'

args = dict(vocab=vocab, bunch_class=TextClasDataBunch, bs=bs)
data_cls = self.load_n_cache_databunch(cls_name, data_loader=lambda: self.load_supervised_data()[:2], **args)
trn_val_dl = lambda: self.load_supervised_data()[:2]
data_cls = self.load_n_cache_databunch(cls_name, data_loader=trn_val_dl, limit=limit, **args)
# Hack to load test dataset with labels
data_tst = self.load_n_cache_databunch('tst', data_loader=lambda: self.load_supervised_data()[1:], **args)
val_tst_dl = lambda: self.load_supervised_data()[1:]
data_tst = self.load_n_cache_databunch('tst', data_loader=val_tst_dl, **args)
data_cls.test_dl = data_tst.valid_dl # data_tst.valid_dl holds test data
data_cls.lang = self.lang
return data_cls

def load_n_cache_databunch(self, name, bunch_class, data_loader, bs, **args):
def load_n_cache_databunch(self, name, bunch_class, data_loader, bs, limit=None, **args):
bunch_path = self.cache_path / name
databunch = None
if bunch_path.exists():
Expand All @@ -251,6 +251,10 @@ def load_n_cache_databunch(self, name, bunch_class, data_loader, bs, **args):
if databunch is None:
print(f"Running tokenization: '{name}' ...")
train_df, valid_df = data_loader()
if limit is not None:
print(f"Limiting number of examples in train and valid sets to: {limit}")
train_df = train_df[:limit]
valid_df = valid_df[:limit]
databunch = self.databunch_from_df(bunch_class, train_df, valid_df, **args)
databunch.save(name)
print(f"Data {name}, trn: {len(databunch.train_ds)}, val: {len(databunch.valid_ds)}")
Expand Down
43 changes: 33 additions & 10 deletions multifit/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class ULMFiTArchitecture(Params):
tokenizer_type: str = "f"
max_vocab: int = 60000
lang: str = None
config_name: str = None

emb_sz: int = awd_lstm_lm_config['emb_sz']
n_hid: int = awd_lstm_lm_config['n_hid']
Expand Down Expand Up @@ -130,7 +131,7 @@ class ULMFiTTrainingCommand(Params):

@property
def model_name(self):
return (self.name or self.arch.model_name()) + (
return (self.name or self.arch.config_name or self.arch.model_name()) + (
"" if self.seed is None or self.seed == 0 or "seed" in self.name else f"seed{self.seed}")

@property
Expand Down Expand Up @@ -161,7 +162,7 @@ def save_paramters(self):
exp_path = params.get('experiment_path', None)
if exp_path:
fn = self.info_json
print("Saving dump to", exp_path / fn)
print("Saving args to", exp_path / fn)
json_str = json.dumps(to_json_serializable(params), indent=2)
with (exp_path / fn).open("w") as f:
f.write(json_str)
Expand Down Expand Up @@ -194,6 +195,12 @@ def load_(self, experiment_path, tantetive=True, update_arch=True, silent=False)
self.replace_(_verbose_diff=not silent, **d)
return arch

def train_(self, dataset_or_path, **kwargs):
pass

def validate(self, **kwargs):
pass


@dataclass
class ULMFiTPretraining(ULMFiTTrainingCommand):
Expand All @@ -210,6 +217,7 @@ class ULMFiTPretraining(ULMFiTTrainingCommand):
clip: float = None
fp16: bool = False
lr: float = 5e-3
limit: int = None

def get_learner(self, data_lm, **additional_trn_args):
config = awd_lstm_lm_config.copy()
Expand Down Expand Up @@ -264,7 +272,7 @@ def train_(self, dataset_or_path, tokenizer=None, **train_config):
tokenizer = self.arch.new_tokenizer()

dataset = self._set_dataset_(dataset_or_path, tokenizer)
learn = self.get_learner(data_lm=dataset.load_lm_databunch(bs=self.bs, bptt=self.bptt))
learn = self.get_learner(data_lm=dataset.load_lm_databunch(bs=self.bs, bptt=self.bptt, limit=self.limit))
experiment_path = learn.path / learn.model_dir
print("Experiment", experiment_path)
if self.num_epochs > 0:
Expand All @@ -280,7 +288,7 @@ def train_(self, dataset_or_path, tokenizer=None, **train_config):
print("Language model saved to", self.experiment_path)

def validate(self):
raise NotImplementedError("The validation on the language model is not implemented.")
return "not implemented"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really just want to return a string here?


@property
def model_fnames(self):
Expand Down Expand Up @@ -346,6 +354,7 @@ class ULMFiTClassifier(ULMFiTTrainingCommand):
seed: int = 0
bptt: int = 70
fp16: bool = False
limit: int = None
arch: ULMFiTArchitecture = None

def get_learner(self, data_clas, eval_only=False, **additional_trn_args):
Expand Down Expand Up @@ -402,7 +411,7 @@ def train_(self, dataset_or_path=None, **train_config):

base_tokenizer = self.base.tokenizer
dataset = self._set_dataset_(dataset_or_path, base_tokenizer)
data_clas = dataset.load_clas_databunch(bs=self.bs)
data_clas = dataset.load_clas_databunch(bs=self.bs, limit=self.limit)
learn = self.get_learner(data_clas=data_clas)
print(f"Training: {learn.path / learn.model_dir}")
learn.unfreeze()
Expand All @@ -415,7 +424,6 @@ def train_(self, dataset_or_path=None, **train_config):
print("Classifier model saved to", self.experiment_path)
self.save_paramters()
learn.destroy()
return

def _validate(self, learn, ds_type):
ds_name = ds_type.name.lower()
Expand All @@ -438,7 +446,7 @@ def validate(self, *splits, data_cls=None, save_name=CLS_BEST, use_cache=True, s
return json.load(fp)

if data_cls is None:
data_cls = self.dataset.load_clas_databunch(bs=self.bs)
data_cls = self.dataset.load_clas_databunch(bs=self.bs, limit=self.limit)

learn = self.get_learner(data_cls, eval_only=True)
# avg = 'binary' if learn.data.c == 2 else 'macro'
Expand Down Expand Up @@ -572,10 +580,25 @@ def pprint(self):
{self.classifier},
)""")

def train_(self, pretrain_dataset=None, clas_dataset=None):
results = {}
if pretrain_dataset is not None:
self.pretrain_lm.train_(pretrain_dataset)
results['pretrain_lm'] = self.pretrain_lm.validate()
if clas_dataset is not None:
self.finetune_lm.train_(clas_dataset)
results['finetune_lm'] = self.finetune_lm.validate(use_cache=None)
self.classifier.train_(clas_dataset)
results['classifier'] = self.classifier.validate(use_cache=None)
return results

def from_pretrained_(self, name, repo="n-waves/multifit-models"):
name = name.rstrip(".tgz") # incase someone put's tgz name the name
url = f"https://github.com/{repo}/releases/download/{name}/{name}.tgz"
path = untar_data(url.rstrip(".tgz"), data=False) # untar_data adds .tgz
if (Path(name)/f"{LM_BEST}.pth").exists():
path = Path(name)
else:
name = name.rstrip(".tgz") # incase someone put's tgz name the name
url = f"https://github.com/{repo}/releases/download/{name}/{name}.tgz"
path = untar_data(url.rstrip(".tgz"), data=False) # untar_data adds .tgz
return self.load_(path)


Expand Down