diff --git a/YOCO/README.md b/YOCO/README.md
index e61e137cc..ab3f3becd 100644
--- a/YOCO/README.md
+++ b/YOCO/README.md
@@ -1,6 +1,168 @@
-# YOCO
+# You Only Cache Once: Decoder-Decoder Architectures for Large Language Models
-- May 2024: Code release
-- May 2024: release preprint [YOCO](https://arxiv.org/abs/)
+## Approach
+
+
+
-## Getting Started
+
+
+
+
+## Performance
+### Harness Eval
+Training with 1T Tokens:
+| **Model** | **Arc-c** | **Arc-e** | **BoolQ** | **Hellaswag**$^*$ | **OBQA** | **PIQA** | **Winogrande** | **SciQ** | **Avg** |
+|----------------------------|-----------|-----------|-----------|-------------------|----------|----------|----------------|----------|---------|
+| OpenLLaMA-3B-v2 | 0.339 | 0.676 | 0.657 | **0.700** | 0.260 | 0.767 | 0.629 | 0.924 | 0.619 |
+| StableLM-base-alpha-3B-v2 | 0.324 | 0.673 | 0.646 | 0.686 | 0.264 | 0.760 | 0.621 | 0.921 | 0.612 |
+| StableLM-3B-4E1T | --- | 0.666 | --- | --- | --- | **0.768**| 0.632 | 0.914 | --- |
+| YOCO-3B | **0.379** | **0.731** | 0.645 | 0.689 | **0.298**| 0.763 | 0.639 | 0.924 | **0.634**|
+
+Training with 1.6T Tokens:
+| **Model** | **Arc-c** | **Arc-e** | **BoolQ** | **Hellaswag**$^*$ | **OBQA** | **PIQA** | **Winogrande** | **SciQ** | **Avg** |
+|----------------------------|-----------|-----------|-----------|-------------------|----------|----------|----------------|----------|---------|
+| StableLM-3B-4E1T | --- | 0.688 | --- | --- | --- | 0.762 | 0.627 | 0.913 | --- |
+| YOCO-3B | 0.396 | 0.733 | **0.644** | 0.698 | 0.300 | 0.764 | 0.631 | 0.921 | 0.636 |
+| YOCO-3B-1M | **0.413** | **0.747** | 0.638 | **0.705** | 0.300 | **0.773**| **0.651** | **0.932**| **0.645**|
+### Needle In A Haystack
+
+
+
+
+### Multi-Needle Eval
+| **Model** | **Size** | **N=1** | **N=2** | **N=4** | **N=8** |
+|-------------------------|----------|---------|---------|---------|---------|
+| GPT-4-128K | -- | 1.00 | 1.00 | 0.98 | 1.00 |
+| MiniCPM-128K | 2.4B | 1.00 | 1.00 | 0.54 | 0.56 |
+| ChatGLM3-128K | 6B | 0.94 | 0.72 | 0.52 | 0.44 |
+| YaRN-Mistral-128K | 7B | 0.02 | 0.12 | 0.08 | 0.20 |
+| LWM-1M-text | 7B | 1.00 | 0.90 | 0.76 | 0.62 |
+| YOCO-3B-1M | 3B | 0.98 | 0.98 | 0.84 | 0.56 |
+
+## Setup
+
+To install the required packages, use the following command:
+
+```bash
+pip install -r requirements.txt
+```
+
+Besides normal packages, [Apex](https://github.com/NVIDIA/apex) and [Flash-Attention](https://github.com/Dao-AILab/flash-attention) should be installed seperately following their offcial guidences.
+
+## Harness Eval
+
+To evaluate models in Harness-Eval, the script is as follows in ```scripts/eval_task.sh```:
+```bash
+cd fairseq/
+TASK='harness_boolq'
+
+torchrun --master-port=29505 --nproc_per_node=1 validate.py \
+ --data-dir ../harness_data/ \
+ --criterion harness_eval \
+ --task harness_eval \
+ --batch-size 4 \
+ --eval-data ${TASK} \
+ --log-format simple --log-interval 10 \
+ --bf16 \
+ --tokenizer-pad-to-multiple 8 \
+ --arch yoco_3b_new --tiktoken-model cl100k_base --load-ckpt /path_to_ckpt/YOCO-3B-1M/checkpoint.pth --yoco-model /path_to_ckpt/YOCO-3B-1M --tokens-per-sample 4096
+```
+
+## Needle In A Haystack Evaluation
+Our model uses city-number pairs for long sequence evaluation. To get the results at a certain maximal length, the script is as follows in ```scripts/eval_needle.sh```:
+```bash
+cd fairseq/
+torchrun --master-port=29504 --nproc_per_node=1 validate.py \
+ --task pseudo \
+ --criterion needle_haystack \
+ --batch-size 1 \
+ --max-epoch 1 \
+ --no-save \
+ --tiktoken-model cl100k_base \
+ --bf16 \
+ --arch yoco_3b_new --tiktoken-model cl100k_base --load-ckpt /path_to_ckpt/YOCO-3B-1M/checkpoint.pth --yoco-model /path_to_ckpt/YOCO-3B-1M --tokens-per-sample 1048576 --interval 1048576
+```
+
+To run Multi-Needle experiments, replace ```--criterion needle_haystack``` with ```--criterion multi_needle --needle-num {num}```.
+
+## Pretraining From Scratch
+To support distributed training, our implementation is based on infinibatch to read data iteratively. The overall data directory should be organized as follows:
+```
+Data/
+├── json/
+│ ├── train.json
+│ └── CC.json
+│ └── StarCoder.json
+│ └── ...
+├── shard/
+│ ├── CC/
+│ │ ├── 00000.jsonl
+│ │ ├── 00001.jsonl
+│ │ └── ...
+│ └── StarCoder/
+│ ├── 00000.jsonl
+│ ├── 00001.jsonl
+│ └── ...
+```
+
+We recommend that each sharded data files contains no more than 10K lines with one json dict per line, and jsonl file, such as ```Data/shard/CC/00000.jsonl```, should be in the format like this:
+```json
+{"text": "File 1 is here..."}
+{"text": "File 2 is here..."}
+...
+```
+
+Then, for each source, a JSON file preserves all the paths of the jsonl files. Take ```Data/json/CC.json``` for example:
+```json
+[
+ "/path_to_data/Data/shard/CC/00000.jsonl",
+ "/path_to_data/Data/shard/CC/00001.jsonl",
+ ...
+]
+```
+
+Finally, ```train.json``` records all sources' information and sampling ratio:
+```json
+[
+ {
+ "name": "CC",
+ "weight": 0.5
+ },
+ {
+ "name": "StarCoder",
+ "weight": 0.2
+ },
+ ...
+]
+```
+
+ ```scripts/train.sh```:
+```bash
+cd fairseq/
+torchrun --nproc-per-node=1 train.py /path_to_data \
+ --save-interval-updates 5000 \
+ --no-epoch-checkpoints \
+ --arch yoco_base \
+ --criterion cross_entropy \
+ --task gpt \
+ --tokens-per-sample 2048 \
+ --tokenizer-pad-to-multiple 8 \
+ --pad-to-max-len \
+ --optimizer adam --adam-betas "(0.9, 0.95)" \
+ --adam-eps 1e-06 \
+ --clip-norm 2.0 \
+ --lr 0.00015 \
+ --lr-scheduler polynomial_decay \
+ --warmup-updates 50 \
+ --weight-decay 0.05 \
+ --batch-size 1 \
+ --model-parallel-size 1 \
+ --update-freq 1 \
+ --batch-read-ahead 1000 \
+ --total-num-update 300000 \
+ --log-format simple --log-interval 10 --disable-validation \
+ --tiktoken-model cl100k_base \
+ --save-interval-updates 5000 \
+ --bf16 # bf16 is encouraged in pre-training
+```
diff --git a/YOCO/imgs/1m_retrieval.png b/YOCO/imgs/1m_retrieval.png
new file mode 100644
index 000000000..9fb8d9490
Binary files /dev/null and b/YOCO/imgs/1m_retrieval.png differ
diff --git a/YOCO/imgs/arch.png b/YOCO/imgs/arch.png
new file mode 100644
index 000000000..152406374
Binary files /dev/null and b/YOCO/imgs/arch.png differ
diff --git a/YOCO/imgs/inference.png b/YOCO/imgs/inference.png
new file mode 100644
index 000000000..0751e0a63
Binary files /dev/null and b/YOCO/imgs/inference.png differ
diff --git a/YOCO/requirements.txt b/YOCO/requirements.txt
new file mode 100644
index 000000000..2e1336239
--- /dev/null
+++ b/YOCO/requirements.txt
@@ -0,0 +1,12 @@
+torch>=2.2.0
+triton>=2.2.0
+numpy==1.23.0
+fairscale
+tiktoken
+sentencepiece
+ninja
+boto3
+iopath
+git+https://github.com/sunyt32/fairseq.git@moe3#egg=fairseq
+git+https://github.com/shumingma/infinibatch.git#egg=infinibatch
+git+https://github.com/microsoft/torchscale.git#egg=torchscale
\ No newline at end of file
diff --git a/YOCO/scripts/eval_needle.sh b/YOCO/scripts/eval_needle.sh
new file mode 100644
index 000000000..a6277901f
--- /dev/null
+++ b/YOCO/scripts/eval_needle.sh
@@ -0,0 +1,11 @@
+cd yoco/
+torchrun --master-port=29504 --nproc_per_node=1 validate.py \
+ --task pseudo \
+ --criterion multi_needle --needle-num 4 \
+ --batch-size 1 \
+ --max-epoch 1 \
+ --no-save \
+ --tiktoken-model cl100k_base \
+ --bf16 \
+ --arch yoco_3b_new --tiktoken-model cl100k_base --load-ckpt /data/yutao/ckpt_opensource/YOCO-3B-1M/checkpoint.pth --yoco-model /data/yutao/ckpt_opensource/YOCO-3B-1M --tokens-per-sample 1048576 --interval 1048576
+
diff --git a/YOCO/scripts/eval_task.sh b/YOCO/scripts/eval_task.sh
new file mode 100644
index 000000000..07b70593e
--- /dev/null
+++ b/YOCO/scripts/eval_task.sh
@@ -0,0 +1,17 @@
+TASK='harness_boolq'
+# TASK='hendrycksTest-abstract_algebra'
+
+cd yoco/
+torchrun --master-port=29505 --nproc_per_node=1 validate.py \
+ --data-dir ../harness_data/ \
+ --criterion harness_eval \
+ --task harness_eval \
+ --batch-size 4 \
+ --eval-data ${TASK} \
+ --log-format simple --log-interval 10 \
+ --bf16 \
+ --tokenizer-pad-to-multiple 8 \
+ --arch yoco_3b_new --tiktoken-model cl100k_base --load-ckpt /data/yutao/ckpt_opensource/YOCO-3B-1M/checkpoint.pth --yoco-model /data/yutao/ckpt_opensource/YOCO-3B-1M --tokens-per-sample 4096
+ # --arch llama_from_ckpt --llama-model /data/yutao/llama/llama-2-7b --load-ckpt /data/yutao/llama/llama-2-7b/consolidated.00.pth --tokens-per-sample 4096
+
+
diff --git a/YOCO/scripts/train.sh b/YOCO/scripts/train.sh
new file mode 100644
index 000000000..28c13f7bf
--- /dev/null
+++ b/YOCO/scripts/train.sh
@@ -0,0 +1,27 @@
+cd yoco/
+torchrun --master-port=29501 --nproc-per-node=1 train.py /mnt/nlcredstone/shaohanh/data/redstone_v4_21_config \
+ --save-interval-updates 5000 \
+ --no-epoch-checkpoints \
+ --arch yoco_base \
+ --criterion cross_entropy \
+ --task gpt \
+ --tokens-per-sample 2048 \
+ --tokenizer-pad-to-multiple 8 \
+ --pad-to-max-len \
+ --optimizer adam --adam-betas "(0.9, 0.95)" \
+ --adam-eps 1e-06 \
+ --clip-norm 2.0 \
+ --lr 0.00015 \
+ --lr-scheduler polynomial_decay \
+ --warmup-updates 50 \
+ --weight-decay 0.05 \
+ --batch-size 1 \
+ --model-parallel-size 1 \
+ --update-freq 1 \
+ --batch-read-ahead 1000 \
+ --total-num-update 300000 \
+ --log-format simple --log-interval 10 --disable-validation \
+ --tiktoken-model cl100k_base \
+ --no-save \
+ --bf16 \
+
diff --git a/YOCO/yoco/__init__.py b/YOCO/yoco/__init__.py
new file mode 100644
index 000000000..3ae31e250
--- /dev/null
+++ b/YOCO/yoco/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
diff --git a/YOCO/yoco/criterions/__init__.py b/YOCO/yoco/criterions/__init__.py
new file mode 100644
index 000000000..9901f2753
--- /dev/null
+++ b/YOCO/yoco/criterions/__init__.py
@@ -0,0 +1,8 @@
+import importlib
+import os
+
+# automatically import any Python files in the criterions/ directory
+for file in sorted(os.listdir(os.path.dirname(__file__))):
+ if file.endswith(".py") and not file.startswith("_"):
+ file_name = file[: file.find(".py")]
+ importlib.import_module("criterions." + file_name)
\ No newline at end of file
diff --git a/YOCO/yoco/criterions/harness_eval.py b/YOCO/yoco/criterions/harness_eval.py
new file mode 100644
index 000000000..8aed18e36
--- /dev/null
+++ b/YOCO/yoco/criterions/harness_eval.py
@@ -0,0 +1,86 @@
+import torch
+import torch.nn.functional as F
+
+from fairseq import metrics
+from fairseq.criterions import FairseqCriterion, register_criterion
+from fairseq.dataclass import FairseqDataclass
+
+
+@register_criterion("harness_eval", dataclass=FairseqDataclass)
+class HarnessEvalCriterion(FairseqCriterion):
+ def __init__(self, cfg, task):
+ super().__init__(task)
+
+ def forward(self, model, sample, reduce=True):
+ """Compute the loss for the given sample.
+
+ Returns a tuple with three elements:
+ 1) the loss
+ 2) the sample size, which is used as the denominator for the gradient
+ 3) logging outputs to display while training
+ """
+ model.eval()
+ net_output, _ = model(sample["net_input"]["src_tokens"])
+ net_output = net_output[:, :-1, :]
+ targets = sample["net_input"]["src_tokens"][:, 1:]
+ loss_mask = sample["net_input"]["gpt_loss_mask"][:, 1:]
+ label_length = sample["net_input"]["label_length"]
+ loss = F.cross_entropy(
+ net_output.float().reshape(-1, net_output.size(-1)),
+ targets.reshape(-1),
+ reduction="none",
+ ignore_index=self.padding_idx,
+ ).reshape(targets.size(0), -1)
+ loss = loss * loss_mask.int()
+ loss_norm = loss.sum(-1) / label_length.float()
+ loss = loss.sum(-1)
+
+ option_num = self.task.harness_task.class_num
+ labels = sample["targets"].view(-1)
+
+ assert sample["targets"].size(0) % option_num == 0
+ sample_size = sample["ntokens"]
+
+ pred_label = torch.argmin(loss.view(-1, option_num), dim=1)
+ pred_norm_label = torch.argmin(loss_norm.view(-1, option_num), dim=1)
+ target_label = labels.view(-1, option_num)[:, 0]
+
+ logging_output = {}
+
+ logging_output.update(
+ {
+ "loss": 0,
+ "nsentences": pred_label.size(0),
+ "sample_size": pred_label.size(0),
+ "ncorrect": (pred_label == target_label).sum().item(),
+ "ncorrect_norm": (pred_norm_label == target_label).sum().item(),
+ }
+ )
+
+ return loss, sample_size, logging_output
+
+ @staticmethod
+ def reduce_metrics(logging_outputs) -> None:
+ """Aggregate logging outputs from data parallel training."""
+ loss = sum(log.get("loss", 0) for log in logging_outputs)
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
+ ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
+ ncorrect_norm = sum(log.get("ncorrect_norm", 0) for log in logging_outputs)
+ metrics.log_scalar(
+ "loss", loss / nsentences, nsentences, round=3
+ )
+ metrics.log_scalar(
+ "accuracy", 100.0 * ncorrect / nsentences, nsentences, round=2
+ )
+ metrics.log_scalar(
+ "accuracy_norm", 100.0 * ncorrect_norm / nsentences, nsentences, round=2
+ )
+
+ @staticmethod
+ def logging_outputs_can_be_summed() -> bool:
+ """
+ Whether the logging outputs returned by `forward` can be summed
+ across workers prior to calling `reduce_metrics`. Setting this
+ to True will improves distributed training speed.
+ """
+ return True
\ No newline at end of file
diff --git a/YOCO/yoco/criterions/multi_needle.py b/YOCO/yoco/criterions/multi_needle.py
new file mode 100644
index 000000000..f1b564ec7
--- /dev/null
+++ b/YOCO/yoco/criterions/multi_needle.py
@@ -0,0 +1,181 @@
+import os
+import random
+import math
+from dataclasses import dataclass, field
+
+import torch
+import torch.nn.functional as F
+
+from fairseq import metrics
+from fairseq.criterions import FairseqCriterion, register_criterion
+from fairseq.dataclass import FairseqDataclass
+
+OURS_TEMPLATE = "There is a special magic number inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the magic number there. {context} "
+RANDOM_NEEDLE_CITIES = [
+ 'Chicago', 'Yangon', 'Antananarivo', 'Colombo', 'Almaty', 'Sydney', 'Chicago', 'Mexico City',
+ 'Seattle', 'Lagos', 'Amsterdam', 'Belgrade', 'Cairo', 'Baghdad', 'Damascus', 'Kigali', 'Dakar',
+ 'Dakar', 'Sofia', 'Kigali', 'Victoria', 'Tashkent', 'Mumbai', 'Barcelona', 'Almaty', 'Amman',
+ 'Toronto', 'Bratislava', 'Johannesburg', 'Thimphu', 'Bangkok', 'Santiago', 'Cairo', 'San Francisco',
+ 'Lagos', 'Amsterdam', 'Paris', 'Rabat', 'Santiago', 'Copenhagen', 'Madrid', 'Kigali',
+ 'Ho Chi Minh City', 'Sarajevo', 'Delhi', 'Istanbul', 'Ho Chi Minh City', 'Khartoum', 'Helsinki',
+ 'Doha', 'Istanbul', 'Kuala Lumpur', 'Budapest', 'Shanghai', 'Moscow', 'Los Angeles', 'Oslo',
+ 'Johannesburg', 'Berlin', 'Bangalore', 'Tokyo', 'Melbourne', 'Barcelona', 'Chicago', 'Port Louis',
+ 'Lisbon', 'Nairobi', 'Kampala', 'Lima', 'Maputo', 'Vancouver', 'Dubai', 'Khartoum', 'Jakarta',
+ 'Madrid', 'Yerevan', 'Beirut', 'Athens', 'Chicago', 'Paris', 'Bucharest', 'Copenhagen', 'Brussels',
+ 'Damascus', 'Seattle', 'Los Angeles', 'Yerevan', 'Victoria', 'Tunis', 'Astana', 'Seoul',
+ 'Buenos Aires', 'Bangkok', 'Colombo', 'Brussels', 'Khartoum', 'Doha', 'San Francisco', 'Vienna', 'Jakarta'
+]
+QUESTION_TEMPLATE = "What is the special magic {city} number? The special magic {city} number is "
+NEEDLE_TEMPLATE = "The special magic {city} number is: {rnd_number}"
+@dataclass
+class NeedleEvalConfig(FairseqDataclass):
+ needle_num: int = field(
+ default=4,
+ metadata={"help":"needle number"}
+ )
+ tokens_per_sample: int = field(
+ default=16384,
+ )
+ interval: int = field(
+ default=1024,
+ )
+ needle_file_path: str = field(
+ default="/mnt/msranlp/yutao/data/PaulGrahamEssays",
+ )
+
+def random_partition(total, n):
+ cuts = random.sample(range(1, total), n - 1)
+ cuts.sort()
+ cuts = [0] + cuts + [total]
+ parts = [cuts[i+1] - cuts[i] for i in range(n)]
+ return parts
+
+@register_criterion("multi_needle", dataclass=NeedleEvalConfig)
+class NeedleEvalCriterion(FairseqCriterion):
+ def __init__(self, cfg: NeedleEvalConfig, task):
+ super().__init__(task)
+ self.cfg = cfg
+ self.essay_list = os.listdir(cfg.needle_file_path) * 5000
+
+ def generate_garbage(self, length):
+ current_text = ""
+ current_length = 0
+ while True:
+ essay = random.choice(self.essay_list)
+ essay = open(os.path.join(self.cfg.needle_file_path, essay)).read().splitlines()
+ for line in essay:
+ tokens = self.task.tokenizer.encode(line + " ")
+ if current_length + len(tokens) > length:
+ return current_text
+ current_text += line + " "
+ current_length += len(tokens)
+
+ def generate_prompt_landmark(self, first_length_list, second_length_list, final_length):
+ """Generates a text file and inserts an passkey at a random position."""
+ lines = []
+ citys = random.sample(RANDOM_NEEDLE_CITIES, self.cfg.needle_num)
+ for length in first_length_list:
+ lines.append(self.generate_garbage(length))
+ city = citys.pop()
+ magic_number = random.randint(1, 50000)
+ information_line = NEEDLE_TEMPLATE.format(city=city, rnd_number=magic_number)
+ lines.append(information_line)
+
+ final_question, answer = QUESTION_TEMPLATE.format(city=city), magic_number
+
+ for length in second_length_list:
+ lines.append(self.generate_garbage(length))
+ city = citys.pop()
+ magic_number = random.randint(1, 50000)
+ information_line = NEEDLE_TEMPLATE.format(city=city, rnd_number=magic_number)
+ lines.append(information_line)
+
+
+ lines.append(self.generate_garbage(final_length))
+ lines.append(final_question)
+ context = "\n".join(lines)
+ return OURS_TEMPLATE.format(context=context), str(answer)
+
+ def forward(self, model, sample, reduce=True):
+ """Compute the loss for the given sample.
+
+ Returns a tuple with three elements:
+ 1) the loss
+ 2) the sample size, which is used as the denominator for the gradient
+ 3) logging outputs to display while training
+ """
+ model.eval()
+ all_retrieval_result = {}
+ random.seed(42)
+ for context_length in range(self.cfg.interval, self.cfg.tokens_per_sample + 1, self.cfg.interval):
+ all_length = (context_length - 150)
+ local_retrieval_result = []
+ for depth_ratio in range(1, 11):
+ prefix_length = int(all_length * depth_ratio / 11)
+ suffix_length = all_length - prefix_length
+ n_correct = 0
+ for _ in range(5):
+ if self.cfg.needle_num > 1:
+ first_needle_num = random.randint(1, self.cfg.needle_num - 1)
+ second_needle_num = self.cfg.needle_num + 1 - first_needle_num
+ first_length_list = random_partition(prefix_length, first_needle_num)
+ second_length_list = random_partition(suffix_length, second_needle_num)
+ final_length = second_length_list.pop()
+ else:
+ first_length_list = [prefix_length]
+ second_length_list = []
+ final_length = suffix_length
+ prompt, pass_key = self.generate_prompt_landmark(first_length_list, second_length_list, final_length)
+ prompt_tokens = self.task.tokenizer.encode(prompt, bos=True)
+ prompt_tokens = torch.tensor([prompt_tokens], device="cuda")
+ print(prompt_tokens.shape)
+ output = self.generate(model, prompt_tokens)
+ pred = self.task.tokenizer.decode(output[0, prompt_tokens.shape[1]:])
+ print("Answer: ", pass_key)
+ print("Pred: ", pred)
+ if pass_key in pred:
+ n_correct += 1
+ local_retrieval_result.append(n_correct / 5)
+ all_retrieval_result[context_length] = local_retrieval_result
+
+ print(all_retrieval_result)
+ return 0, 1, {"loss": 0}
+
+ def generate(self, model, net_input, generate_tokens=20, chunk_length = 32768):
+ output_tokens = torch.cat((net_input, torch.full((net_input.shape[0], generate_tokens), self.task.tokenizer.pad_id).long().cuda()), dim=1)
+ begin_pad_index = torch.where(output_tokens == self.task.tokenizer.pad_id)[1].min().item()
+ incremental_state = {}
+ eos_reached = torch.tensor([False] * net_input.shape[0], device="cuda")
+ # prefilling
+ for begin_index in range(0, begin_pad_index - 1, chunk_length):
+ end_index = min(begin_index + chunk_length, begin_pad_index - 1)
+ _, _ = model(output_tokens[:, begin_index : end_index], incremental_state=incremental_state, start_pos=begin_index, skip_cross_decoder=True, is_prefilling=True)
+ # generation
+ for index in range(begin_pad_index, output_tokens.shape[1]):
+ generation_net_output, _ = model(output_tokens[:, index - 1].unsqueeze(-1), incremental_state=incremental_state, start_pos=index - 1, skip_cross_decoder=False, is_prefilling=False)
+ generation_net_output[:, :, self.task.tokenizer.bos_id] = -math.inf
+ generation_net_output[:, :, self.task.tokenizer.pad_id] = -math.inf
+ next_tokens = torch.argmax(generation_net_output[:, -1, :], dim=-1)
+ pad_tokens = output_tokens[:, index]
+ next_tokens = torch.where((pad_tokens == self.task.tokenizer.pad_id) & ~eos_reached, next_tokens, pad_tokens)
+ output_tokens[:, index] = next_tokens
+ eos_reached |= (
+ next_tokens == self.task.tokenizer.eos_id
+ )
+ if all(eos_reached):
+ break
+
+ return output_tokens
+
+ @staticmethod
+ def reduce_metrics(logging_outputs) -> None:
+ pass
+
+ @staticmethod
+ def logging_outputs_can_be_summed() -> bool:
+ """
+ Whether the logging outputs returned by `forward` can be summed
+ across workers prior to calling `reduce_metrics`. Setting this
+ to True will improves distributed training speed.
+ """
+ return True
\ No newline at end of file
diff --git a/YOCO/yoco/criterions/needle_haystack.py b/YOCO/yoco/criterions/needle_haystack.py
new file mode 100644
index 000000000..5cc9f231e
--- /dev/null
+++ b/YOCO/yoco/criterions/needle_haystack.py
@@ -0,0 +1,169 @@
+import os
+import random
+import math
+from dataclasses import dataclass, field
+
+import torch
+import torch.nn.functional as F
+
+from fairseq import metrics
+from fairseq.criterions import FairseqCriterion, register_criterion
+from fairseq.dataclass import FairseqDataclass
+
+OURS_TEMPLATE = "There is a special magic number inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the magic number there. {context} "
+RANDOM_NEEDLE_CITIES = [
+ 'Chicago', 'Yangon', 'Antananarivo', 'Colombo', 'Almaty', 'Sydney', 'Chicago', 'Mexico City',
+ 'Seattle', 'Lagos', 'Amsterdam', 'Belgrade', 'Cairo', 'Baghdad', 'Damascus', 'Kigali', 'Dakar',
+ 'Dakar', 'Sofia', 'Kigali', 'Victoria', 'Tashkent', 'Mumbai', 'Barcelona', 'Almaty', 'Amman',
+ 'Toronto', 'Bratislava', 'Johannesburg', 'Thimphu', 'Bangkok', 'Santiago', 'Cairo', 'San Francisco',
+ 'Lagos', 'Amsterdam', 'Paris', 'Rabat', 'Santiago', 'Copenhagen', 'Madrid', 'Kigali',
+ 'Ho Chi Minh City', 'Sarajevo', 'Delhi', 'Istanbul', 'Ho Chi Minh City', 'Khartoum', 'Helsinki',
+ 'Doha', 'Istanbul', 'Kuala Lumpur', 'Budapest', 'Shanghai', 'Moscow', 'Los Angeles', 'Oslo',
+ 'Johannesburg', 'Berlin', 'Bangalore', 'Tokyo', 'Melbourne', 'Barcelona', 'Chicago', 'Port Louis',
+ 'Lisbon', 'Nairobi', 'Kampala', 'Lima', 'Maputo', 'Vancouver', 'Dubai', 'Khartoum', 'Jakarta',
+ 'Madrid', 'Yerevan', 'Beirut', 'Athens', 'Chicago', 'Paris', 'Bucharest', 'Copenhagen', 'Brussels',
+ 'Damascus', 'Seattle', 'Los Angeles', 'Yerevan', 'Victoria', 'Tunis', 'Astana', 'Seoul',
+ 'Buenos Aires', 'Bangkok', 'Colombo', 'Brussels', 'Khartoum', 'Doha', 'San Francisco', 'Vienna', 'Jakarta'
+]
+QUESTION_TEMPLATE = "What is the special magic {city} number? The special magic {city} number is "
+# NEEDLE_TEMPLATE = "The special magic {city} number is {rnd_number} . Remember it. The special magic {city} number is {rnd_number} . "
+NEEDLE_TEMPLATE = "The special magic {city} number is {rnd_number} . "
+@dataclass
+class NeedleHaystackEvalConfig(FairseqDataclass):
+ max_len_b: int = field(
+ default=5,
+ metadata={"help":"max_len_b"}
+ )
+ tokens_per_sample: int = field(
+ default=16384,
+ )
+ interval: int = field(
+ default=1024,
+ )
+ needle_file_path: str = field(
+ default="/mnt/msranlp/yutao/data/PaulGrahamEssays",
+ )
+
+@register_criterion("needle_haystack", dataclass=NeedleHaystackEvalConfig)
+class NeedleHaystackEvalCriterion(FairseqCriterion):
+ def __init__(self, cfg: NeedleHaystackEvalConfig, task):
+ super().__init__(task)
+ self.cfg = cfg
+ self.essay_list = os.listdir(cfg.needle_file_path) * 5000
+
+ def generate_garbage(self, length):
+ current_text = ""
+ current_length = 0
+ while True:
+ essay = random.choice(self.essay_list)
+ essay = open(os.path.join(self.cfg.needle_file_path, essay)).read().splitlines()
+ for line in essay:
+ tokens = self.task.tokenizer.encode(line + " ")
+ if current_length + len(tokens) > length:
+ return current_text
+ current_text += line + " "
+ current_length += len(tokens)
+
+ def generate_prompt_landmark(self, prefix_length, suffix_length):
+ """Generates a text file and inserts an passkey at a random position."""
+ city = random.choice(RANDOM_NEEDLE_CITIES)
+ magic_number = random.randint(1, 50000)
+ garbage_prefix = self.generate_garbage(prefix_length)
+ garbage_suffix = self.generate_garbage(suffix_length)
+ information_line = NEEDLE_TEMPLATE.format(city=city, rnd_number=magic_number)
+ final_question = QUESTION_TEMPLATE.format(city=city)
+ lines = [
+ garbage_prefix,
+ information_line,
+ garbage_suffix,
+ final_question,
+ ]
+ context = "\n".join(lines)
+ return OURS_TEMPLATE.format(context=context), str(magic_number)
+
+ def forward(self, model, sample, reduce=True):
+ """Compute the loss for the given sample.
+
+ Returns a tuple with three elements:
+ 1) the loss
+ 2) the sample size, which is used as the denominator for the gradient
+ 3) logging outputs to display while training
+ """
+ model.eval()
+ all_retrieval_result = {}
+ random.seed(0)
+ for context_length in range(self.cfg.interval, self.cfg.tokens_per_sample + 1, self.cfg.interval):
+ all_length = (context_length - 150)
+ local_retrieval_result = []
+ depth_number = 10
+ for depth_ratio in range(0, depth_number + 1):
+ prefix_length = int(all_length * depth_ratio / depth_number)
+ suffix_length = all_length - prefix_length
+ n_correct = 0
+ times = 10
+ for _ in range(times):
+ prompt, pass_key = self.generate_prompt_landmark(prefix_length, suffix_length)
+ prompt_tokens = self.task.tokenizer.encode(prompt, bos=True)
+ prompt_tokens = torch.tensor([prompt_tokens], device="cuda")
+ print(prompt_tokens.shape)
+ output = self.generate(model, prompt_tokens)
+ pred = self.task.tokenizer.decode(output[0, prompt_tokens.shape[1]:])
+ print("Answer: ", pass_key)
+ print("Pred: ", pred)
+ if pass_key in pred:
+ n_correct += 1
+ local_retrieval_result.append(n_correct / times)
+ all_retrieval_result[context_length] = local_retrieval_result
+
+ print(all_retrieval_result)
+ return 0, 1, {"loss": 0}
+
+ def generate(self, model, net_input, generate_tokens=20, chunk_length = 32768):
+ output_tokens = torch.cat((net_input, torch.full((net_input.shape[0], generate_tokens), self.task.tokenizer.pad_id).long().cuda()), dim=1)
+ begin_pad_index = torch.where(output_tokens == self.task.tokenizer.pad_id)[1].min().item()
+ incremental_state = {}
+ eos_reached = torch.tensor([False] * net_input.shape[0], device="cuda")
+ # prefilling
+ for begin_index in range(0, begin_pad_index - 1, chunk_length):
+ end_index = min(begin_index + chunk_length, begin_pad_index - 1)
+ _, _ = model(output_tokens[:, begin_index : end_index], incremental_state=incremental_state, start_pos=begin_index, skip_cross_decoder=True, is_prefilling=True)
+ # generation
+ for index in range(begin_pad_index, output_tokens.shape[1]):
+ generation_net_output, _ = model(output_tokens[:, index - 1].unsqueeze(-1), incremental_state=incremental_state, start_pos=index - 1, skip_cross_decoder=False, is_prefilling=False)
+ generation_net_output[:, :, self.task.tokenizer.bos_id] = -math.inf
+ generation_net_output[:, :, self.task.tokenizer.pad_id] = -math.inf
+ next_tokens = torch.argmax(generation_net_output[:, -1, :], dim=-1)
+ pad_tokens = output_tokens[:, index]
+ next_tokens = torch.where((pad_tokens == self.task.tokenizer.pad_id) & ~eos_reached, next_tokens, pad_tokens)
+ output_tokens[:, index] = next_tokens
+ eos_reached |= (
+ next_tokens == self.task.tokenizer.eos_id
+ )
+ if all(eos_reached):
+ break
+
+ return output_tokens
+
+ @staticmethod
+ def reduce_metrics(logging_outputs) -> None:
+ """Aggregate logging outputs from data parallel training."""
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
+ metric_sum = sum(log.get("metric", 0) for log in logging_outputs)
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
+ metrics.log_scalar(
+ "loss", loss_sum / ntokens, ntokens, round=3
+ )
+ metrics.log_scalar(
+ "metric", metric_sum / nsentences, nsentences, round=3
+ )
+
+
+ @staticmethod
+ def logging_outputs_can_be_summed() -> bool:
+ """
+ Whether the logging outputs returned by `forward` can be summed
+ across workers prior to calling `reduce_metrics`. Setting this
+ to True will improves distributed training speed.
+ """
+ return True
\ No newline at end of file
diff --git a/YOCO/yoco/models/__init__.py b/YOCO/yoco/models/__init__.py
new file mode 100644
index 000000000..1ff184f30
--- /dev/null
+++ b/YOCO/yoco/models/__init__.py
@@ -0,0 +1,41 @@
+import argparse
+import importlib
+import os
+
+try:
+ from torch._six import inf
+except:
+ import sys
+ import torch
+ sys.modules["torch._six"] = torch
+ torch.string_classes = str
+
+MODEL_REGISTRY = {}
+MODEL_DATACLASS_REGISTRY = {}
+ARCH_MODEL_REGISTRY = {}
+ARCH_MODEL_NAME_REGISTRY = {}
+ARCH_MODEL_INV_REGISTRY = {}
+ARCH_CONFIG_REGISTRY = {}
+
+# automatically import any Python files in the models/ directory
+models_dir = os.path.dirname(__file__)
+for file in os.listdir(models_dir):
+ path = os.path.join(models_dir, file)
+ if (
+ not file.startswith("_")
+ and not file.startswith(".")
+ and (file.endswith(".py") or os.path.isdir(path))
+ ):
+ model_name = file[: file.find(".py")] if file.endswith(".py") else file
+ module = importlib.import_module("models." + model_name)
+
+ # extra `model_parser` for sphinx
+ if model_name in MODEL_REGISTRY:
+ parser = argparse.ArgumentParser(add_help=False)
+ group_archs = parser.add_argument_group("Named architectures")
+ group_archs.add_argument(
+ "--arch", choices=ARCH_MODEL_INV_REGISTRY[model_name]
+ )
+ group_args = parser.add_argument_group("Additional command-line arguments")
+ MODEL_REGISTRY[model_name].add_args(group_args)
+ globals()[model_name + "_parser"] = parser
\ No newline at end of file
diff --git a/YOCO/yoco/models/decoder/__init__.py b/YOCO/yoco/models/decoder/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/YOCO/yoco/models/decoder/cross_attention.py b/YOCO/yoco/models/decoder/cross_attention.py
new file mode 100644
index 000000000..09c31a893
--- /dev/null
+++ b/YOCO/yoco/models/decoder/cross_attention.py
@@ -0,0 +1,46 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from fairseq.model_parallel.megatron.mpu import (
+ ColumnParallelLinear,
+ RowParallelLinear,
+)
+
+from .model_parallel_init import init_method
+from .kernel.rotary import apply_rotary_emb
+from flash_attn import flash_attn_func
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self,
+ args,
+ ):
+ super().__init__()
+ self.args = args
+ self.embed_dim = args.dim
+ self.num_heads = args.n_attn_heads // args.model_parallel_size
+ self.num_kv_heads = args.n_attn_kv_heads // args.model_parallel_size
+
+ self.head_dim = args.dim // args.n_attn_heads
+ self.q_proj = ColumnParallelLinear(args.dim, args.dim, bias=False, gather_output=False, init_method=init_method)
+ self.out_proj = RowParallelLinear(args.dim, args.dim, bias=False, input_is_parallel=True, init_method=init_method)
+
+ def forward(
+ self,
+ x,
+ key,
+ value,
+ rel_pos
+ ):
+ bsz, tgt_len, _ = x.size()
+
+ q = self.q_proj(x)
+ q = q.view(bsz, tgt_len, self.num_heads, self.head_dim)
+ q = apply_rotary_emb(q, *rel_pos, interleaved=True)
+
+ attn = flash_attn_func(q, key, value, causal=True)
+ attn = attn.view(bsz, tgt_len, self.head_dim * self.num_heads)
+
+ attn = self.out_proj(attn)
+ return attn
\ No newline at end of file
diff --git a/YOCO/yoco/models/decoder/feedforward_network.py b/YOCO/yoco/models/decoder/feedforward_network.py
new file mode 100644
index 000000000..3972068fe
--- /dev/null
+++ b/YOCO/yoco/models/decoder/feedforward_network.py
@@ -0,0 +1,33 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from fairseq.model_parallel.megatron.mpu import (
+ ColumnParallelLinear,
+ RowParallelLinear,
+)
+
+from .kernel.swiglu import swiglu
+from .model_parallel_init import init_method
+
+class FeedForwardNetwork(nn.Module):
+ def __init__(
+ self,
+ embed_dim,
+ ffn_dim,
+ load_checkpoint=False,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.fc1 = ColumnParallelLinear(self.embed_dim, ffn_dim, bias=False, gather_output=False, init_method=init_method)
+ self.gate = ColumnParallelLinear(self.embed_dim, ffn_dim, bias=False, gather_output=False, init_method=init_method)
+ self.fc2 = RowParallelLinear(ffn_dim, self.embed_dim, bias=False, input_is_parallel=True, init_method=init_method)
+
+ def forward(self, x):
+ x_shape = x.shape
+ x = x.reshape(-1, x.size(-1))
+ x = self.fc2(swiglu(self.fc1(x), self.gate(x)))
+ output = x.view(x_shape)
+ return output
\ No newline at end of file
diff --git a/YOCO/yoco/models/decoder/gate_retention.py b/YOCO/yoco/models/decoder/gate_retention.py
new file mode 100644
index 000000000..089164c77
--- /dev/null
+++ b/YOCO/yoco/models/decoder/gate_retention.py
@@ -0,0 +1,87 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from fairseq.model_parallel.megatron.mpu import (
+ ColumnParallelLinear,
+ RowParallelLinear,
+)
+
+from .rms_norm import RMSNorm
+
+from .kernel.gate_recurrent import chunk_gate_retention, recurrent_gate_retention
+from .kernel.rotary import apply_rotary_emb
+from .kernel.swiglu import swiglu
+
+from .model_parallel_init import qkvg_init_method, out_init_method
+
+class GateRetention(nn.Module):
+
+ def __init__(
+ self,
+ args,
+ gate_logit_normalizer: int = 16,
+ ):
+ super().__init__()
+ self.args = args
+ self.embed_dim = args.dim
+ self.num_heads = args.n_self_heads // args.model_parallel_size
+ self.head_dim = args.dim // args.n_self_heads
+
+ self.q_proj = ColumnParallelLinear(args.dim, args.dim, bias=False, gather_output=False, init_method=qkvg_init_method)
+ self.k_proj = ColumnParallelLinear(args.dim, args.dim, bias=False, gather_output=False, init_method=qkvg_init_method)
+ self.v_proj = ColumnParallelLinear(args.dim, args.dim, bias=False, gather_output=False, init_method=qkvg_init_method)
+ self.g_proj = ColumnParallelLinear(args.dim, args.dim, bias=False, gather_output=False, init_method=qkvg_init_method)
+ self.gt_proj = ColumnParallelLinear(args.dim, args.n_self_heads, bias=False, gather_output=False, init_method=qkvg_init_method)
+
+ self.out_proj = RowParallelLinear(args.dim, args.dim, bias=False, input_is_parallel=True, init_method=out_init_method)
+
+ self.subln = RMSNorm(self.head_dim, elementwise_affine=False, eps=args.norm_eps)
+
+ self.gate_logit_normalizer = gate_logit_normalizer
+
+ def forward(
+ self,
+ x,
+ rel_pos,
+ incremental_state=None,
+ is_prefilling=False,
+ ):
+ bsz, tgt_len, _ = x.size()
+
+ q = self.q_proj(x)
+ k = self.k_proj(x)
+ v = self.v_proj(x)
+ g = self.g_proj(x)
+ gt = self.gt_proj(x)
+
+ qr = q.view(bsz, tgt_len, self.num_heads, self.head_dim)
+ kr = k.view(bsz, tgt_len, self.num_heads, self.head_dim)
+ v = v.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
+ gt = gt.view(bsz, tgt_len, self.num_heads).transpose(1, 2)
+
+ qr = apply_rotary_emb(qr, *rel_pos, interleaved=True).transpose(1, 2)
+ kr = apply_rotary_emb(kr, *rel_pos, interleaved=True).transpose(1, 2)
+ gt = (F.logsigmoid(gt) / self.gate_logit_normalizer)
+
+ if incremental_state is not None and not is_prefilling:
+ o = recurrent_gate_retention(qr, kr, v, gt, incremental_state)
+ else:
+ if incremental_state is not None:
+ index_mask = incremental_state["index_mask"]
+ gt_sum = gt.float().masked_fill(index_mask, 0).sum(dim=-1, keepdim=True)
+ gt_mask = (gt_sum - gt.float().cumsum(dim=-1)).exp().masked_fill(index_mask, 0)
+ next_hidden_state = (kr.transpose(-1, -2) * (self.head_dim ** -0.5)) @ (v * gt_mask.to(v.dtype).unsqueeze(-1))
+ if "last_hidden_state" in incremental_state:
+ last_hidden_state = incremental_state["last_hidden_state"]
+ next_hidden_state += last_hidden_state * gt_sum.exp().unsqueeze(-1).to(v.dtype) if last_hidden_state is not None else 0
+ else:
+ last_hidden_state = None
+ incremental_state["last_hidden_state"] = next_hidden_state
+ o = chunk_gate_retention(qr, kr, v, gt, chunk_size=256, last_hidden_state=last_hidden_state)
+ else:
+ o = chunk_gate_retention(qr, kr, v, gt, chunk_size=256)
+
+ o = self.subln(o).transpose(1, 2).reshape(bsz, tgt_len, self.num_heads * self.head_dim)
+ o = swiglu(g, o)
+ o = self.out_proj(o)
+ return o
diff --git a/YOCO/yoco/models/decoder/kernel/gate_recurrent.py b/YOCO/yoco/models/decoder/kernel/gate_recurrent.py
new file mode 100644
index 000000000..304131ccc
--- /dev/null
+++ b/YOCO/yoco/models/decoder/kernel/gate_recurrent.py
@@ -0,0 +1,302 @@
+import time
+from typing import Optional
+
+import torch
+import triton
+import triton.language as tl
+
+torch.backends.cudnn.allow_tf32 = True
+
+@triton.jit
+def _fwd_recurrence(
+ S, d,
+ O,
+ NUM_HEAD, NUM_BLOCK,
+ D_MODEL_K: tl.constexpr, D_MODEL_V: tl.constexpr,
+ BLOCK_MODEL_K: tl.constexpr, BLOCK_MODEL_V: tl.constexpr,
+ last_kv: Optional[tl.tensor]
+ ):
+ offset_bh = tl.program_id(0)
+ offset_d = tl.program_id(1)
+ offset_s = tl.program_id(2)
+
+ S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL_K + tl.arange(0, BLOCK_MODEL_K)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL_V + tl.arange(0, BLOCK_MODEL_V)[None, :]
+
+ O = O + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL_K + tl.arange(0, BLOCK_MODEL_K)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL_V + tl.arange(0, BLOCK_MODEL_V)[None, :]
+
+ if last_kv is not None:
+ last_kv = last_kv + offset_bh * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL_K + tl.arange(0, BLOCK_MODEL_K)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL_V + tl.arange(0, BLOCK_MODEL_V)[None, :]
+ acc = tl.load(last_kv).to(tl.float32)
+ else:
+ acc = tl.zeros([BLOCK_MODEL_K, BLOCK_MODEL_V], dtype=tl.float32)
+
+ tl.store(O, acc.to(O.dtype.element_ty))
+ O += D_MODEL_K * D_MODEL_V
+ d = d + offset_bh * NUM_BLOCK
+ for i in range(NUM_BLOCK-1):
+ d_i = tl.load(d)
+ S_i = tl.load(S)
+ acc = acc * d_i + S_i
+ tl.store(O, acc.to(O.dtype.element_ty))
+ d += 1
+ S += D_MODEL_K * D_MODEL_V
+ O += D_MODEL_K * D_MODEL_V
+
+
+## NUM_SPLIT_K/V. K/V dimension split into NUM_SPLIT_K/V parts with equal size BLOCK_MODEL
+@triton.jit
+def _bwd_recurrence(
+ S, d,
+ DI, DG, DL, DS,
+ NUM_HEAD, NUM_BLOCK,
+ D_MODEL_K: tl.constexpr, D_MODEL_V: tl.constexpr,
+ BLOCK_MODEL_K: tl.constexpr, BLOCK_MODEL_V: tl.constexpr,
+
+ ):
+ offset_bh = tl.program_id(0)
+ offset_d = tl.program_id(1)
+ offset_s = tl.program_id(2)
+
+ # offset_h = offset_bh % NUM_HEAD
+ NUM_K = D_MODEL_K // BLOCK_MODEL_K
+ NUM_V = D_MODEL_V // BLOCK_MODEL_V
+ # skip the last chunk because it is never used
+ S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL_K + tl.arange(0, BLOCK_MODEL_K)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL_V + tl.arange(0, BLOCK_MODEL_V)[None, :] + (NUM_BLOCK - 2) * D_MODEL_K * D_MODEL_V
+
+ DI = DI + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL_K + tl.arange(0, BLOCK_MODEL_K)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL_V + tl.arange(0, BLOCK_MODEL_V)[None, :] + (NUM_BLOCK - 2) * D_MODEL_K * D_MODEL_V
+
+ # start from the last chunk
+ DS = DS + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL_K + tl.arange(0, BLOCK_MODEL_K)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL_V + tl.arange(0, BLOCK_MODEL_V)[None, :] + (NUM_BLOCK - 1) * D_MODEL_K * D_MODEL_V
+
+ DG = DG + offset_bh * NUM_BLOCK * NUM_K * NUM_V + offset_d * NUM_V + offset_s + (NUM_BLOCK - 2) * NUM_K * NUM_V
+
+ d = d + offset_bh * NUM_BLOCK + (NUM_BLOCK - 1)
+
+ Dacc = tl.zeros([BLOCK_MODEL_K, BLOCK_MODEL_V], dtype=tl.float32)
+
+ # ignore the first chunk
+ for i in range(NUM_BLOCK - 1):
+ S_i = tl.load(S)
+ DS_i = tl.load(DS)
+ d_i = tl.load(d)
+ Dacc = Dacc * d_i + DS_i
+ DG_i = tl.sum(Dacc * S_i.to(tl.float32))
+
+ tl.store(DG, DG_i.to(DG.dtype.element_ty))
+ tl.store(DI, Dacc.to(DI.dtype.element_ty))
+
+ S -= D_MODEL_K * D_MODEL_V
+ DI -= D_MODEL_K * D_MODEL_V
+ DS -= D_MODEL_K * D_MODEL_V
+ DG -= NUM_K * NUM_V
+ d -= 1
+
+ DL = DL + offset_bh * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL_K + tl.arange(0, BLOCK_MODEL_K)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL_V + tl.arange(0, BLOCK_MODEL_V)[None, :]
+ DS_i = tl.load(DS)
+ d_i = tl.load(d)
+ Dacc = Dacc * d_i + DS_i
+ tl.store(DL, Dacc.to(DL.dtype.element_ty))
+
+class ChunkGateRecurrent(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, kv, cross_decay, last_kv=None):
+ cross_decay = cross_decay.contiguous()
+ kv = kv.contiguous()
+
+ B, H, N, D_k, D_v = kv.shape
+ output = torch.empty_like(kv)
+ BLOCK_MODEL_K = 64
+ BLOCK_MODEL_V = 16
+
+ assert D_k % BLOCK_MODEL_K == 0
+ assert D_v % BLOCK_MODEL_V == 0
+
+ grid = (B*H, D_k//BLOCK_MODEL_K, D_v//BLOCK_MODEL_V)
+ ctx.grid = grid
+ ctx.have_last_kv = last_kv is not None
+ ctx.BLOCK_MODEL_K = BLOCK_MODEL_K
+ ctx.BLOCK_MODEL_V = BLOCK_MODEL_V
+
+ _fwd_recurrence[grid](
+ kv,
+ cross_decay,
+ output,
+ D_MODEL_K=D_k, D_MODEL_V=D_v,
+ NUM_BLOCK=N, NUM_HEAD=H,
+ BLOCK_MODEL_K=BLOCK_MODEL_K,
+ BLOCK_MODEL_V=BLOCK_MODEL_V,
+ last_kv=last_kv
+ )
+
+ ctx.save_for_backward(output, cross_decay)
+ return output
+
+ @staticmethod
+ def backward(ctx, DO):
+ DO = DO.contiguous()
+
+ output, cross_decay = ctx.saved_tensors
+
+ B, H, N, D_k, D_v = output.shape
+
+ BLOCK_MODEL_K = 64
+ BLOCK_MODEL_V = 16
+
+ grid = (B*H, D_k//BLOCK_MODEL_K, D_v//BLOCK_MODEL_V)
+
+ DI = torch.empty_like(DO)
+ DG = torch.empty(B*H, N, D_k//BLOCK_MODEL_K, D_v//BLOCK_MODEL_V, device=cross_decay.device, dtype=cross_decay.dtype)
+ DL = torch.empty(B, H, D_k, D_v, device=output.device, dtype=output.dtype)
+ _bwd_recurrence[grid](
+ output, cross_decay,
+ DI, DG, DL, DO,
+ NUM_HEAD=H, NUM_BLOCK = N,
+ D_MODEL_K = D_k,
+ D_MODEL_V = D_v,
+ BLOCK_MODEL_K=BLOCK_MODEL_K,
+ BLOCK_MODEL_V=BLOCK_MODEL_V,
+ )
+
+ DI[:, :, -1] = 0
+ DG[:, -1] = 0
+ DG = DG.view(B, H, N, -1).sum(dim=-1)
+ return DI, DG, DL if ctx.have_last_kv else None
+
+def cross_chunk(q, k, v, g, last_hidden_state=None):
+ kv = k.transpose(-1, -2) @ (v * (-g + g[:, :, :, -1, None]).exp()[..., None].to(v.dtype))
+ cross_decay = g[:, :, :, -1].exp().to(kv.dtype)
+ S = chunk_gate_recurrent(kv, cross_decay, last_hidden_state)
+ cross = (q * g[..., None].exp().to(q.dtype)) @ S
+ return cross
+
+@torch.compile
+def inner_chunk(q, k, v, g):
+ attn = q @ k.transpose(-1, -2)
+ causal_mask = torch.full([q.shape[-2], q.shape[-2]], float("-inf"), device=q.device).triu(1).type_as(q)
+ attn = attn * (g[..., None] - g[..., None, :] + causal_mask).exp().to(attn.dtype)
+ inner = attn @ v
+ return inner
+
+def chunk_gate_retention(q, k, v, g, chunk_size=64, last_hidden_state=None):
+ bsz, num_head, tgt_len, key_dim = q.shape
+ head_dim = v.shape[-1]
+ num_chunk = tgt_len // chunk_size
+ q = q.view(bsz, num_head, num_chunk, chunk_size, key_dim)
+ k = k.view(bsz, num_head, num_chunk, chunk_size, key_dim) * (key_dim ** -0.5)
+ v = v.view(bsz, num_head, num_chunk, chunk_size, head_dim)
+ g = g.view(bsz, num_head, num_chunk, chunk_size)
+ g = g.float().cumsum(-1)
+ cross = cross_chunk(q, k, v, g, last_hidden_state=last_hidden_state)
+ inner = inner_chunk(q, k, v, g)
+ o = cross + inner
+ return o.view(bsz, num_head, tgt_len, head_dim)
+
+# for long sequence parallelism
+def hier_chunk_gate_retention(q, k, v, g, chunk_size=64, hier_chunk_size=16384):
+ bsz, num_head, tgt_len, key_dim = q.shape
+ head_dim = v.shape[-1]
+ num_hier_chunk = tgt_len // hier_chunk_size
+ assert tgt_len == num_hier_chunk * hier_chunk_size
+
+ q = q.view(bsz, num_head, num_hier_chunk, hier_chunk_size, key_dim)
+ k = k.view(bsz, num_head, num_hier_chunk, hier_chunk_size, key_dim)
+ v = v.view(bsz, num_head, num_hier_chunk, hier_chunk_size, head_dim)
+ g = g.view(bsz, num_head, num_hier_chunk, hier_chunk_size)
+ hier_cross = cross_chunk(q, k * (key_dim ** -0.5), v, g.float().cumsum(-1)).view(bsz, num_head, tgt_len, head_dim)
+
+ qi = q.transpose(1, 2).reshape(bsz * num_hier_chunk, num_head, hier_chunk_size, key_dim)
+ ki = k.transpose(1, 2).reshape(bsz * num_hier_chunk, num_head, hier_chunk_size, key_dim)
+ vi = v.transpose(1, 2).reshape(bsz * num_hier_chunk, num_head, hier_chunk_size, head_dim)
+ gi = g.transpose(1, 2).reshape(bsz * num_hier_chunk, num_head, hier_chunk_size)
+ inner_cross = chunk_gate_retention(qi, ki, vi, gi, chunk_size)
+
+ inner_cross = inner_cross.view(bsz, num_hier_chunk, num_head, hier_chunk_size, head_dim).transpose(1, 2).reshape(bsz, num_head, tgt_len, head_dim)
+ o = hier_cross + inner_cross
+ return o
+
+def recurrent_gate_retention(q, k, v, g, incremental_state):
+ bsz, num_head, _, key_dim = q.shape
+ k *= key_dim ** -0.5
+ g = g.view(bsz, num_head, 1, 1).float().exp()
+ kv = k.transpose(-1, -2) * v
+ if "last_hidden_state" in incremental_state:
+ prev_kv = incremental_state["last_hidden_state"]
+ kv += prev_kv * g.to(prev_kv.dtype)
+
+ incremental_state["last_hidden_state"] = kv
+ o = q @ kv
+ return o
+
+def parallel_gate_retention(q, k, v, g):
+ k = k * (q.shape[-1] ** -0.5)
+ causal_mask = torch.full([q.shape[-2], q.shape[-2]], float("-inf"), device=q.device).triu(1).type_as(q)
+ g = g.float().cumsum(-1)
+ mask = g[..., None] - g[..., None, :] + causal_mask
+ mask = mask.exp()
+
+ attn = q @ k.transpose(-1, -2)
+ attn = attn * mask.to(attn.dtype)
+ o = attn @ v
+ return o
+
+def naive_kv_recurrent(kv, cross_decay, last_kv=None):
+ BSZ, NUM_HEAD, NUM_BLOCK, D_MODEL_K, D_MODEL_V = kv.shape
+ kv_recurrent = []
+ kv_state = torch.zeros(BSZ, NUM_HEAD, D_MODEL_K, D_MODEL_V, dtype=kv.dtype, device="cuda") if last_kv is None else last_kv
+ # accumulate kv by loop
+ for i in range(NUM_BLOCK):
+ kv_recurrent.append(kv_state)
+ kv_state = kv_state * cross_decay[:, :, i, None, None] + kv[:, :, i]
+
+ kv_recurrent = torch.stack(kv_recurrent, dim=2)
+ return kv_recurrent
+
+chunk_gate_recurrent = ChunkGateRecurrent.apply
+
+def main():
+ BSZ = 4
+ NUM_HEAD = 4
+ NUM_BLOCK = 16
+ D_MODEL_K = 256
+ D_MODEL_V = 432
+ dtype = torch.float16
+ kv = torch.randn(BSZ, NUM_HEAD, NUM_BLOCK, D_MODEL_K, D_MODEL_V, dtype=dtype, device="cuda")
+ last_kv = torch.randn(BSZ, NUM_HEAD, D_MODEL_K, D_MODEL_V, dtype=dtype, device="cuda")
+ kv_triton = kv.clone().detach()
+ last_kv_triton = last_kv.clone().detach()
+ cross_decay = torch.randn(BSZ, NUM_HEAD, NUM_BLOCK, dtype=dtype, device="cuda")
+ cross_decay = torch.sigmoid(cross_decay)
+ cross_decay_triton = cross_decay.clone().detach()
+ grad_weight = torch.randn(BSZ, NUM_HEAD, NUM_BLOCK, D_MODEL_K, D_MODEL_V, dtype=dtype, device="cuda")
+ kv.requires_grad = True
+ kv_triton.requires_grad = True
+ last_kv.requires_grad = True
+ last_kv_triton.requires_grad = True
+ cross_decay.requires_grad = True
+ cross_decay_triton.requires_grad = True
+
+ start = time.time()
+ kv_recurrent = naive_kv_recurrent(kv, cross_decay, last_kv)
+ kv_recurrent.mul(grad_weight).sum().backward()
+ print("naive time:", time.time() - start)
+
+ start = time.time()
+ kv_recurrent_triton = chunk_gate_recurrent(kv_triton, cross_decay_triton, last_kv_triton)
+ kv_recurrent_triton.mul(grad_weight).sum().backward()
+ print("triton time:", time.time() - start)
+
+ print(torch.allclose(kv_recurrent, kv_recurrent_triton, atol=1e-3))
+ print((kv_recurrent - kv_recurrent_triton).abs().max(), (kv_recurrent - kv_recurrent_triton).abs().mean())
+
+ print(torch.allclose(kv.grad, kv_triton.grad, atol=1e-3))
+ print((kv.grad - kv_triton.grad).abs().max(), (kv.grad - kv_triton.grad).abs().mean())
+
+ print(torch.allclose(last_kv.grad, last_kv_triton.grad, atol=1e-3))
+ print((last_kv.grad - last_kv_triton.grad).abs().max(), (last_kv.grad - last_kv_triton.grad).abs().mean())
+
+ print(torch.allclose(cross_decay.grad, cross_decay_triton.grad, atol=1e-3))
+ print((cross_decay.grad - cross_decay_triton.grad).abs().max(), (cross_decay.grad - cross_decay_triton.grad).abs().mean())
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/YOCO/yoco/models/decoder/kernel/rotary.py b/YOCO/yoco/models/decoder/kernel/rotary.py
new file mode 100644
index 000000000..8ee2cb938
--- /dev/null
+++ b/YOCO/yoco/models/decoder/kernel/rotary.py
@@ -0,0 +1,332 @@
+# Copyright (c) 2023, Tri Dao.
+
+from typing import Optional, Union
+
+import torch
+
+import triton
+import triton.language as tl
+
+
+# @triton.autotune(
+# configs=[
+# triton.Config({"BLOCK_M": 2}),
+# triton.Config({"BLOCK_M": 4}),
+# triton.Config({"BLOCK_M": 8}),
+# triton.Config({"BLOCK_M": 16}),
+# ],
+# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"],
+# )
+@triton.jit
+def rotary_kernel(
+ OUT, # Pointers to matrices
+ X,
+ COS,
+ SIN,
+ CU_SEQLENS,
+ SEQLEN_OFFSETS, # this could be int or a pointer
+ # Matrix dimensions
+ seqlen,
+ nheads,
+ rotary_dim,
+ seqlen_ro,
+ CACHE_KEY_SEQLEN,
+ # strides
+ stride_out_batch,
+ stride_out_seqlen,
+ stride_out_nheads,
+ stride_out_headdim,
+ stride_x_batch,
+ stride_x_seqlen,
+ stride_x_nheads,
+ stride_x_headdim,
+ # Meta-parameters
+ BLOCK_K: tl.constexpr,
+ IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
+ IS_VARLEN: tl.constexpr,
+ INTERLEAVED: tl.constexpr,
+ CONJUGATE: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+):
+ pid_m = tl.program_id(axis=0)
+ pid_batch = tl.program_id(axis=1)
+ pid_head = tl.program_id(axis=2)
+ rotary_dim_half = rotary_dim // 2
+
+ if not IS_VARLEN:
+ X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
+ OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
+ else:
+ start_idx = tl.load(CU_SEQLENS + pid_batch)
+ seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
+ X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
+ OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
+
+ if pid_m * BLOCK_M >= seqlen:
+ return
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ if not IS_SEQLEN_OFFSETS_TENSOR:
+ rm_cs = rm + SEQLEN_OFFSETS
+ else:
+ rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
+ rk = tl.arange(0, BLOCK_K)
+ rk_half = tl.arange(0, BLOCK_K // 2)
+
+ if not INTERLEAVED:
+ # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
+ X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)
+ COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
+ SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
+ cos = tl.load(
+ COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0
+ ).to(tl.float32)
+ sin = tl.load(
+ SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0
+ ).to(tl.float32)
+ x0 = tl.load(
+ X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0
+ ).to(tl.float32)
+ x1 = tl.load(
+ X + rotary_dim_half * stride_x_headdim,
+ mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
+ other=0.0,
+ ).to(tl.float32)
+ if CONJUGATE:
+ sin = -sin
+ o0 = x0 * cos - x1 * sin
+ o1 = x0 * sin + x1 * cos
+ # write back result
+ OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
+ tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
+ tl.store(
+ OUT + rotary_dim_half * stride_out_headdim,
+ o1,
+ mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
+ )
+ else:
+ # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
+ # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
+ # Loading x0 will be fast but x1 will be slow.
+ # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
+ # Then we do the calculation and use tl.where to pick put the right outputs for the even
+ # and for the odd indices.
+ rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
+ rk_repeat = tl.arange(0, BLOCK_K) // 2
+ X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
+ X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
+ COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
+ SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
+ cos = tl.load(
+ COS,
+ mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
+ other=1.0,
+ ).to(tl.float32)
+ sin = tl.load(
+ SIN,
+ mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
+ other=0.0,
+ ).to(tl.float32)
+ x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(
+ tl.float32
+ )
+ x1 = tl.load(
+ X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
+ ).to(tl.float32)
+ if CONJUGATE:
+ sin = -sin
+ x0_cos = x0 * cos
+ x1_sin = x1 * sin
+ out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
+ OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
+ tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
+
+
+def apply_rotary(
+ x: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ interleaved=False,
+ inplace=False,
+ conjugate=False,
+) -> torch.Tensor:
+ """
+ Arguments:
+ x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
+ else (total_seqlen, nheads, headdim).
+ cos: (seqlen_ro, rotary_dim / 2)
+ sin: (seqlen_ro, rotary_dim / 2)
+ seqlen_offsets: integer or integer tensor of size (batch,)
+ cu_seqlens: (batch + 1,) or None
+ max_seqlen: int
+ Returns:
+ y: (batch, seqlen, nheads, headdim)
+ """
+ is_varlen = cu_seqlens is not None
+ if not is_varlen:
+ batch, seqlen, nheads, headdim = x.shape
+ else:
+ assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed"
+ total_seqlen, nheads, headdim = x.shape
+ batch_p_1 = cu_seqlens.shape[0]
+ batch = batch_p_1 - 1
+ seqlen = max_seqlen
+ seqlen_ro, rotary_dim = cos.shape
+ assert sin.shape == cos.shape
+ rotary_dim *= 2
+ assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
+ assert headdim <= 256, "Only support headdim <= 256"
+ assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
+
+ assert (
+ cos.dtype == sin.dtype
+ ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
+ assert (
+ x.dtype == cos.dtype
+ ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
+
+ cos, sin = cos.contiguous(), sin.contiguous()
+ if isinstance(seqlen_offsets, torch.Tensor):
+ assert seqlen_offsets.shape == (batch,)
+ assert seqlen_offsets.dtype in [torch.int32, torch.int64]
+ seqlen_offsets = seqlen_offsets.contiguous()
+ else:
+ assert seqlen_offsets + seqlen <= seqlen_ro
+
+ output = torch.empty_like(x) if not inplace else x
+ if rotary_dim < headdim and not inplace:
+ output[..., rotary_dim:].copy_(x[..., rotary_dim:])
+
+ BLOCK_K = (
+ 32
+ if rotary_dim <= 32
+ else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
+ )
+ grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
+ BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
+
+ # Need this, otherwise Triton tries to launch from cuda:0 and we get
+ # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
+ with torch.cuda.device(x.device.index):
+ rotary_kernel[grid](
+ output, # data ptrs
+ x,
+ cos,
+ sin,
+ cu_seqlens,
+ seqlen_offsets,
+ seqlen, # shapes
+ nheads,
+ rotary_dim,
+ seqlen_ro,
+ seqlen // 128, # key for triton cache (limit number of compilations)
+ output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
+ output.stride(-3), # seqlen_stride or total_seqlen_stride
+ output.stride(-2), # nheads_stride
+ output.stride(-1), # headdim_stride
+ x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
+ x.stride(-3), # seqlen stride or total_seqlen_stride
+ x.stride(-2), # nheads stride
+ x.stride(-1), # headdim stride
+ BLOCK_K,
+ isinstance(seqlen_offsets, torch.Tensor),
+ is_varlen,
+ interleaved,
+ conjugate,
+ BLOCK_M,
+ )
+ return output
+
+class ApplyRotaryEmb(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ cos,
+ sin,
+ interleaved=False,
+ inplace=False,
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ ):
+ out = apply_rotary(
+ x,
+ cos,
+ sin,
+ seqlen_offsets=seqlen_offsets,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ interleaved=interleaved,
+ inplace=inplace,
+ )
+ if isinstance(seqlen_offsets, int):
+ # Can't save int with save_for_backward
+ ctx.save_for_backward(cos, sin, cu_seqlens)
+ ctx.seqlen_offsets = seqlen_offsets
+ else:
+ ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
+ ctx.seqlen_offsets = None
+ ctx.interleaved = interleaved
+ ctx.inplace = inplace
+ ctx.max_seqlen = max_seqlen
+ return out if not inplace else x
+
+ @staticmethod
+ def backward(ctx, do):
+ seqlen_offsets = ctx.seqlen_offsets
+ if seqlen_offsets is None:
+ cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
+ else:
+ cos, sin, cu_seqlens = ctx.saved_tensors
+ # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
+ # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
+ if not ctx.interleaved and not ctx.inplace:
+ do = do.clone()
+ dx = apply_rotary(
+ do,
+ cos,
+ sin,
+ seqlen_offsets=seqlen_offsets,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=ctx.max_seqlen,
+ interleaved=ctx.interleaved,
+ inplace=ctx.inplace,
+ conjugate=True,
+ )
+ return dx, None, None, None, None, None, None, None
+
+
+def apply_rotary_emb(
+ x,
+ cos,
+ sin,
+ interleaved=False,
+ inplace=False,
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+):
+ """
+ Arguments:
+ x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
+ else (total_seqlen, nheads, headdim)
+ cos, sin: (seqlen_rotary, rotary_dim / 2)
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
+ of 1st half and 2nd half (GPT-NeoX style).
+ inplace: if True, apply rotary embedding in-place.
+ seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
+ Most commonly used in inference when we have KV cache.
+ cu_seqlens: (batch + 1,) or None
+ max_seqlen: int
+ Return:
+ out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
+ else (total_seqlen, nheads, headdim)
+ rotary_dim must be <= headdim
+ Apply rotary embedding to the first rotary_dim of x.
+ """
+ return ApplyRotaryEmb.apply(
+ x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
+ )
diff --git a/YOCO/yoco/models/decoder/kernel/swiglu.py b/YOCO/yoco/models/decoder/kernel/swiglu.py
new file mode 100644
index 000000000..d57589d21
--- /dev/null
+++ b/YOCO/yoco/models/decoder/kernel/swiglu.py
@@ -0,0 +1,32 @@
+import torch
+
+
+swiglu_fwd_codestring = """
+template T swiglu_fwd(T x, T y) {
+ return float(x) * float(y) / (1.0f + ::exp(-float(x)));
+}
+"""
+swiglu_bwd_codestring = """
+template T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
+ float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
+ dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
+ dy = float(x) * x_sigmoid * float(g);
+}
+"""
+swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring)
+swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2)
+
+
+class SwiGLUFunction(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, y):
+ ctx.save_for_backward(x, y)
+ return swiglu_fwd(x, y)
+
+ @staticmethod
+ def backward(ctx, dout):
+ x, y = ctx.saved_tensors
+ return swiglu_bwd(x, y, dout)
+
+swiglu = SwiGLUFunction.apply
diff --git a/YOCO/yoco/models/decoder/model_parallel_init.py b/YOCO/yoco/models/decoder/model_parallel_init.py
new file mode 100644
index 000000000..3eb50a854
--- /dev/null
+++ b/YOCO/yoco/models/decoder/model_parallel_init.py
@@ -0,0 +1,16 @@
+import math
+
+import torch
+import torch.nn as nn
+
+def init_method(tensor, **kwargs):
+ nn.init.kaiming_uniform_(tensor, a=math.sqrt(5))
+
+def qkvg_init_method(tensor, **kwargs):
+ nn.init.xavier_uniform_(tensor, gain = 2 ** -2.5)
+
+def out_init_method(tensor, **kwargs):
+ nn.init.xavier_uniform_(tensor, gain = 2 ** -1)
+
+def vocab_init_method(tensor, **kwargs):
+ torch.nn.init.normal_(tensor, mean=0, std=tensor.shape[1] ** -0.5)
diff --git a/YOCO/yoco/models/decoder/rms_norm.py b/YOCO/yoco/models/decoder/rms_norm.py
new file mode 100644
index 000000000..fccb027ec
--- /dev/null
+++ b/YOCO/yoco/models/decoder/rms_norm.py
@@ -0,0 +1,26 @@
+import torch
+import torch.nn as nn
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True):
+ super().__init__()
+ self.dim = dim
+ self.eps = eps
+ self.elementwise_affine = elementwise_affine
+ if self.elementwise_affine:
+ self.weight = nn.Parameter(torch.ones(dim))
+ else:
+ self.register_parameter('weight', None)
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ output = self._norm(x.float()).type_as(x)
+ if self.weight is not None:
+ output = output * self.weight
+ return output
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
+
\ No newline at end of file
diff --git a/YOCO/yoco/models/decoder/sliding_window_attention.py b/YOCO/yoco/models/decoder/sliding_window_attention.py
new file mode 100644
index 000000000..3d744956d
--- /dev/null
+++ b/YOCO/yoco/models/decoder/sliding_window_attention.py
@@ -0,0 +1,68 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from fairseq.model_parallel.megatron.mpu import (
+ ColumnParallelLinear,
+ RowParallelLinear,
+)
+
+from .model_parallel_init import init_method
+from .kernel.rotary import apply_rotary_emb
+
+from flash_attn import flash_attn_func
+
+class SlidingWindowAttention(nn.Module):
+ def __init__(self, args):
+ super().__init__()
+ self.args = args
+ self.embed_dim = args.dim
+ self.num_heads = args.n_self_heads // args.model_parallel_size
+ self.window_size = args.sliding_window - 1 # compatible with flash attention
+
+ self.head_dim = args.dim // args.n_self_heads
+
+ self.q_proj = ColumnParallelLinear(args.dim, args.dim, bias=False, gather_output=False, init_method=init_method)
+ self.k_proj = ColumnParallelLinear(args.dim, args.dim, bias=False, gather_output=False, init_method=init_method)
+ self.v_proj = ColumnParallelLinear(args.dim, args.dim, bias=False, gather_output=False, init_method=init_method)
+ self.out_proj = RowParallelLinear(args.dim, args.dim, bias=False, input_is_parallel=True, init_method=init_method)
+
+ def forward(
+ self,
+ x,
+ rel_pos,
+ start_pos=0,
+ incremental_state=None,
+ ):
+ bsz, tgt_len, embed_dim = x.size()
+ src_len = tgt_len
+
+ q = self.q_proj(x)
+ k = self.k_proj(x)
+ v = self.v_proj(x)
+
+ q = q.view(bsz, tgt_len, self.num_heads, self.head_dim)
+ k = k.view(bsz, src_len, self.num_heads, self.head_dim)
+ v = v.view(bsz, src_len, self.num_heads, self.head_dim)
+
+ q = apply_rotary_emb(q, *rel_pos, interleaved=True)
+ k = apply_rotary_emb(k, *rel_pos, interleaved=True)
+ if incremental_state is not None:
+ if "prev_key" not in incremental_state:
+ incremental_state["prev_key"] = torch.empty(self.args.max_batch_size, self.window_size, self.num_heads, self.head_dim, device=x.device, dtype=x.dtype)
+ incremental_state["prev_value"] = torch.empty(self.args.max_batch_size, self.window_size, self.num_heads, self.head_dim, device=x.device, dtype=x.dtype)
+
+ key = torch.cat([incremental_state["prev_key"][:bsz, :start_pos], k], dim=1)
+ value = torch.cat([incremental_state["prev_value"][:bsz, :start_pos], v], dim=1)
+ if key.shape[1] > self.window_size:
+ incremental_state["prev_key"][:bsz] = key[:, -self.window_size:]
+ incremental_state["prev_value"][:bsz] = value[:, -self.window_size:]
+ else:
+ incremental_state["prev_key"][:bsz, start_pos : start_pos + tgt_len] = k
+ incremental_state["prev_value"][:bsz, start_pos : start_pos + tgt_len] = v
+
+ attn = flash_attn_func(q, k, v, causal=True, window_size=(self.window_size - 1, 0))
+ attn = attn.reshape(bsz, tgt_len, self.head_dim * self.num_heads)
+
+ attn = self.out_proj(attn)
+ return attn
\ No newline at end of file
diff --git a/YOCO/yoco/models/decoder/transformer.py b/YOCO/yoco/models/decoder/transformer.py
new file mode 100644
index 000000000..f41edf583
--- /dev/null
+++ b/YOCO/yoco/models/decoder/transformer.py
@@ -0,0 +1,251 @@
+import json
+import math
+import os
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional, Tuple
+
+import torch
+from torch import nn
+
+from flash_attn import flash_attn_func
+
+from fairseq.model_parallel.megatron.mpu import (
+ ColumnParallelLinear,
+ RowParallelLinear,
+ copy_to_model_parallel_region,
+ VocabParallelEmbedding
+)
+
+from fairscale.nn import checkpoint_wrapper
+
+from .rms_norm import RMSNorm
+from .kernel.rotary import apply_rotary_emb
+from .model_parallel_init import init_method, vocab_init_method
+
+
+def precompute_freqs_cis(dim: int, end: int, theta: float) -> torch.Tensor:
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+ t = torch.arange(end, device=freqs.device) # type: ignore
+ freqs = torch.outer(t, freqs).float() # type: ignore
+ return freqs
+
+
+@dataclass
+class ModelArgs:
+ dim: int
+ n_layers: int
+ head_dim: int
+ hidden_dim: int
+ n_heads: int
+ n_kv_heads: int
+ norm_eps: float
+ vocab_size: int
+
+ max_batch_size: int = 0
+ max_seq_len: int = -1
+ model_parallel_size: int = 1
+ load_checkpoint: bool = False
+ rope_theta: float = 10000.0
+ sliding_window: Optional[int] = None
+
+
+class Attention(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.args = args
+
+ self.dim = args.dim
+ self.head_dim = args.head_dim
+ self.hidden_dim = args.n_heads * args.head_dim
+ self.key_value_dim = args.n_kv_heads * args.head_dim
+ self.n_heads = args.n_heads // args.model_parallel_size
+ self.n_kv_heads = args.n_kv_heads // args.model_parallel_size
+ self.activate_sliding_window = args.sliding_window is not None
+ self.cache_len = args.sliding_window - 1 if self.activate_sliding_window else args.max_seq_len
+
+ self.repeats = self.n_heads // self.n_kv_heads
+
+ self.scale = self.args.head_dim**-0.5
+
+ self.wq = ColumnParallelLinear(self.dim, self.hidden_dim, bias=False, gather_output=False, init_method=init_method)
+ self.wk = ColumnParallelLinear(self.dim, self.key_value_dim, bias=False, gather_output=False, init_method=init_method)
+ self.wv = ColumnParallelLinear(self.dim, self.key_value_dim, bias=False, gather_output=False, init_method=init_method)
+ self.wo = RowParallelLinear(self.hidden_dim, self.dim, bias=False, input_is_parallel=True, init_method=init_method)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ rel_pos: Tuple[torch.Tensor, torch.Tensor],
+ start_pos: int,
+ incremental_state = None,
+ ) -> torch.Tensor:
+ bsz, seqlen, _ = x.shape
+
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
+ xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
+ xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
+ xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
+ xq = apply_rotary_emb(xq, *rel_pos)
+ xk = apply_rotary_emb(xk, *rel_pos)
+ if incremental_state is not None:
+ if "cache_k" not in incremental_state:
+ incremental_state["cache_k"] = torch.zeros(
+ (
+ self.args.max_batch_size,
+ self.cache_len,
+ self.n_kv_heads,
+ self.head_dim,
+ )
+ ).to(xk)
+ incremental_state["cache_v"] = torch.zeros(
+ (
+ self.args.max_batch_size,
+ self.cache_len,
+ self.n_kv_heads,
+ self.head_dim,
+ )
+ ).to(xv)
+ key = torch.cat([incremental_state["cache_k"][:, :start_pos], xk], dim=1)
+ value = torch.cat([incremental_state["cache_v"][:, :start_pos], xv], dim=1)
+ if key.shape[1] > self.cache_len:
+ incremental_state["cache_k"][:bsz] = key[:, -self.cache_len:]
+ incremental_state["cache_v"][:bsz] = value[:, -self.cache_len:]
+ else:
+ incremental_state["cache_k"][:bsz, start_pos : start_pos + seqlen] = xk
+ incremental_state["cache_v"][:bsz, start_pos : start_pos + seqlen] = xv
+ else:
+ key, value = xk, xv
+
+ output = flash_attn_func(xq, key, value, causal=True, window_size=(self.args.sliding_window - 1, 0) if self.activate_sliding_window else (-1, -1))
+
+ return self.wo(output.view(bsz, seqlen, self.n_heads * self.head_dim))
+
+
+class FeedForward(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.w1 = ColumnParallelLinear(args.dim, args.hidden_dim, bias=False, gather_output=False, init_method=init_method)
+ self.w2 = RowParallelLinear(args.hidden_dim, args.dim, bias=False, input_is_parallel=True, init_method=init_method)
+ self.w3 = ColumnParallelLinear(args.dim, args.hidden_dim, bias=False, gather_output=False, init_method=init_method)
+
+ def forward(self, x) -> torch.Tensor:
+ return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.n_heads = args.n_heads
+ self.dim = args.dim
+ self.attention = Attention(args)
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
+ self.args = args
+
+ self.feed_forward: nn.Module
+ self.feed_forward = FeedForward(args=args)
+
+ def forward(
+ self, x: torch.Tensor, rel_pos: Tuple[torch.Tensor, torch.Tensor], start_pos: int, incremental_state = None
+ ) -> torch.Tensor:
+ r = self.attention.forward(self.attention_norm(x), rel_pos, start_pos, incremental_state)
+ h = x + r
+ r = self.feed_forward.forward(self.ffn_norm(h))
+ out = h + r
+ return out
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ args: ModelArgs,
+ mp_rank: int = 0,
+ checkpoint_activations: bool = False
+ ):
+ super().__init__()
+ self.args = args
+ self.vocab_size = args.vocab_size
+ self.n_layers = args.n_layers
+ self._precomputed_freqs_cis: Optional[torch.Tensor] = None
+ self._window_precomputed_freqs_cis: Optional[torch.Tensor] = None
+ self._global_precomputed_freqs_cis: Optional[torch.Tensor] = None
+ assert self.vocab_size > 0
+ self.mp_rank = mp_rank
+ self.checkpoint_activations = checkpoint_activations
+ self.tok_embeddings = VocabParallelEmbedding(
+ args.vocab_size, args.dim, -1, init_method=vocab_init_method
+ )
+ self.norm = RMSNorm(args.dim, eps=args.norm_eps)
+ self.output = nn.Linear(args.dim, args.vocab_size // args.model_parallel_size, bias=False)
+ # Initialize all layers but slice off those not of this rank.
+ layers = [TransformerBlock(args=args) for idx in range(args.n_layers)]
+ if checkpoint_activations:
+ layers = [checkpoint_wrapper(layer) for layer in layers]
+ self.layers = nn.ModuleList(layers)
+ self.n_local_layers = len(self.layers)
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return next(self.parameters()).dtype
+
+ @property
+ def device(self) -> torch.device:
+ return next(self.parameters()).device
+
+ def build_rel_pos(self, x, start_pos):
+ if self._precomputed_freqs_cis is None:
+ theta = self.args.rope_theta
+ self._precomputed_freqs_cis = precompute_freqs_cis(
+ self.args.head_dim, self.args.max_seq_len, theta
+ )
+ if self._precomputed_freqs_cis.device != self.device:
+ self._precomputed_freqs_cis = self._precomputed_freqs_cis.to(
+ device=self.device
+ )
+ cos = torch.cos(self._precomputed_freqs_cis[start_pos:start_pos+x.size(1)])
+ sin = torch.sin(self._precomputed_freqs_cis[start_pos:start_pos+x.size(1)])
+ rel_pos = (cos.to(x.dtype), sin.to(x.dtype))
+ return rel_pos
+
+ def forward_partial(
+ self,
+ input_ids: torch.Tensor,
+ start_pos: Optional[int] = 0,
+ incremental_state = None,
+ ) -> torch.Tensor:
+ h = self.tok_embeddings(input_ids)
+ rel_pos = self.build_rel_pos(h, start_pos)
+ for local_layer_id, layer in enumerate(self.layers):
+ if incremental_state is not None:
+ if local_layer_id not in incremental_state:
+ incremental_state[local_layer_id] = {}
+ h = layer(h, rel_pos, start_pos, incremental_state=incremental_state[local_layer_id] if incremental_state is not None else None)
+
+ return self.norm(h)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ start_pos: Optional[int] = 0,
+ incremental_state = None,
+ ) -> torch.Tensor:
+ h = self.forward_partial(input_ids, start_pos, incremental_state)
+ if self.args.model_parallel_size > 1:
+ h = copy_to_model_parallel_region(h)
+ outs = self.output(h)
+ return outs.float(), None
+
+ def load_state_dict(self, state_dict, strict=False, assign=False):
+ state_to_load = {}
+ for k, v in state_dict.items():
+ if k.startswith("tok_embeddings") or k.startswith("output"):
+ state_to_load[k] = v.view(self.args.model_parallel_size, self.vocab_size // self.args.model_parallel_size, self.args.dim)[self.mp_rank]
+ elif "wq" in k or "wk" in k or "wv" in k or "w1" in k or "w3" in k:
+ state_to_load[k] = v.view(self.args.model_parallel_size, -1, v.shape[1])[self.mp_rank]
+ elif "wo" in k or "w2" in k:
+ state_to_load[k] = v.view(v.shape[0], self.args.model_parallel_size, -1)[:, self.mp_rank]
+ else:
+ state_to_load[k] = v
+ super().load_state_dict(state_to_load, strict=False, assign=assign)
+ print("Loaded state dict from checkpoint.")
diff --git a/YOCO/yoco/models/decoder/yoco.py b/YOCO/yoco/models/decoder/yoco.py
new file mode 100644
index 000000000..6fb0d01e8
--- /dev/null
+++ b/YOCO/yoco/models/decoder/yoco.py
@@ -0,0 +1,294 @@
+import math
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from fairscale.nn import checkpoint_wrapper
+
+from fairseq.model_parallel.megatron.mpu import (
+ ColumnParallelLinear,
+ copy_to_model_parallel_region,
+ VocabParallelEmbedding
+)
+
+from .gate_retention import GateRetention
+from .sliding_window_attention import SlidingWindowAttention
+from .cross_attention import CrossAttention
+from .feedforward_network import FeedForwardNetwork, init_method
+from .rms_norm import RMSNorm
+
+from .kernel.rotary import apply_rotary_emb
+from .model_parallel_init import vocab_init_method, init_method
+
+
+@dataclass
+class YOCOArgs:
+ dim: int
+ n_layers: int
+ hidden_dim: int
+ n_self_heads: int
+ n_attn_heads: int
+ n_attn_kv_heads: int
+ vocab_size: int
+
+ max_batch_size: int = 0
+ max_seq_len: int = -1
+ model_parallel_size: int = 1
+ load_checkpoint: bool = False
+ rope_theta: float = 10000.0
+ norm_eps: float = 1e-5
+ sliding_window: Optional[int] = None
+
+class DecoderLayer(nn.Module):
+ def __init__(
+ self,
+ args: YOCOArgs,
+ is_cross_layer=False
+ ):
+ super().__init__()
+ self.args = args
+ self.is_cross_layer = is_cross_layer
+
+ if is_cross_layer:
+ self.mixer = CrossAttention(args)
+ elif args.sliding_window is not None:
+ self.mixer = SlidingWindowAttention(args)
+ else:
+ self.mixer = GateRetention(args)
+
+ self.mixer_layer_norm = RMSNorm(args.dim, eps=args.norm_eps)
+
+ self.ffn = FeedForwardNetwork(
+ args.dim,
+ args.hidden_dim,
+ args.load_checkpoint
+ )
+
+ self.final_layer_norm = RMSNorm(args.dim, eps=args.norm_eps)
+
+ def forward(
+ self,
+ x,
+ start_pos=0,
+ key=None,
+ value=None,
+ rel_pos=None,
+ incremental_state=None,
+ is_prefilling=False,
+ ):
+ residual = x
+ x = self.mixer_layer_norm(x)
+
+ if self.is_cross_layer:
+ x = self.mixer(
+ x,
+ key,
+ value,
+ rel_pos=rel_pos,
+ )
+ elif self.args.sliding_window is not None:
+ x = self.mixer(
+ x,
+ rel_pos=rel_pos,
+ start_pos=start_pos,
+ incremental_state=incremental_state,
+ )
+ else:
+ x = self.mixer(
+ x,
+ rel_pos=rel_pos,
+ incremental_state=incremental_state,
+ is_prefilling=is_prefilling,)
+
+ x = x + residual
+ residual = x
+ x = self.final_layer_norm(x)
+
+ x = self.ffn(x)
+
+ x = x + residual
+ return x
+
+class SelfDecoder(nn.Module):
+ def __init__(
+ self,
+ args: YOCOArgs,
+ checkpoint_activations: bool = False
+ ):
+ super().__init__()
+ self.args = args
+ layers = [DecoderLayer(args, is_cross_layer=False,) for idx in range(args.n_layers // 2)]
+ if checkpoint_activations:
+ layers = [checkpoint_wrapper(layer) for layer in layers]
+ self.layers = nn.ModuleList(layers)
+ self.head_dim = args.dim // args.n_self_heads
+ self.block_size = 256
+ self._precomputed_freqs_cis = None
+
+ def build_rel_pos(self, x, start_pos):
+ if self._precomputed_freqs_cis is None:
+ angle = 1.0 / (self.args.rope_theta ** torch.linspace(0, 1, self.head_dim // 2, dtype=torch.float, device=x.device))
+ index = torch.arange(self.args.max_seq_len).to(angle)
+ self._precomputed_freqs_cis = index[:, None] * angle
+
+ cos = torch.cos(self._precomputed_freqs_cis[start_pos:start_pos+x.size(1)])
+ sin = torch.sin(self._precomputed_freqs_cis[start_pos:start_pos+x.size(1)])
+ rel_pos = (cos.to(x.dtype), sin.to(x.dtype))
+ return rel_pos
+
+ def get_index_mask(self, x, length, pad_length):
+ return torch.arange(pad_length, device=x.device) >= length
+
+ def forward(
+ self,
+ x,
+ incremental_state=None,
+ is_prefilling=False,
+ start_pos=0
+ ):
+ if is_prefilling and x.size(1) % self.block_size != 0 and self.args.sliding_window is None:
+ padding_len = self.block_size - x.size(1) % self.block_size
+ x = F.pad(x, (0, 0, 0, padding_len), value=0)
+ else:
+ padding_len = 0
+
+ if incremental_state is not None and is_prefilling:
+ index_mask = self.get_index_mask(x, x.size(1) - padding_len, x.size(1))
+
+ rel_pos = self.build_rel_pos(x, start_pos)
+ for idx, layer in enumerate(self.layers):
+ if incremental_state is not None:
+ if idx not in incremental_state:
+ incremental_state[idx] = {}
+ if is_prefilling:
+ incremental_state[idx]["index_mask"] = index_mask
+ x = layer(
+ x,
+ start_pos=start_pos,
+ rel_pos=rel_pos,
+ incremental_state=incremental_state[idx] if incremental_state is not None else None,
+ is_prefilling=is_prefilling,)
+
+ x = x[:, :x.size(1) - padding_len, :]
+ return x
+
+class CrossDecoder(nn.Module):
+ def __init__(
+ self,
+ args: YOCOArgs,
+ checkpoint_activations: bool = False
+ ):
+ super().__init__()
+ self.args = args
+ self.num_heads = args.n_attn_kv_heads
+ self.head_dim = args.dim // args.n_attn_heads
+ self.k_proj = ColumnParallelLinear(args.dim, self.head_dim * args.n_attn_kv_heads, bias=False, gather_output=False, init_method=init_method)
+ self.v_proj = ColumnParallelLinear(args.dim, self.head_dim * args.n_attn_kv_heads, bias=False, gather_output=False, init_method=init_method)
+ self.kv_layer_norm = RMSNorm(args.dim, eps=args.norm_eps)
+ layers = [DecoderLayer(args, is_cross_layer=True) for idx in range(args.n_layers // 2)]
+ if checkpoint_activations:
+ layers = [checkpoint_wrapper(layer) for layer in layers]
+ self.layers = nn.ModuleList(layers)
+ self._precomputed_freqs_cis = None
+
+ def build_rel_pos(self, x, start_pos):
+ if self._precomputed_freqs_cis is None:
+ angle = 1.0 / (self.args.rope_theta ** torch.linspace(0, 1, self.head_dim // 2, dtype=torch.float, device=x.device))
+ index = torch.arange(self.args.max_seq_len).to(angle)
+ self._precomputed_freqs_cis = index[:, None] * angle
+
+ cos = torch.cos(self._precomputed_freqs_cis[start_pos:start_pos+x.size(1)])
+ sin = torch.sin(self._precomputed_freqs_cis[start_pos:start_pos+x.size(1)])
+ rel_pos = (cos.to(x.dtype), sin.to(x.dtype))
+ return rel_pos
+
+ def forward(
+ self,
+ x,
+ incremental_state=None,
+ start_pos=0,
+ skip_cross_decoder=False,
+ ):
+ bsz, seqlen, embed_dim = x.size()
+ x_norm = self.kv_layer_norm(x)
+ key, value = self.k_proj(x_norm), self.v_proj(x_norm)
+ key = key.view(bsz, seqlen, self.num_heads, self.head_dim)
+ value = value.view(bsz, seqlen, self.num_heads, self.head_dim)
+ rel_pos = self.build_rel_pos(x, start_pos)
+ key = apply_rotary_emb(key, *rel_pos, interleaved=True)
+ if incremental_state is not None:
+ if "prev_key" not in incremental_state:
+ incremental_state["prev_key"] = torch.empty(bsz, self.args.max_seq_len, self.num_heads, self.head_dim, device=x.device, dtype=x.dtype)
+ incremental_state["prev_value"] = torch.empty(bsz, self.args.max_seq_len, self.num_heads, self.head_dim, device=x.device, dtype=x.dtype)
+ incremental_state["prev_key"][:, start_pos : start_pos + seqlen] = key
+ incremental_state["prev_value"][:, start_pos : start_pos + seqlen] = value
+ key = incremental_state["prev_key"][:, : start_pos + seqlen]
+ value = incremental_state["prev_value"][:, : start_pos + seqlen]
+
+ if skip_cross_decoder:
+ return torch.zeros(bsz, 1, embed_dim, device=x.device, dtype=x.dtype)
+ for layer in self.layers:
+ x = layer(
+ x,
+ key=key,
+ value=value,
+ rel_pos=rel_pos)
+
+ return x
+
+class YOCO(nn.Module):
+ def __init__(
+ self,
+ args: YOCOArgs,
+ checkpoint_activations: bool = False,
+ share_input_output_embed: bool = False,
+ ):
+ super().__init__()
+ self.args = args
+ self.embed_scale = math.sqrt(args.dim)
+ self.embed_tokens = VocabParallelEmbedding(
+ args.vocab_size, args.dim, -1, init_method=vocab_init_method
+ )
+ self.output_projection = nn.Linear(args.dim, args.vocab_size, bias=False)
+ if share_input_output_embed:
+ self.output_projection.weight = self.embed_tokens.weight
+
+ self.self_decoder = SelfDecoder(args, checkpoint_activations)
+ self.cross_decoder = CrossDecoder(args, checkpoint_activations)
+ self.layer_norm = RMSNorm(args.dim, eps=args.norm_eps)
+
+ def forward(
+ self,
+ x,
+ start_pos=0,
+ incremental_state=None,
+ is_prefilling=True,
+ skip_cross_decoder=False
+ ):
+ x = self.embed_scale * self.embed_tokens(x)
+
+ x = self.self_decoder(
+ x,
+ incremental_state=incremental_state,
+ is_prefilling=is_prefilling,
+ start_pos=start_pos,
+ )
+
+ x = self.cross_decoder(
+ x,
+ start_pos=start_pos,
+ incremental_state=incremental_state,
+ skip_cross_decoder=skip_cross_decoder,
+ )
+
+ x = self.layer_norm(x)
+ x = self.output_layer(x)
+
+ return x, None
+
+ def output_layer(self, features):
+ if self.args.model_parallel_size > 1:
+ features = copy_to_model_parallel_region(features)
+ return self.output_projection(features)
\ No newline at end of file
diff --git a/YOCO/yoco/models/transformer.py b/YOCO/yoco/models/transformer.py
new file mode 100644
index 000000000..3fcaa78dc
--- /dev/null
+++ b/YOCO/yoco/models/transformer.py
@@ -0,0 +1,141 @@
+import json
+import os
+from dataclasses import dataclass, field
+from typing import Optional
+
+import torch
+
+from fairseq.model_parallel.megatron.mpu import (
+ initialize_model_parallel,
+ model_parallel_is_initialized,
+ get_model_parallel_rank
+)
+
+from fairseq.dataclass import FairseqDataclass
+from fairseq.models import (
+ FairseqIncrementalDecoder,
+ FairseqLanguageModel,
+ register_model,
+ register_model_architecture,
+)
+
+from omegaconf import II
+
+from .decoder.transformer import ModelArgs, Transformer
+
+DEFAULT_MAX_TARGET_POSITIONS = 4096
+
+@dataclass
+class LanguageConfig(FairseqDataclass):
+ llama_model: Optional[str] = field(
+ default=None,
+ metadata={"help": "path to load tokenizer and config"},
+ )
+ load_ckpt: Optional[str] = field(
+ default=None,
+ metadata={"help": "path to load checkpoint from"},
+ )
+ init_from_config: bool = field(
+ default=False,
+ )
+ dim: int = field(
+ default=1024,
+ )
+ n_layers: int = field(
+ default=8,
+ )
+ n_heads: int = field(
+ default=8,
+ )
+ n_kv_heads: int = field(
+ default=2,
+ )
+ batch_size: int = field(
+ default=1,
+ )
+ rope_theta: Optional[float] = field(
+ default=10000.0,
+ )
+ checkpoint_activations: bool = field(
+ default=False, metadata={"help": "checkpoint activations at each layer"}
+ )
+ tokens_per_sample: int = II("task.tokens_per_sample")
+ model_parallel_size: int = II("common.model_parallel_size")
+
+@register_model("llama", dataclass=LanguageConfig)
+class LanguageModel(FairseqLanguageModel):
+ def __init__(self, args, decoder, tokenizer):
+ self.args = args
+ self.tokenizer = tokenizer
+ super().__init__(decoder)
+
+ @classmethod
+ def build_model(cls, args, task):
+ if not model_parallel_is_initialized():
+ initialize_model_parallel(args.model_parallel_size)
+
+ if not args.init_from_config:
+ params = {
+ "dim": args.dim,
+ "n_layers": args.n_layers,
+ "n_heads": args.n_heads,
+ "head_dim": args.dim // args.n_heads,
+ "n_kv_heads": args.n_kv_heads,
+ "hidden_dim": int(args.dim * 8 / 3),
+ "vocab_size": task.tokenizer.n_words,
+ "max_batch_size": args.batch_size,
+ "max_seq_len": args.tokens_per_sample,
+ "model_parallel_size": args.model_parallel_size,
+ "load_checkpoint": args.load_ckpt is not None,
+ "rope_theta": args.rope_theta,
+ }
+ model_args: ModelArgs = ModelArgs(
+ **params,
+ )
+ else:
+ with open(os.path.join(args.llama_model, "params.json"), "r") as f:
+ params = json.load(f)
+ model_args = ModelArgs(**params)
+ model_args.max_batch_size = args.batch_size
+ model_args.max_seq_len = args.tokens_per_sample
+ model_args.model_parallel_size = args.model_parallel_size
+ model_args.load_checkpoint = args.load_ckpt is not None
+ model = Transformer(
+ model_args,
+ mp_rank=get_model_parallel_rank(),
+ checkpoint_activations=args.checkpoint_activations,
+ )
+ if args.load_ckpt is not None:
+ loaded = torch.load(args.load_ckpt, mmap=True)
+ model.load_state_dict(loaded, assign=True)
+
+ model = LLaMA(model)
+ return cls(args, model, task.tokenizer)
+
+class LLaMA(FairseqIncrementalDecoder):
+ def __init__(self, model):
+ super().__init__(None)
+ self.model = model
+
+ def forward(self, src_tokens, start_pos = 0, **kwargs):
+ padding = src_tokens < 0
+ src_tokens = torch.where(padding, torch.zeros_like(src_tokens), src_tokens)
+ return self.model.forward(src_tokens, start_pos, **kwargs)
+
+ def max_positions(self):
+ return self.model.args.max_seq_len
+
+@register_model_architecture("llama", "llama_from_scratch")
+def llama_from_scratch(args):
+ args.init_from_config = getattr(args, "init_from_config", False)
+ args.dim = getattr(args, "dim", 1024)
+ args.n_layers = getattr(args, "n_layers", 8)
+ args.n_heads = getattr(args, "n_heads", 8)
+ args.n_kv_heads = getattr(args, "n_kv_heads", 2)
+
+@register_model_architecture("llama", "llama_from_ckpt")
+def llama_from_ckpt(args):
+ args.init_from_config = getattr(args, "init_from_config", True)
+
+
+
\ No newline at end of file
diff --git a/YOCO/yoco/models/yoco.py b/YOCO/yoco/models/yoco.py
new file mode 100644
index 000000000..580d15bd2
--- /dev/null
+++ b/YOCO/yoco/models/yoco.py
@@ -0,0 +1,158 @@
+import os
+import json
+import logging
+from dataclasses import dataclass, field
+from typing import Optional
+
+import torch
+from fairseq import distributed_utils, utils
+from fairseq.dataclass import FairseqDataclass
+from fairseq.models import (
+ FairseqIncrementalDecoder,
+ FairseqLanguageModel,
+ register_model,
+ register_model_architecture,
+)
+
+from omegaconf import II
+
+from fairseq.model_parallel.megatron.mpu import (
+ initialize_model_parallel,
+ model_parallel_is_initialized
+)
+from .decoder.yoco import YOCO, YOCOArgs
+
+DEFAULT_MAX_TARGET_POSITIONS = 4096
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class LanguageConfig(FairseqDataclass):
+ yoco_model: Optional[str] = field(
+ default=None,
+ metadata={"help": "path to load params from"},
+ )
+ load_ckpt: Optional[str] = field(
+ default=None,
+ metadata={"help": "path to load checkpoint from"},
+ )
+ dim: int = field(
+ default=1024,
+ )
+ hidden_dim: int = field(
+ default=3072,
+ )
+ n_layers: int = field(
+ default=24,
+ )
+ n_self_heads: int = field(
+ default=4,
+ )
+ n_attn_heads: int = field(
+ default=8,
+ )
+ n_attn_kv_heads: Optional[int] = field(
+ default=None,
+ )
+ batch_size: int = field(
+ default=1,
+ )
+ share_input_output_embed: bool = field(
+ default=False, metadata={"help": "share decoder input and output embeddings"}
+ )
+ sliding_window: Optional[bool] = field(
+ default=None,
+ )
+ rope_theta: Optional[float] = field(
+ default=10000.0,
+ )
+ checkpoint_activations: bool = field(
+ default=False, metadata={"help": "checkpoint activations at each layer"}
+ )
+ tokens_per_sample: int = II("task.tokens_per_sample")
+ model_parallel_size: int = II("common.model_parallel_size")
+
+
+@register_model("yoco", dataclass=LanguageConfig)
+class LanguageModel(FairseqLanguageModel):
+ def __init__(self, args, decoder, tokenizer):
+ self.args = args
+ self.tokenizer = tokenizer
+ super().__init__(decoder)
+
+ @classmethod
+ def build_model(cls, args, task):
+ if not model_parallel_is_initialized():
+ initialize_model_parallel(args.model_parallel_size)
+
+ if args.yoco_model is None:
+ params = {
+ "dim": args.dim,
+ "n_layers": args.n_layers,
+ "n_self_heads": args.n_self_heads,
+ "n_attn_heads": args.n_attn_heads,
+ "n_attn_kv_heads": args.n_attn_kv_heads,
+ "hidden_dim": args.hidden_dim,
+ "vocab_size": task.tokenizer.n_words,
+ "max_batch_size": args.batch_size,
+ "max_seq_len": args.tokens_per_sample,
+ "model_parallel_size": args.model_parallel_size,
+ "load_checkpoint": args.load_ckpt is not None,
+ "rope_theta": args.rope_theta,
+ }
+ model_args: YOCOArgs = YOCOArgs(
+ **params,
+ )
+ else:
+ with open(os.path.join(args.yoco_model, "params.json"), "r") as f:
+ params = json.load(f)
+ model_args = YOCOArgs(**params)
+ model_args.max_batch_size = args.batch_size
+ model_args.max_seq_len = args.tokens_per_sample
+ model_args.model_parallel_size = args.model_parallel_size
+ model_args.load_checkpoint = args.load_ckpt is not None
+
+ model = YOCO(
+ model_args,
+ checkpoint_activations=args.checkpoint_activations,
+ )
+ if args.load_ckpt is not None:
+ loaded = torch.load(args.load_ckpt, mmap=True)
+ model.load_state_dict(loaded, assign=True)
+ model = YOCOModel(model)
+ return cls(args, model, task.tokenizer)
+
+class YOCOModel(FairseqIncrementalDecoder):
+ def __init__(self, model):
+ super().__init__(None)
+ self.model = model
+
+ def forward(self, src_tokens, **kwargs):
+ return self.model.forward(src_tokens, **kwargs)
+
+ def max_positions(self):
+ return self.model.args.max_seq_len
+
+def default(args):
+ args.n_attn_kv_heads = getattr(args, "n_attn_kv_heads", args.n_attn_heads)
+ args.sliding_window = getattr(args, "sliding_window", False)
+ args.rope_theta = getattr(args, "rope_theta", 10000.0)
+ args.share_input_output_embed = getattr(
+ args, "share_input_output_embed", False
+ )
+ args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
+
+
+@register_model_architecture("yoco", "yoco_3b")
+def yoco_3b(args):
+ args.dim = getattr(args, "dim", 3072)
+ args.hidden_dim = getattr(args, "hidden_dim", 8192)
+ args.n_layers = getattr(args, "n_layers", 26)
+ args.n_self_heads = getattr(args, "n_self_heads", 24)
+ args.n_attn_heads = getattr(args, "n_attn_heads", 24)
+ args.n_attn_kv_heads = getattr(args, "n_attn_kv_heads", 8)
+ default(args)
+
+
+
+
diff --git a/YOCO/yoco/tasks/__init__.py b/YOCO/yoco/tasks/__init__.py
new file mode 100644
index 000000000..1da9d1238
--- /dev/null
+++ b/YOCO/yoco/tasks/__init__.py
@@ -0,0 +1,32 @@
+import argparse
+import importlib
+import os
+
+# register dataclass
+TASK_DATACLASS_REGISTRY = {}
+TASK_REGISTRY = {}
+TASK_CLASS_NAMES = set()
+
+# automatically import any Python files in the tasks/ directory
+tasks_dir = os.path.dirname(__file__)
+for file in os.listdir(tasks_dir):
+ path = os.path.join(tasks_dir, file)
+ if (
+ not file.startswith("_")
+ and not file.startswith(".")
+ and (file.endswith(".py") or os.path.isdir(path))
+ ):
+ task_name = file[: file.find(".py")] if file.endswith(".py") else file
+ module = importlib.import_module("tasks." + task_name)
+
+ # expose `task_parser` for sphinx
+ if task_name in TASK_REGISTRY:
+ parser = argparse.ArgumentParser(add_help=False)
+ group_task = parser.add_argument_group("Task name")
+ # fmt: off
+ group_task.add_argument('--task', metavar=task_name,
+ help='Enable this task with: ``--task=' + task_name + '``')
+ # fmt: on
+ group_args = parser.add_argument_group("Additional command-line arguments")
+ TASK_REGISTRY[task_name].add_args(group_args)
+ globals()[task_name + "_parser"] = parser
diff --git a/YOCO/yoco/tasks/data/__init__.py b/YOCO/yoco/tasks/data/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/YOCO/yoco/tasks/data/basic_loader.py b/YOCO/yoco/tasks/data/basic_loader.py
new file mode 100644
index 000000000..d6f06f2ac
--- /dev/null
+++ b/YOCO/yoco/tasks/data/basic_loader.py
@@ -0,0 +1,75 @@
+import torch
+from infinibatch.iterators import CheckpointableIterator
+
+from . import utils
+
+
+class BaseBatchGen(CheckpointableIterator):
+ """
+ This is a base class for batch generators that use infinibatch
+ """
+
+ def __init__(self):
+ self._iter = None
+ self.epoch = 1
+ self.next_epoch_idx = 1
+ self.sharded_checkpoint = True
+ self.should_close_after_finished = True
+
+ def _build_iter(self):
+ """
+ Build infinibatch iterator and assign to self._iter
+ """
+ raise NotImplementedError()
+
+ def _move_to_tensor(self, batch):
+ def to_tensor(x):
+ return torch.tensor(x)
+
+ return utils.apply_to_sample(to_tensor, batch)
+
+ @property
+ def iterator(self):
+ if self._iter is None:
+ raise NotImplementedError("_build_iter() must called first")
+ return self._iter
+
+ def __iter__(self):
+ if self._iter is None:
+ raise NotImplementedError("_build_iter() must called first")
+ return self._iter
+
+ def __next__(self):
+ return next(self._iter)
+
+ def setstate(self, value):
+ self._iter.setstate(value)
+
+ def getstate(self):
+ return self._iter.getstate()
+
+ def close(self):
+ self._iter.close()
+
+ def __len__(self) -> int:
+ return 819200000
+
+ def next_epoch_itr(
+ self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True
+ ):
+ return self
+
+ def end_of_epoch(self) -> bool:
+ return False
+
+ def state_dict(self):
+ """Returns a dictionary containing a whole state of the iterator."""
+ return self.getstate()
+
+ def load_state_dict(self, state_dict):
+ """Copies the state of the iterator from the given *state_dict*."""
+ self.setstate(state_dict)
+
+ @property
+ def first_batch(self):
+ return "DUMMY"
diff --git a/YOCO/yoco/tasks/data/llama_tokenizer.py b/YOCO/yoco/tasks/data/llama_tokenizer.py
new file mode 100644
index 000000000..fad3d206b
--- /dev/null
+++ b/YOCO/yoco/tasks/data/llama_tokenizer.py
@@ -0,0 +1,38 @@
+from pathlib import Path
+from sentencepiece import SentencePieceProcessor
+from typing import List
+
+
+class LLaMATokenizer:
+ def __init__(self, model_path: str):
+ assert Path(model_path).exists(), model_path
+ self._model = SentencePieceProcessor(model_file=model_path)
+ assert self._model.vocab_size() == self._model.get_piece_size()
+
+ @property
+ def n_words(self) -> int:
+ return self._model.vocab_size()
+
+ @property
+ def bos_id(self) -> int:
+ return self._model.bos_id()
+
+ @property
+ def eos_id(self) -> int:
+ return self._model.eos_id()
+
+ @property
+ def pad_id(self) -> int:
+ return self._model.pad_id()
+
+ def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
+ assert isinstance(s, str)
+ t = self._model.encode(s)
+ if bos:
+ t = [self.bos_id, *t]
+ if eos:
+ t = [*t, self.eos_id]
+ return t
+
+ def decode(self, t: List[int]) -> str:
+ return self._model.decode(t)
\ No newline at end of file
diff --git a/YOCO/yoco/tasks/data/lm_loader.py b/YOCO/yoco/tasks/data/lm_loader.py
new file mode 100644
index 000000000..825a82239
--- /dev/null
+++ b/YOCO/yoco/tasks/data/lm_loader.py
@@ -0,0 +1,303 @@
+import os
+import random
+import math
+import numpy as np
+import json
+
+from infinibatch import iterators
+from .utils import FixedBlockwiseShuffleIterator, NativeCheckpointableIterator, WeightNoRandomStateIterator
+from .basic_loader import BaseBatchGen
+
+
+class LMLoader(BaseBatchGen):
+ def __init__(
+ self,
+ args,
+ dataset,
+ tokenizer,
+ max_tokens=None,
+ max_sentences=None,
+ max_positions=None,
+ ignore_invalid_inputs=False,
+ required_batch_size_multiple=1,
+ seed=1,
+ epoch=1,
+ num_shards=1,
+ shard_id=0,
+ reject_sampling=1,
+ ):
+ super().__init__()
+ self.args = args
+ self.data = dataset.data
+ self.data_dir = dataset.data_dir
+ self.shuffle = dataset.shuffle
+ self.tokenizer = tokenizer
+
+ self.max_tokens = max_tokens
+ self.max_sentences = max_sentences
+ self.max_positions = max_positions
+ self.tokens_per_sample = args.tokens_per_sample
+ self.mlm_cut_length = getattr(args, "mlm_cut_length", 0)
+ self.mlm_tokens_proportion = getattr(args, "mlm_tokens_proportion", 0)
+ self.pad_to_max_len = getattr(args, "pad_to_max_len", False)
+ self.ignore_invalid_inputs = ignore_invalid_inputs
+ self.required_batch_size_multiple = required_batch_size_multiple
+ self.seed = str(seed)
+ self.epoch = epoch
+ self.num_shards = num_shards
+ self.shard_id = shard_id
+
+ self.batch_read_ahead = args.batch_read_ahead
+ self.sharded_checkpoint = True
+
+ self._build_iter()
+
+ def _build_iter(self):
+ tokenized_lines = self._tokenize()
+ self.padded_batches = self._batchify(tokenized_lines)
+
+ prefetch_batches = iterators.PrefetchIterator(
+ self.padded_batches,
+ buffer_size=10,
+ buffer_in_main_process=True,
+ log_empty_buffer_warning=True and self.shard_id == 0,
+ )
+
+ prefetch_batches = iterators.MapIterator(
+ prefetch_batches, self._move_to_tensor
+ )
+
+ self._iter = prefetch_batches
+
+ def _tokenize(self):
+ '''
+ data:
+ {
+ 'source': list[Path],
+ }
+ '''
+ dataset = list(zip(self.data['source']))
+
+ if self.shuffle:
+ chunk_files = \
+ iterators.InfinitePermutationSourceIterator(
+ dataset,
+ seed=self.seed,
+ shuffle=self.shuffle,
+ num_instances=self.num_shards,
+ instance_rank=self.shard_id,
+ )
+ else:
+ chunk_files = \
+ iterators.ChunkedSourceIterator(
+ dataset,
+ num_instances=self.num_shards,
+ instance_rank=self.shard_id,
+ )
+
+ tokenized_lines = iterators.SelectManyIterator(chunk_files, lambda files: self._read_from_files(*files))
+ tokenized_lines = iterators.SamplingRandomMapIterator(tokenized_lines, self._prepare, self.seed)
+
+ return tokenized_lines
+
+ def getstate(self):
+ state = super().getstate()
+ state["epoch"] = self.epoch
+ state["iterations_in_epoch"] = None
+ return state
+
+ def _batchify(self, lines):
+
+ if self.max_sentences is not None:
+ if self.batch_read_ahead > 0:
+ lines = FixedBlockwiseShuffleIterator(lines, self.batch_read_ahead, self.seed)
+ batches = iterators.FixedBatchIterator(lines, self.max_sentences)
+ else:
+ # -
+ def dynamic_batch_size(sample):
+ lengths = [len(x) for x in sample]
+ batch_size = self.max_tokens // max(lengths) // self.required_batch_size_multiple * self.required_batch_size_multiple
+ return max(1, batch_size)
+
+ batches = iterators.BucketedReadaheadBatchIterator(
+ lines,
+ read_ahead=self.batch_read_ahead,
+ key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None,
+ batch_size=dynamic_batch_size,
+ shuffle=self.shuffle,
+ seed=self.seed,
+ )
+
+ def collate(batch):
+ batch_size = len(batch)
+ gpt_max_length = max([len(x[0]) for x in batch])
+ if self.pad_to_max_len:
+ gpt_max_length = self.tokens_per_sample + 1
+
+ gpt_source_ids = np.full(shape=(batch_size, gpt_max_length-1), dtype=np.int32,
+ fill_value=self.tokenizer.pad_id)
+ gpt_target_ids = np.full(shape=(batch_size, gpt_max_length-1), dtype=np.int32,
+ fill_value=self.tokenizer.pad_id)
+ gpt_input_mask_all = np.full(shape=(batch_size, gpt_max_length-1), dtype=np.int32, fill_value=0)
+ gpt_loss_mask_all = np.full(shape=(batch_size, gpt_max_length-1), dtype=np.int32, fill_value=1)
+
+ for i, (gpt_ids, gpt_input_mask, gpt_loss_mask) in enumerate(batch):
+ gpt_source_ids[i, :len(gpt_ids)-1] = gpt_ids[:-1]
+ gpt_target_ids[i, :len(gpt_ids)-1] = gpt_ids[1:]
+ gpt_input_mask_all[i, :len(gpt_ids)-1] = gpt_input_mask[:-1]
+ gpt_loss_mask_all[i, :len(gpt_ids)-1] = gpt_loss_mask[1:]
+
+ ret_batch = {
+ 'net_input': {
+ 'src_tokens': gpt_source_ids.astype(np.int64),
+ },
+ 'target': gpt_target_ids.astype(np.int64),
+ 'nsentences': batch_size,
+ 'ntokens': sum([len(x[0]) for x in batch]),
+ }
+
+ return ret_batch
+
+ padded_batches = iterators.MapIterator(
+ batches, collate
+ )
+
+ return padded_batches
+
+ def _prepare(self, doc):
+ gpt_input_mask = [0] * len(doc)
+ gpt_loss_mask = [1] * len(doc)
+ full_tokens = doc
+ return full_tokens, gpt_input_mask, gpt_loss_mask
+
+ def _tokenize(self):
+ multilingual_iters = []
+ weights = []
+
+ for data in self.data:
+ multilingual_iters.append(
+ self._tokenize_foreach_lang(data)
+ )
+ if 'weight' in data:
+ weights.append(float(data['weight']))
+ else:
+ weights.append(int(data['count']))
+
+ if len(multilingual_iters) == 1:
+ return multilingual_iters[0]
+
+ sampling_iterator = WeightNoRandomStateIterator(weights, self.seed)
+ control_iterator = NativeCheckpointableIterator(sampling_iterator)
+ tokenized_lines = iterators.MultiplexIterator(control_iterator, multilingual_iters)
+
+ return tokenized_lines
+
+ def _tokenize_foreach_lang(self, data):
+ # if 'epoch' in data:
+ _random = random.Random(self.seed)
+ if 'source' not in data or len(data['source']) == 0:
+ # load source from single file, format: self.data_dir/json/{name}.json
+ file_path = os.path.join(self.data_dir, 'json', f"{data['name']}.json")
+ if not os.path.exists(file_path):
+ raise FileNotFoundError(f"file {file_path} not exists")
+ with open(file_path, 'r', encoding='utf8') as f:
+ data_source = json.load(f)
+ data['source'] = data_source
+ data_source = data['source']
+ epoch_num = 50
+ temp_list = math.ceil(epoch_num) * data_source
+ _random.shuffle(temp_list)
+ dataset = list(zip(temp_list))
+ # print('data name: ', data['name'], 'len(dataset): ', len(dataset))
+ chunk_files = iterators.ChunkedSourceIterator(
+ dataset,
+ num_instances=self.num_shards,
+ instance_rank=self.shard_id,)
+
+ tokenized_lines = iterators.SelectManyIterator(chunk_files, lambda files: self._read_from_files(*files))
+ tokenized_lines = iterators.MapIterator(tokenized_lines, self._prepare)
+
+ return tokenized_lines
+
+ @staticmethod
+ def _doc_to_ids(text, tokenizer=None):
+ tokenized_ids = [] # list of list of ids
+ lines = text.split('\n\n')
+ for line_idx, line in enumerate(lines):
+ suffix = '\n\n' if line_idx != len(lines) - 1 else ''
+ if len(line) == 0:
+ continue
+
+ sublines = line.split('\n')
+ for idx, subline in enumerate(sublines):
+ if len(subline) > 200000:
+ continue
+ if len(subline) == 0:
+ continue
+ if idx == len(sublines) - 1:
+ tokenized_ids.append(tokenizer.encode(subline + suffix))
+ else:
+ tokenized_ids.append(tokenizer.encode(subline + '\n'))
+
+ tokenized_ids[-1].append(tokenizer.eos_id)
+ return tokenized_ids
+
+ def _read_lines(self, file_path):
+ try:
+ with open(file_path, 'r', encoding='utf8') as f:
+ lines = f.read().strip().split('\n')
+ except:
+ return iter([]) # skip bad file
+ return lines
+
+ def _read_from_files(self, source_file):
+ data = []
+ if self.args.absolute_path:
+ file_path = source_file
+ else:
+ file_path = os.path.join(self.data_dir, source_file)
+
+ if not os.path.exists(file_path):
+ print('| file {} not exists'.format(file_path), flush=True)
+ return iter([]) # skip bad file
+
+ lines = self._read_lines(file_path)
+
+ tokenized_ids = []
+ for doc_jsonstr in lines:
+ try:
+ json_obj = json.loads(doc_jsonstr)
+
+ if 'text' in json_obj:
+ text = json_obj['text']
+ elif 'content' in json_obj:
+ text = json_obj['content']
+ elif 'raw_content_lines' in json_obj:
+ text = "\n".join(json_obj['raw_content_lines'])
+ else:
+ print('no text in json_obj')
+
+ if len(text) == 0:
+ continue
+ ret = LMLoader._doc_to_ids(text, self.tokenizer)
+ tokenized_ids.extend(ret)
+ except Exception as e:
+ print(source_file, flush=True)
+ print(e, flush=True)
+
+ # ###################################################
+
+ doc = [self.tokenizer.bos_id]
+ for ids in tokenized_ids:
+ if len(doc) + len(ids) > self.tokens_per_sample + 1:
+ doc.extend(ids)
+ doc = doc[:self.tokens_per_sample + 1]
+ data.append(doc)
+ doc = [self.tokenizer.bos_id]
+ else:
+ doc.extend(ids)
+
+ # if len(doc) > 1 and len(doc) <= self.tokens_per_sample + 1:
+ # data.append(doc)
+ return data
+
diff --git a/YOCO/yoco/tasks/data/tiktoken_tokenizer.py b/YOCO/yoco/tasks/data/tiktoken_tokenizer.py
new file mode 100644
index 000000000..3a041cc09
--- /dev/null
+++ b/YOCO/yoco/tasks/data/tiktoken_tokenizer.py
@@ -0,0 +1,81 @@
+import tiktoken
+from typing import List
+
+
+class TiktokenTokenizer:
+ def __init__(self,
+ tiktoken_model: str,
+ tokenizer_pad_to_multiple: int = 8,
+ bos="",
+ pad="",
+ eos="",
+ unk="",
+ ):
+ self.symbols = [bos, pad, eos, unk]
+ self.indices = {s: i for i, s in enumerate(self.symbols)}
+ self.tokenizer_pad_to_multiple = tokenizer_pad_to_multiple
+ cl100k_base = tiktoken.get_encoding(tiktoken_model)
+ self._model = tiktoken.Encoding(
+ # If you're changing the set of special tokens, make sure to use a different name
+ # It should be clear from the name what behaviour to expect.
+ name="cl100k_im",
+ pat_str=cl100k_base._pat_str,
+ mergeable_ranks=cl100k_base._mergeable_ranks,
+ special_tokens={
+ **cl100k_base._special_tokens,
+ "": 100264,
+ "": 100265,
+ "": 100266,
+ "": 100267,
+ "": 100268,
+ "": 100269,
+ "": 100270,
+ "": 100271,
+ "": 100272,
+ "": 100273,
+ "": 100274,
+ "": 100275,
+ "": 100276,
+ "": 100277,
+ "": 100278,
+ "": 100279,
+ "": 100280,
+ "": 100281,
+ }
+ )
+
+ @property
+ def n_words(self) -> int:
+ n_words = self._model.n_vocab + len(self.symbols)
+ n_words = (n_words + self.tokenizer_pad_to_multiple - 1) // self.tokenizer_pad_to_multiple * self.tokenizer_pad_to_multiple
+ return n_words
+
+ @property
+ def bos_id(self) -> int:
+ return self.indices[""]
+
+ @property
+ def eos_id(self) -> int:
+ return self.indices[""]
+
+ @property
+ def pad_id(self) -> int:
+ return self.indices[""]
+
+ @property
+ def unk_id(self) -> int:
+ return self.indices[""]
+
+ def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
+ assert isinstance(s, str)
+ t = self._model.encode(s, allowed_special="all")
+ t = [i + len(self.symbols) for i in t]
+ if bos:
+ t = [self.bos_id, *t]
+ if eos:
+ t = [*t, self.eos_id]
+ return t
+
+ def decode(self, t: List[int]) -> str:
+ t = [i - len(self.symbols) for i in t if i >= len(self.symbols)]
+ return self._model.decode(t)
\ No newline at end of file
diff --git a/YOCO/yoco/tasks/data/utils.py b/YOCO/yoco/tasks/data/utils.py
new file mode 100644
index 000000000..fd850d73f
--- /dev/null
+++ b/YOCO/yoco/tasks/data/utils.py
@@ -0,0 +1,267 @@
+import collections
+from random import Random
+from typing import Dict, Iterable, Optional
+
+import torch
+import numpy as np
+from infinibatch import iterators
+from infinibatch.iterators import CheckpointableIterator, FixedBatchIterator, SelectManyIterator, MapIterator
+
+from fairseq.data import BaseWrapperDataset, FairseqDataset, data_utils
+
+def apply_to_sample(f, sample):
+ if hasattr(sample, "__len__") and len(sample) == 0:
+ return {}
+
+ def _apply(x):
+ if isinstance(x, np.ndarray):
+ return f(x)
+ elif isinstance(x, collections.OrderedDict):
+ # OrderedDict has attributes that needs to be preserved
+ od = collections.OrderedDict(
+ (key, _apply(value)) for key, value in x.items()
+ )
+ od.__dict__ = x.__dict__
+ return od
+ elif isinstance(x, dict):
+ return {key: _apply(value) for key, value in x.items()}
+ elif isinstance(x, list):
+ return [_apply(x) for x in x]
+ elif isinstance(x, tuple):
+ return tuple(_apply(x) for x in x)
+ elif isinstance(x, set):
+ return {_apply(x) for x in x}
+ else:
+ return x
+
+ return _apply(sample)
+
+
+class NativeCheckpointableIterator(iterators.CheckpointableIterator):
+ def __init__(self, iterable: Iterable):
+ self._input_iterable = iterable
+ self.setstate(None)
+
+ def getstate(self) -> Dict:
+ return {"num_items_yielded": self._num_items_yielded}
+
+ def setstate(self, checkpoint: Optional[Dict]):
+ self._iterator = iter(self._input_iterable)
+ self._num_items_yielded = (
+ iterators._advance_iterator(self._iterator, checkpoint["num_items_yielded"])
+ if checkpoint is not None
+ else 0
+ )
+
+ def __next__(self):
+ item = next(self._iterator)
+ self._num_items_yielded += 1
+ return item
+
+ def close(self):
+ pass
+
+
+class WeightIterator(object):
+ def __init__(self, weights, seed):
+ self.weights = weights
+ self.seed = seed
+ self.control_index = list(range(len(weights)))
+ self.setstate(None)
+
+ def __iter__(self):
+ return self
+
+ def getstate(self):
+ return {"random_state": self._random_state}
+
+ def setstate(self, checkpoint):
+ self._random_state = checkpoint["random_state"] if checkpoint else None
+ self._random = (
+ None # this will trigger the lazy initialization in self.__next__
+ )
+
+ def __next__(self):
+ if self._random is None:
+ self._random = Random(self.seed)
+ if self._random_state is not None:
+ self._random.setstate(self._random_state)
+ idx = self._random.choices(self.control_index, self.weights)[0]
+ self._random_state = self._random.getstate()
+ return idx
+
+ def close(self):
+ pass
+
+
+def FixedBlockwiseShuffleIterator(source_iterator: CheckpointableIterator, block_size: int, seed: int=0):
+ """
+ Shuffles a sequence of items by grouping consecutive items in blocks of fixed size, shuffling
+ each block, and yielding the shuffled items of all blocks as a flat sequence.
+
+ E.g. [1, 2, 3, 4, 5, 6, 7, 8] with block_size = 3 may yield [3, 1, 2, 4, 6, 5, 8, 7].
+
+ Args:
+ source_iterator: checkpointable iterator or restartable iterable over input items to shuffle
+ block_size: size of the buffer in number of items used for shuffling
+ seed: random seed used for shuffling (or None)
+ """
+ # This is implemented as a pipeline:
+ # - group N consecutive items together
+ # - shuffle them
+ # - flatten the result
+ blocks = FixedBatchIterator(source_iterator, batch_size=block_size)
+ def shuffle_block_fn(block):
+ _random = Random(seed)
+ _random.shuffle(block)
+ return block
+ shuffled_blocks = MapIterator(blocks, transform=shuffle_block_fn)
+ # samples = SelectManyNoSkipIterator(shuffled_blocks, collection_selector=lambda shuffled_block: iter(shuffled_block))
+ samples = SelectManyIterator(shuffled_blocks, collection_selector=lambda shuffled_block: iter(shuffled_block))
+ return samples
+
+
+class IndexIterator(object):
+ def __init__(self, num):
+ self.num = num
+ self.setstate(None)
+
+ def __iter__(self):
+ return self
+
+ def getstate(self):
+ return {'num_items_yielded': self._num_items_yielded}
+
+ def setstate(self, checkpoint):
+ self._num_items_yielded =checkpoint['num_items_yielded'] if checkpoint is not None else 0
+
+ def __next__(self):
+ item = self._num_items_yielded % self.num
+ self._num_items_yielded += 1
+ return item
+
+ def close(self):
+ pass
+
+
+class WeightNoRandomStateIterator(object):
+ def __init__(self, weights, seed):
+ self.weights = weights
+ self.seed = seed
+ self.control_index = list(range(len(weights)))
+ self.setstate(None)
+
+ def __iter__(self):
+ return self
+
+ def getstate(self):
+ return {'num_items_yielded': self._num_items_yielded}
+
+ def setstate(self, checkpoint):
+ self._num_items_yielded =checkpoint['num_items_yielded'] if checkpoint is not None else 0
+
+ def __next__(self):
+ self._random = Random(int(self.seed) + self._num_items_yielded)
+ idx = self._random.choices(self.control_index, self.weights)[0]
+ self._num_items_yielded += 1
+ return idx
+
+ def close(self):
+ pass
+
+
+class SelectManyNoSkipIterator(CheckpointableIterator):
+ """
+ Projects each element of a source sequence to a sequence and flattens the resulting sequences into one sequence.
+ """
+ def __init__(self, source_iterator: CheckpointableIterator, collection_selector=None):
+ """
+ Args:
+ source_iterator: iterator over the items to pass to collection_selector()
+ collection_selector: user callback that maps an item into an Iterable, whose items will be yielded.
+ The returned Iterator is used only once. Hence, it is also allowed to
+ return self-iterables, such as iterators and generator expressions.
+ If None is given, no callback is applied.
+ """
+ if not isinstance(source_iterator, CheckpointableIterator):
+ raise ValueError('source_iterator has to be a CheckpointableIterator')
+ self._source_iterator = source_iterator # type: CheckpointableIterator
+ self._collection_selector = collection_selector
+ self.setstate(None)
+
+ def getstate(self) -> Dict:
+ return {'source_state': self._source_state,
+ 'flattened_items_yielded': self._flattened_items_yielded}
+
+ def setstate(self, checkpoint: Optional[Dict]):
+ self._source_state = checkpoint['source_state'] if checkpoint else None
+ self._flattened_items_yielded = 0
+ self._source_iterator.setstate(self._source_state)
+ def _generate():
+ skip_to_checkpoint = self._flattened_items_yielded
+ # main loop over source source_items
+ for source_item in self._source_iterator:
+ if self._collection_selector is not None:
+ data = iter(self._collection_selector(source_item))
+ else:
+ data = iter(source_item)
+ self._flattened_items_yielded = 0
+ # if skip_to_checkpoint:
+ # #print("Skipping to index", skip_to_checkpoint, file=sys.stderr)
+ # self._flattened_items_yielded += _advance_iterator(data, skip_to_checkpoint)
+ # skip_to_checkpoint = 0
+ # main loop over lines
+ for item in data:
+ self._flattened_items_yielded += 1
+ yield item
+ self._source_state = self._source_iterator.getstate()
+ self._iterator = _generate()
+
+ def __next__(self):
+ return next(self._iterator)
+
+ def close(self):
+ self._source_iterator.close()
+
+
+class RawArrayDataset(FairseqDataset):
+
+ def __init__(self, dataset, datatype="token"):
+ super().__init__()
+ self.dataset = dataset
+ self.datatype = datatype
+ if hasattr(dataset, 'sizes'):
+ self._sizes = dataset.sizes
+ else:
+ try:
+ self._sizes = np.array([len(x) for x in self.dataset])
+ except:
+ self._sizes = np.array([1 for x in self.dataset])
+
+ def __getitem__(self, index):
+ if type(self.dataset[index][0]) != list:
+ if self.datatype == "token":
+ return torch.Tensor(self.dataset[index]).long()
+ else:
+ return torch.Tensor(self.dataset[index]).bool()
+ else:
+ return self.dataset[index]
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def collater(self, samples):
+ if hasattr(self.dataset, 'collater'):
+ return self.dataset.collater(samples)
+ else:
+ raise NotImplementedError()
+
+ @property
+ def sizes(self):
+ return self._sizes
+
+ def num_tokens(self, index):
+ return self.dataset.num_tokens(index)
+
+ def size(self, index):
+ return self.dataset.size(index)
diff --git a/YOCO/yoco/tasks/gpt.py b/YOCO/yoco/tasks/gpt.py
new file mode 100644
index 000000000..70dd10283
--- /dev/null
+++ b/YOCO/yoco/tasks/gpt.py
@@ -0,0 +1,176 @@
+import os
+from typing import Optional
+import json
+from argparse import Namespace
+import torch
+
+from fairseq.tasks import register_task, FairseqDataclass, FairseqTask
+from dataclasses import dataclass, field
+from omegaconf import II
+
+from .data.lm_loader import LMLoader
+from .data.tiktoken_tokenizer import TiktokenTokenizer
+from .data.llama_tokenizer import LLaMATokenizer
+
+
+@dataclass
+class GPTLanguageModelingConfig(FairseqDataclass):
+ data: Optional[str] = field(
+ default=None, metadata={"help": "path to data directory"}
+ )
+ tokens_per_sample: int = field(
+ default=1024,
+ metadata={"help": "max number of tokens per sample for LM dataset"},
+ )
+ max_target_positions: Optional[int] = field(
+ default=None, metadata={"help": "max number of tokens in the target sequence"}
+ )
+ llama_model: Optional[str] = field(
+ default=None,
+ metadata={"help": "path to load tokenizer and config"},
+ )
+ tiktoken_model: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "tiktoken model to tokenize the data"
+ },
+ )
+ batch_read_ahead: int = field(
+ default=10000,
+ metadata={"help": "batch read ahead size for infinibatch"},
+ )
+ pad_to_max_len: bool = field(
+ default=False,
+ metadata={"help": "pad each sentence to max length"},
+ )
+ absolute_path: bool = field(
+ default=False,
+ metadata={"help": "use absolute path in data config"},
+ )
+ tokenizer_pad_to_multiple: int = field(
+ default=8,
+ metadata={"help": "pad to multiple of this value"},
+ )
+ seed: int = II("common.seed")
+ batch_size: Optional[int] = II("dataset.batch_size")
+
+
+@register_task('gpt', dataclass=GPTLanguageModelingConfig)
+class GPTPretrainingTask(FairseqTask):
+ def __init__(self, args, tokenizer):
+ super().__init__(args)
+ self.cfg = args
+ self.tokenizer = tokenizer
+
+ @classmethod
+ def setup_task(cls, cfg, **kwargs):
+ """Setup the task (e.g., load dictionaries).
+
+ Args:
+ args (argparse.Namespace): parsed command-line arguments
+ """
+ if cfg.llama_model is not None:
+ tokenizer = LLaMATokenizer(os.path.join(cfg.llama_model, "tokenizer.model"))
+ elif cfg.tiktoken_model is not None:
+ tokenizer = TiktokenTokenizer(cfg.tiktoken_model, cfg.tokenizer_pad_to_multiple)
+ else:
+ raise ValueError("No tokenizer model provided")
+
+ return cls(cfg, tokenizer)
+
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
+ self.datasets[split] = {
+ 'data': json.load(open(f'{self.cfg.data}/json/{split}.json')),
+ 'data_dir': self.cfg.data,
+ 'shuffle': True if split == 'train' else False,
+ }
+ self.datasets[split] = Namespace(**self.datasets[split])
+
+ def dataset(self, split):
+ if split not in self.datasets:
+ raise KeyError("Dataset not loaded: " + split)
+
+ return self.datasets[split]
+
+ def get_batch_iterator(
+ self,
+ dataset,
+ max_tokens=None,
+ max_sentences=None,
+ max_positions=None,
+ ignore_invalid_inputs=False,
+ required_batch_size_multiple=1,
+ seed=1,
+ num_shards=1,
+ shard_id=0,
+ num_workers=0,
+ epoch=1,
+ data_buffer_size=0,
+ disable_iterator_cache=False,
+ skip_remainder_batch=False,
+ grouped_shuffling=False,
+ update_epoch_batch_itr=False
+ ):
+ return LMLoader(
+ self.cfg,
+ dataset,
+ self.tokenizer,
+ max_tokens=max_tokens,
+ max_sentences=max_sentences,
+ max_positions=max_positions,
+ ignore_invalid_inputs=ignore_invalid_inputs,
+ required_batch_size_multiple=required_batch_size_multiple,
+ seed=seed,
+ epoch=epoch,
+ num_shards=num_shards,
+ shard_id=shard_id,
+ )
+
+ def train_step(
+ self, sample, model, criterion, optimizer, update_num, ignore_grad=False
+ ):
+ """
+ Do forward and backward, and return the loss as computed by *criterion*
+ for the given *model* and *sample*.
+
+ Args:
+ sample (dict): the mini-batch. The format is defined by the
+ :class:`~fairseq.data.FairseqDataset`.
+ model (~fairseq.models.BaseFairseqModel): the model
+ criterion (~fairseq.criterions.FairseqCriterion): the criterion
+ optimizer (~fairseq.optim.FairseqOptimizer): the optimizer
+ update_num (int): the current update
+ ignore_grad (bool): multiply loss by 0 if this is set to True
+
+ Returns:
+ tuple:
+ - the loss
+ - the sample size, which is used as the denominator for the
+ gradient
+ - logging outputs to display while training
+ """
+ model.train()
+ model.set_num_updates(update_num)
+ with torch.autograd.profiler.record_function("forward"):
+ loss, sample_size, logging_output = criterion(model, sample)
+ if ignore_grad:
+ loss *= 0
+ with torch.autograd.profiler.record_function("backward"):
+ optimizer.backward(loss)
+ return loss, sample_size, logging_output
+
+ def valid_step(self, sample, model, criterion):
+ model.eval()
+ with torch.no_grad():
+ loss, sample_size, logging_output = criterion(model, sample)
+ return loss, sample_size, logging_output
+
+ @property
+ def target_dictionary(self):
+ padding_idx = self.tokenizer.pad_id
+ class Dict:
+ def pad(self):
+ return padding_idx
+ dictionary = Dict()
+ return dictionary
+
diff --git a/YOCO/yoco/tasks/harness_eval.py b/YOCO/yoco/tasks/harness_eval.py
new file mode 100644
index 000000000..0b0621aea
--- /dev/null
+++ b/YOCO/yoco/tasks/harness_eval.py
@@ -0,0 +1,151 @@
+import os
+from typing import Optional
+import logging
+
+from fairseq.data import (
+ IdDataset,
+ NumSamplesDataset,
+ NumelDataset,
+ NestedDictionaryDataset,
+ NumelDataset,
+ RightPadDataset,
+ RawLabelDataset,
+)
+
+from fairseq.tasks import register_task, FairseqDataclass, LegacyFairseqTask
+from dataclasses import dataclass, field
+
+from .data.tiktoken_tokenizer import TiktokenTokenizer
+from .data.llama_tokenizer import LLaMATokenizer
+from .data.utils import RawArrayDataset
+
+from .harness_task import HarnessAnlir1, HarnessAnlir2, HarnessAnlir3, HarnessArc_challenge, HarnessArc_easy, HarnessBoolq, HarnessCopa, HarnessOpenbookqa, HarnessPiqa, HarnessRte, HarnessWic, HarnessWinogrande, HarnessHellaswag, HarnessRecord, HarnessTruthfullqaMC1, HarnessTruthfullqaMC2, HarnessSCIQ
+from .harness_task import HarnessArc_challenge25s, HarnessHellaswag10s
+
+
+logger = logging.getLogger(__name__)
+
+task_map = {
+ "harness_anli_r1": HarnessAnlir1,
+ "harness_anli_r2": HarnessAnlir2,
+ "harness_anli_r3": HarnessAnlir3,
+ "harness_boolq": HarnessBoolq,
+ "harness_copa": HarnessCopa,
+ "harness_openbookqa": HarnessOpenbookqa,
+ "harness_piqa": HarnessPiqa,
+ "harness_rte": HarnessRte,
+ "harness_wic": HarnessWic,
+ "harness_winogrande": HarnessWinogrande,
+ "harness_hellaswag": HarnessHellaswag,
+ "harness_arc_challenge": HarnessArc_challenge,
+ "harness_arc_easy": HarnessArc_easy,
+ "harness_record": HarnessRecord,
+ "harness_truthfullqa_mc1": HarnessTruthfullqaMC1,
+ "harness_truthfullqa_mc2": HarnessTruthfullqaMC2,
+ "harness_arc_challenge_25s": HarnessArc_challenge25s,
+ "harness_hellaswag_10s": HarnessHellaswag10s,
+ "harness_sciq": HarnessSCIQ,
+}
+
+from .mmlu_task import create_mmlu_tasks
+mmlu_tasks = create_mmlu_tasks()
+task_map.update(mmlu_tasks)
+
+@dataclass
+class HarnessEvalConfig(FairseqDataclass):
+ data_dir: str = field(
+ default="/mnt/msranlp/shaohanh/data/fs_eval/harness/",
+ metadata={"help": "path to data directory"},
+ )
+ eval_data: str = field(default="", metadata={"help": "dataset name"})
+ tokens_per_sample: int = field(
+ default=2048,
+ metadata={"help": "max number of tokens per sample for LM dataset"},
+ )
+ max_target_positions: Optional[int] = field(
+ default=None, metadata={"help": "max number of tokens in the target sequence"}
+ )
+ llama_model: Optional[str] = field(
+ default=None,
+ metadata={"help": "path to load tokenizer and config"},
+ )
+ tiktoken_model: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "tiktoken model to tokenize the data"
+ },
+ )
+ tokenizer_pad_to_multiple: int = field(
+ default=8,
+ metadata={"help": "pad to multiple of this value"},
+ )
+
+
+@register_task('harness_eval', dataclass=HarnessEvalConfig)
+class HarnessEval(LegacyFairseqTask):
+
+ def __init__(self, cfg, tokenizer):
+ super().__init__(cfg)
+ self.cfg = cfg
+ self.tokenizer = tokenizer
+ self.harness_task = task_map[self.cfg.eval_data](tokenizer=self.tokenizer, data_dir=cfg.data_dir, tokens_per_sample=cfg.tokens_per_sample)
+
+ @classmethod
+ def setup_task(cls, cfg, **kwargs):
+ if cfg.llama_model is not None:
+ tokenizer = LLaMATokenizer(os.path.join(cfg.llama_model, "tokenizer.model"))
+ elif cfg.tiktoken_model is not None:
+ tokenizer = TiktokenTokenizer(cfg.tiktoken_model, cfg.tokenizer_pad_to_multiple)
+ else:
+ raise ValueError("No tokenizer model provided")
+
+ return cls(cfg, tokenizer)
+
+ def load_dataset(self, split, combine=False, **kwargs):
+ src_tokens, gpt_loss_mask, label_length, labels = self.harness_task.get_data_for_evaluation()
+
+ src_tokens = RawArrayDataset(src_tokens)
+ gpt_loss_mask = RawArrayDataset(gpt_loss_mask, datatype="mask")
+ label_length = RawLabelDataset(label_length)
+ label_ids = RawLabelDataset(labels)
+ '''
+ Input format: src_tokens + option_tokens
+ '''
+ data_dict = {
+ 'id': IdDataset(),
+ 'net_input': {
+ 'src_tokens': RightPadDataset(
+ src_tokens,
+ pad_idx=self.tokenizer.pad_id,
+ ),
+ 'gpt_loss_mask': RightPadDataset(
+ gpt_loss_mask,
+ pad_idx=False,
+ ),
+ 'label_length': label_length,
+ 'src_lengths': NumelDataset(src_tokens, reduce=False),
+ },
+ 'targets': label_ids,
+ 'nsentences': NumSamplesDataset(),
+ 'ntokens': NumelDataset(src_tokens, reduce=True),
+ }
+ dataset = NestedDictionaryDataset(
+ data_dict,
+ sizes=[src_tokens.sizes],
+ )
+
+ print('| Loaded {} with {} samples'.format(split, len(dataset)))
+
+ self.datasets[split] = dataset
+ return self.datasets[split]
+
+ @property
+ def target_dictionary(self):
+ padding_idx = self.tokenizer.pad_id
+ class Dict:
+ def pad(self):
+ return padding_idx
+ dictionary = Dict()
+ return dictionary
+
+
\ No newline at end of file
diff --git a/YOCO/yoco/tasks/harness_task.py b/YOCO/yoco/tasks/harness_task.py
new file mode 100644
index 000000000..3e87f96a2
--- /dev/null
+++ b/YOCO/yoco/tasks/harness_task.py
@@ -0,0 +1,289 @@
+import json
+import numpy as np
+
+class HarnessBaseTask:
+ def __init__(self, tokenizer, data_dir, tokens_per_sample=1024):
+ self.tokenizer = tokenizer
+ self.class_num = 1
+ self.tokens_per_sample = tokens_per_sample
+ self.base_dir = data_dir
+ self.set_dataname()
+ self.set_class_num()
+ self.dataset = self.load_data()
+
+ def load_data(self):
+ import os
+ datasets = []
+ with open(os.path.join(self.base_dir, self.dataname), "r", encoding='utf-8') as fin:
+ for line in fin:
+ obj = json.loads(line)
+ datasets.append(
+ {
+ "text": obj["ctx"] if "ctx" in obj else None,
+ "label": obj["label"] if "label" in obj else None,
+ "choices": obj["choices"] if "choices" in obj else [],
+ "gold": obj["gold"] if "gold" in obj else None,
+ "raw": obj,
+ }
+ )
+ return datasets
+
+ def set_class_num(self):
+ raise NotImplementedError
+
+ def set_dataname(self):
+ raise NotImplementedError
+
+ def preprocess_example(self, example):
+ raise NotImplementedError
+
+ def get_data_for_evaluation(self):
+ src_tokens = []
+ gpt_loss_mask = []
+ label_length = []
+ labels = []
+ cut_num = 0
+ for i, example in enumerate(self.dataset):
+ input_str, label_str, label = self.preprocess_example(example)
+ if i < 2:
+ print(f"input str is {input_str}")
+ print(f"label str is {label_str}")
+
+ for j in range(len(input_str)):
+ sub_input_str, sub_label_str = input_str[j], label_str[j]
+ input_token = self.tokenizer.encode(sub_input_str)
+ label_token = self.tokenizer.encode(sub_input_str + sub_label_str)[len(input_token):]
+ if len(input_token) + len(label_token) + 1 >= self.tokens_per_sample:
+ cut_num += 1
+ input_token = input_token[-(self.tokens_per_sample - len(label_token) - 1):]
+
+ src_tokens.append([self.tokenizer.bos_id] + input_token + label_token)
+ gpt_loss_mask.append([False] * (len(input_token) + 1) + [True] * len(label_token))
+ label_length.append(len(sub_label_str.strip()))
+ labels.append(label)
+
+ if cut_num > 0:
+ print(f"cut {cut_num} examples")
+
+ return np.array(src_tokens), np.array(gpt_loss_mask), np.array(label_length), np.array(labels)
+
+
+class HarnessAnlir1(HarnessBaseTask):
+ def set_class_num(self):
+ self.class_num = 3
+
+ def set_dataname(self):
+ self.dataname = "anli_r1"
+
+ def preprocess_example(self, example):
+ input_str = [example["text"]] * self.class_num
+ answer_str = [" True", " Neither", " False"]
+ label = example["label"]
+ return input_str, answer_str, label
+
+class HarnessAnlir2(HarnessAnlir1):
+ def set_dataname(self):
+ self.dataname = "anli_r2"
+
+class HarnessAnlir3(HarnessAnlir1):
+ def set_dataname(self):
+ self.dataname = "anli_r3"
+
+class HarnessArc_challenge(HarnessBaseTask):
+ '''
+ using harness to evaluate arc challenge
+ '''
+ def set_class_num(self):
+ self.class_num = 5
+
+ def set_dataname(self):
+ self.dataname = "arc_challenge"
+
+ def preprocess_example(self, example):
+ input_str = [example["text"]] * len(example["choices"])
+ answer_str = [' ' + item for item in example["choices"]]
+ label = example["gold"]
+ return input_str, answer_str, label
+
+class HarnessArc_challenge25s(HarnessBaseTask):
+ '''
+ using harness to evaluate arc challenge
+ '''
+ def set_class_num(self):
+ self.class_num = 5
+
+ def set_dataname(self):
+ self.dataname = "arc_challenge_25s"
+
+ def preprocess_example(self, example):
+ input_str = [example["text"]] * len(example["choices"])
+ answer_str = [' ' + item for item in example["choices"]]
+ label = example["gold"]
+ return input_str, answer_str, label
+
+class HarnessArc_easy(HarnessArc_challenge):
+ def set_class_num(self):
+ self.class_num = 5
+
+ def set_dataname(self):
+ self.dataname = "arc_easy"
+
+class HarnessBoolq(HarnessBaseTask):
+ def set_class_num(self):
+ self.class_num = 2
+
+ def set_dataname(self):
+ self.dataname = "boolq"
+
+ def preprocess_example(self, example):
+ input_str = [example["text"]] * self.class_num
+ answer_str = [" no", " yes"]
+ label = example["label"]
+ return input_str, answer_str, label
+
+class HarnessCopa(HarnessBaseTask):
+ def set_class_num(self):
+ self.class_num = 2
+
+ def set_dataname(self):
+ self.dataname = "copa"
+
+ def preprocess_example(self, example):
+ input_str = [example["text"]] * self.class_num
+ answer_str = [' ' + example['raw']['choice1'], ' ' + example['raw']['choice2']]
+ label = example["label"]
+ return input_str, answer_str, label
+
+class HarnessOpenbookqa(HarnessArc_challenge):
+ def set_class_num(self):
+ self.class_num = 4
+
+ def set_dataname(self):
+ self.dataname = "openbookqa"
+
+class HarnessPiqa(HarnessArc_challenge):
+ def set_class_num(self):
+ self.class_num = 2
+
+ def set_dataname(self):
+ self.dataname = "piqa"
+
+class HarnessRte(HarnessBaseTask):
+ def set_class_num(self):
+ self.class_num = 2
+
+ def set_dataname(self):
+ self.dataname = "rte"
+
+ def preprocess_example(self, example):
+ input_str = [example["text"]] * self.class_num
+ answer_str = [' True', ' False']
+ label = example["label"]
+ return input_str, answer_str, label
+
+class HarnessWic(HarnessRte):
+ def set_dataname(self):
+ self.dataname = "wic"
+
+class HarnessWinogrande(HarnessBaseTask):
+ def set_class_num(self):
+ self.class_num = 2
+
+ def set_dataname(self):
+ self.dataname = "winogrande"
+
+ def preprocess_example(self, example):
+ pronoun_loc = example['raw']['sentence'].index("_")
+ input_str = []
+ input_str.append(example['raw']['sentence'][:pronoun_loc].strip() + ' ' + example['raw']['option1'])
+ input_str.append(example['raw']['sentence'][:pronoun_loc].strip() + ' ' + example['raw']['option2'])
+ answer_str = [" " + example['raw']["sentence"][pronoun_loc + 1:].strip()] * self.class_num
+ label = int(example['raw']['answer']) - 1
+ return input_str, answer_str, label
+
+class HarnessHellaswag(HarnessBaseTask):
+ def set_class_num(self):
+ self.class_num = 4
+
+ def set_dataname(self):
+ self.dataname = "hellaswag"
+
+ def preprocess_example(self, example):
+ input_str = [example["text"]] * self.class_num
+ answer_str = [' ' + item for item in example["choices"]]
+ label = example["gold"]
+ return input_str, answer_str, label
+
+
+class HarnessHellaswag10s(HarnessBaseTask):
+ def set_class_num(self):
+ self.class_num = 4
+
+ def set_dataname(self):
+ self.dataname = "hellaswag_10s"
+
+ def preprocess_example(self, example):
+ input_str = [example["text"]] * self.class_num
+ answer_str = [' ' + item for item in example["choices"]]
+ label = example["gold"]
+ return input_str, answer_str, label
+
+
+class HarnessTruthfullqaMC1(HarnessBaseTask):
+ def set_class_num(self):
+ self.class_num = 1
+
+ def set_dataname(self):
+ self.dataname = "truthfulqa_mc"
+
+ def preprocess_example(self, example):
+ input_str = [example["text"]] * len(example["raw"]["mc1_targets"]["choices"])
+ answer_str = [' ' + item for item in example["raw"]["mc1_targets"]["choices"]]
+ label = 0 # dummy label
+ return input_str, answer_str, label
+
+
+
+class HarnessTruthfullqaMC2(HarnessBaseTask):
+ def set_class_num(self):
+ self.class_num = 1
+
+ def set_dataname(self):
+ self.dataname = "truthfulqa_mc"
+
+ def preprocess_example(self, example):
+ input_str = [example["text"]] * len(example["raw"]["mc2_targets"]["choices"])
+ answer_str = [' ' + item for item in example["raw"]["mc2_targets"]["choices"]]
+ label = 0 # dummy label
+ return input_str, answer_str, label
+
+
+class HarnessRecord(HarnessBaseTask):
+ def set_class_num(self):
+ self.class_num = 1
+
+ def set_dataname(self):
+ self.dataname = "record"
+
+ def preprocess_example(self, example):
+ input_str = [example["text"]] * len(example["raw"]["entities"])
+ answer_str = [f' - {example["raw"]["query"]}'.replace("@placeholder", item) for item in example["raw"]["entities"]]
+ label = 0 # dummy label
+ return input_str, answer_str, label
+
+class HarnessSCIQ(HarnessBaseTask):
+ def set_class_num(self):
+ self.class_num = 4
+
+ def set_dataname(self):
+ self.dataname = "sciq"
+
+ def preprocess_example(self, example):
+ input_str = [example["text"]] * self.class_num
+ answer_str = [' ' + example["raw"]["distractor1"],
+ ' ' + example["raw"]["distractor2"],
+ ' ' + example["raw"]["distractor3"],
+ ' ' + example["raw"]["correct_answer"]
+ ]
+ label = 3
+ return input_str, answer_str, label
\ No newline at end of file
diff --git a/YOCO/yoco/tasks/mmlu_task.py b/YOCO/yoco/tasks/mmlu_task.py
new file mode 100644
index 000000000..a93476c72
--- /dev/null
+++ b/YOCO/yoco/tasks/mmlu_task.py
@@ -0,0 +1,92 @@
+from .harness_task import HarnessBaseTask
+
+
+SUBJECTS = [
+ "abstract_algebra",
+ "anatomy",
+ "astronomy",
+ "business_ethics",
+ "clinical_knowledge",
+ "college_biology",
+ "college_chemistry",
+ "college_computer_science",
+ "college_mathematics",
+ "college_medicine",
+ "college_physics",
+ "computer_security",
+ "conceptual_physics",
+ "econometrics",
+ "electrical_engineering",
+ "elementary_mathematics",
+ "formal_logic",
+ "global_facts",
+ "high_school_biology",
+ "high_school_chemistry",
+ "high_school_computer_science",
+ "high_school_european_history",
+ "high_school_geography",
+ "high_school_government_and_politics",
+ "high_school_macroeconomics",
+ "high_school_mathematics",
+ "high_school_microeconomics",
+ "high_school_physics",
+ "high_school_psychology",
+ "high_school_statistics",
+ "high_school_us_history",
+ "high_school_world_history",
+ "human_aging",
+ "human_sexuality",
+ "international_law",
+ "jurisprudence",
+ "logical_fallacies",
+ "machine_learning",
+ "management",
+ "marketing",
+ "medical_genetics",
+ "miscellaneous",
+ "moral_disputes",
+ "moral_scenarios",
+ "nutrition",
+ "philosophy",
+ "prehistory",
+ "professional_accounting",
+ "professional_law",
+ "professional_medicine",
+ "professional_psychology",
+ "public_relations",
+ "security_studies",
+ "sociology",
+ "us_foreign_policy",
+ "virology",
+ "world_religions",
+]
+
+
+def create_mmlu_tasks():
+ """Creates a dictionary of tasks from a list of subjects
+ :return: {task_name: task}
+ e.g. {hendrycksTest-abstract_algebra: Task, hendrycksTest-anatomy: Task}
+ """
+ return {f"hendrycksTest-{sub}": create_task(f"hendrycksTest-{sub}") for sub in SUBJECTS}
+
+
+def create_task(subject):
+ class HendrycksTest(GeneralHendrycksTest):
+ def set_dataname(self):
+ self.dataname = f"{subject}"
+
+ return HendrycksTest
+
+class GeneralHendrycksTest(HarnessBaseTask):
+ def set_class_num(self):
+ self.class_num = 4
+
+ def preprocess_example(self, example):
+ # find the last occurence of "Queston:" in example["text"], and remove everything before it
+ # this is to remove the context
+ # last_question = example["text"].rfind("Question:")
+ # example["text"] = example["text"][last_question:]
+ input_str = [example["text"]] * self.class_num
+ answer_str = [' ' + item for item in example["choices"]]
+ label = example["gold"]
+ return input_str, answer_str, label
diff --git a/YOCO/yoco/tasks/pseudo.py b/YOCO/yoco/tasks/pseudo.py
new file mode 100644
index 000000000..87b51a12b
--- /dev/null
+++ b/YOCO/yoco/tasks/pseudo.py
@@ -0,0 +1,202 @@
+import os
+from typing import Optional
+import torch
+
+from fairseq.data import FairseqDataset
+from fairseq.tasks import register_task, FairseqDataclass, LegacyFairseqTask
+from dataclasses import dataclass, field
+from omegaconf import II
+
+from .data.tiktoken_tokenizer import TiktokenTokenizer
+from .data.llama_tokenizer import LLaMATokenizer
+
+
+class PseudoIterator(FairseqDataset):
+ def __init__(self, batch_size, length, vocab_size):
+ super().__init__()
+ self.batch_size = batch_size
+ self.length = length
+ self.vocab_size = vocab_size
+
+ self.epoch = 1
+ self.next_epoch_idx = 1
+ self.sharded_checkpoint = True
+ self.should_close_after_finished = True
+
+ def __iter__(self):
+ while True:
+ yield self.__next__()
+
+ def __next__(self):
+ net_input = torch.randint(size=(self.batch_size, self.length), dtype=torch.long, low=0, high=self.vocab_size - 1)
+ return {
+ "net_input": {"src_tokens": net_input},
+ "target": net_input,
+ "ntokens": self.batch_size * self.length,
+ }
+
+ def __len__(self) -> int:
+ return 819200000
+
+ def next_epoch_itr(self, **kwargs):
+ return self
+
+ @property
+ def first_batch(self):
+ return "DUMMY"
+
+ def end_of_epoch(self) -> bool:
+ return False
+
+ def state_dict(self):
+ return None
+
+ def load_state_dict(self, state_dict):
+ pass
+
+ def setstate(self, value):
+ pass
+
+ def getstate(self):
+ pass
+
+ def close(self):
+ pass
+
+@dataclass
+class PseudoConfig(FairseqDataclass):
+ tokens_per_sample: int = field(
+ default=1024,
+ metadata={"help": "max number of tokens per sample for LM dataset"},
+ )
+ max_target_positions: Optional[int] = field(
+ default=None, metadata={"help": "max number of tokens in the target sequence"}
+ )
+ llama_model: Optional[str] = field(
+ default=None,
+ metadata={"help": "path to load tokenizer and config"},
+ )
+ tiktoken_model: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "tiktoken model to tokenize the data"
+ },
+ )
+ batch_read_ahead: int = field(
+ default=10000,
+ metadata={"help": "batch read ahead size for infinibatch"},
+ )
+ pad_to_max_len: bool = field(
+ default=False,
+ metadata={"help": "pad each sentence to max length"},
+ )
+ absolute_path: bool = field(
+ default=False,
+ metadata={"help": "use absolute path in data config"},
+ )
+ tokenizer_pad_to_multiple: int = field(
+ default=8,
+ metadata={"help": "pad to multiple of this value"},
+ )
+ seed: int = II("common.seed")
+ batch_size: Optional[int] = II("dataset.batch_size")
+
+
+@register_task('pseudo', dataclass=PseudoConfig)
+class PseudoTask(LegacyFairseqTask):
+ def __init__(self, args, tokenizer):
+ super().__init__(args)
+ self.cfg = args
+ self.tokenizer = tokenizer
+
+ @classmethod
+ def setup_task(cls, cfg, **kwargs):
+ """Setup the task (e.g., load dictionaries).
+
+ Args:
+ args (argparse.Namespace): parsed command-line arguments
+ """
+ if cfg.llama_model is not None:
+ tokenizer = LLaMATokenizer(os.path.join(cfg.llama_model, "tokenizer.model"))
+ elif cfg.tiktoken_model is not None:
+ tokenizer = TiktokenTokenizer(cfg.tiktoken_model, cfg.tokenizer_pad_to_multiple)
+ else:
+ raise ValueError("No tokenizer model provided")
+
+ return cls(cfg, tokenizer)
+
+ def load_dataset(self, split, **kwargs):
+ pass
+ # self.datasets[split] = None
+
+ def dataset(self, split):
+ return None
+
+ def get_batch_iterator(
+ self,
+ dataset,
+ max_tokens=None,
+ max_sentences=None,
+ max_positions=None,
+ ignore_invalid_inputs=False,
+ required_batch_size_multiple=1,
+ seed=1,
+ num_shards=1,
+ shard_id=0,
+ num_workers=0,
+ epoch=1,
+ data_buffer_size=0,
+ disable_iterator_cache=False,
+ skip_remainder_batch=False,
+ grouped_shuffling=False,
+ update_epoch_batch_itr=False
+ ):
+ return PseudoIterator(max_sentences, self.cfg.tokens_per_sample, 10000)
+
+ def train_step(
+ self, sample, model, criterion, optimizer, update_num, ignore_grad=False
+ ):
+ """
+ Do forward and backward, and return the loss as computed by *criterion*
+ for the given *model* and *sample*.
+
+ Args:
+ sample (dict): the mini-batch. The format is defined by the
+ :class:`~fairseq.data.FairseqDataset`.
+ model (~fairseq.models.BaseFairseqModel): the model
+ criterion (~fairseq.criterions.FairseqCriterion): the criterion
+ optimizer (~fairseq.optim.FairseqOptimizer): the optimizer
+ update_num (int): the current update
+ ignore_grad (bool): multiply loss by 0 if this is set to True
+
+ Returns:
+ tuple:
+ - the loss
+ - the sample size, which is used as the denominator for the
+ gradient
+ - logging outputs to display while training
+ """
+ model.train()
+ model.set_num_updates(update_num)
+ with torch.autograd.profiler.record_function("forward"):
+ loss, sample_size, logging_output = criterion(model, sample)
+ if ignore_grad:
+ loss *= 0
+ with torch.autograd.profiler.record_function("backward"):
+ optimizer.backward(loss)
+ return loss, sample_size, logging_output
+
+ def valid_step(self, sample, model, criterion):
+ model.eval()
+ with torch.no_grad():
+ loss, sample_size, logging_output = criterion(model, sample)
+ return loss, sample_size, logging_output
+
+ @property
+ def target_dictionary(self):
+ padding_idx = self.tokenizer.pad_id
+ class Dict:
+ def pad(self):
+ return padding_idx
+ dictionary = Dict()
+ return dictionary
\ No newline at end of file
diff --git a/YOCO/yoco/train.py b/YOCO/yoco/train.py
new file mode 100644
index 000000000..ee6615d87
--- /dev/null
+++ b/YOCO/yoco/train.py
@@ -0,0 +1,7 @@
+import models
+import tasks
+import criterions
+from fairseq_cli.train import cli_main
+
+if __name__ == "__main__":
+ cli_main()
diff --git a/YOCO/yoco/validate.py b/YOCO/yoco/validate.py
new file mode 100644
index 000000000..e3815ca6a
--- /dev/null
+++ b/YOCO/yoco/validate.py
@@ -0,0 +1,294 @@
+#!/usr/bin/env python3 -u
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Train a new model on one or across multiple GPUs.
+"""
+import models
+import tasks
+import criterions
+
+import argparse
+import logging
+import math
+import os
+import sys
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+# We need to setup root logger before importing any fairseq libraries.
+logging.basicConfig(
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
+ stream=sys.stdout,
+)
+logger = logging.getLogger("fairseq_cli.train")
+
+import numpy as np
+import torch
+from omegaconf import DictConfig, OmegaConf
+
+from fairseq import checkpoint_utils, options, quantization_utils, tasks, utils
+from fairseq.data import data_utils, iterators
+from fairseq.data.plasma_utils import PlasmaStore
+from fairseq.dataclass.configs import FairseqConfig
+from fairseq.dataclass.utils import convert_namespace_to_omegaconf
+from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap
+from fairseq.distributed import utils as distributed_utils
+from fairseq.file_io import PathManager
+from fairseq.logging import meters, metrics, progress_bar
+
+
+def main(cfg: FairseqConfig) -> None:
+ if isinstance(cfg, argparse.Namespace):
+ cfg = convert_namespace_to_omegaconf(cfg)
+
+ utils.import_user_module(cfg.common)
+
+ if (
+ distributed_utils.is_master(cfg.distributed_training)
+ and "job_logging_cfg" in cfg
+ ):
+ # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
+ logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))
+
+ assert (
+ cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
+ ), "Must specify batch size either with --max-tokens or --batch-size"
+ metrics.reset()
+
+ if cfg.common.log_file is not None:
+ handler = logging.FileHandler(filename=cfg.common.log_file)
+ logger.addHandler(handler)
+
+ np.random.seed(cfg.common.seed)
+ utils.set_torch_seed(cfg.common.seed)
+
+ # if distributed_utils.is_master(cfg.distributed_training):
+ # checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)
+
+ # Print args
+ logger.info(cfg)
+
+ if cfg.checkpoint.write_checkpoints_asynchronously:
+ try:
+ import iopath # noqa: F401
+ except ImportError:
+ logging.exception(
+ "Asynchronous checkpoint writing is specified but iopath is "
+ "not installed: `pip install iopath`"
+ )
+ return
+
+ # Setup task, e.g., translation, language modeling, etc.
+ task = tasks.setup_task(cfg.task)
+
+ assert cfg.criterion, "Please specify criterion to train a model"
+
+ # Build model and criterion
+ if cfg.distributed_training.ddp_backend == "fully_sharded":
+ with fsdp_enable_wrap(cfg.distributed_training):
+ model = fsdp_wrap(task.build_model(cfg.model))
+ else:
+ model = task.build_model(cfg.model)
+ criterion = task.build_criterion(cfg.criterion)
+
+ tpu = cfg.common.tpu
+ cuda = torch.cuda.is_available() and not cfg.common.cpu and not tpu
+ if cuda:
+ device = torch.device("cuda")
+ elif tpu:
+ device = utils.get_tpu_device()
+ else:
+ device = torch.device("cpu")
+ if cfg.common.fp16:
+ criterion = criterion.half()
+ model = model.half()
+ elif cfg.common.bf16:
+ criterion = criterion.to(dtype=torch.bfloat16)
+ model = model.to(dtype=torch.bfloat16)
+ criterion = criterion.to(device)
+ model = model.to(device)
+
+ logger.info(model)
+ logger.info("task: {}".format(task.__class__.__name__))
+ logger.info("model: {}".format(model.__class__.__name__))
+ logger.info("criterion: {}".format(criterion.__class__.__name__))
+ logger.info(
+ "num. shared model params: {:,} (num. trained: {:,})".format(
+ sum(
+ p.numel() for p in model.parameters() if not getattr(p, "expert", False)
+ ),
+ sum(
+ p.numel()
+ for p in model.parameters()
+ if not getattr(p, "expert", False) and p.requires_grad
+ ),
+ )
+ )
+
+ logger.info(
+ "num. expert model params: {} (num. trained: {})".format(
+ sum(p.numel() for p in model.parameters() if getattr(p, "expert", False)),
+ sum(
+ p.numel()
+ for p in model.parameters()
+ if getattr(p, "expert", False) and p.requires_grad
+ ),
+ )
+ )
+
+ # Load valid dataset (we load training data below, based on the latest checkpoint)
+ # We load the valid dataset AFTER building the model
+ data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg)
+ if cfg.dataset.combine_valid_subsets:
+ task.load_dataset("valid", combine=True, epoch=1)
+ else:
+ for valid_sub_split in cfg.dataset.valid_subset.split(","):
+ task.load_dataset(valid_sub_split, combine=False, epoch=1)
+
+ # Load the latest checkpoint if one is available and restore the
+ # corresponding train iterator
+ # try:
+ # state_dict = torch.load(cfg.checkpoint.restore_file)
+ # model.load_state_dict(state_dict['model'])
+ # print(f"Loaded model from {cfg.checkpoint.restore_file}")
+ # except Exception as e:
+ # print(e)
+ # print(f"No checkpoint found from {cfg.checkpoint.restore_file}")
+
+ valid_subsets = cfg.dataset.valid_subset.split(",")
+ logger.info("Start validating")
+
+ validate(
+ cfg, task, model, criterion, valid_subsets,
+ )
+
+@torch.no_grad()
+def validate(
+ cfg: DictConfig,
+ task: tasks.FairseqTask,
+ model,
+ criterion,
+ subsets: List[str],
+) -> List[Optional[float]]:
+ """Evaluate the model on the validation set(s) and return the losses."""
+ if cfg.dataset.fixed_validation_seed is not None:
+ # set fixed seed for every validation
+ utils.set_torch_seed(cfg.dataset.fixed_validation_seed)
+
+ valid_losses = []
+ for subset in subsets:
+ logger.info('begin validation on "{}" subset'.format(subset))
+
+ # Initialize data iterator
+ itr = task.get_batch_iterator(
+ dataset=task.dataset(subset),
+ max_tokens=cfg.dataset.max_tokens_valid,
+ max_sentences=cfg.dataset.batch_size_valid,
+ required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
+ seed=cfg.common.seed,
+ num_workers=cfg.dataset.num_workers_valid,
+ # always pass a fixed "epoch" to keep validation data consistent
+ # across training epochs
+ epoch=1,
+ data_buffer_size=cfg.dataset.data_buffer_size,
+ ).next_epoch_itr(
+ shuffle=False, set_dataset_epoch=False # use a fixed valid set
+ )
+ if cfg.common.tpu:
+ itr = utils.tpu_data_loader(itr)
+ progress = progress_bar.progress_bar(
+ itr,
+ log_format=cfg.common.log_format,
+ log_interval=cfg.common.log_interval,
+ prefix=f"valid on '{subset}' subset",
+ tensorboard_logdir=(
+ cfg.common.tensorboard_logdir
+ if distributed_utils.is_master(cfg.distributed_training)
+ else None
+ ),
+ default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
+ wandb_project=(
+ cfg.common.wandb_project
+ if distributed_utils.is_master(cfg.distributed_training)
+ else None
+ ),
+ wandb_run_name=os.environ.get(
+ "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
+ ),
+ )
+
+ # create a new root metrics aggregator so validation metrics
+ # don't pollute other aggregators (e.g., train meters)
+ with metrics.aggregate(new_root=True) as agg:
+ logging_outputs = []
+ for i, sample in enumerate(progress):
+ if cfg.dataset.max_valid_steps is not None and i > cfg.dataset.max_valid_steps:
+ break
+ sample = utils.move_to_cuda(sample)
+ _, _, inner_logging_outputs = task.valid_step(
+ sample, model, criterion
+ )
+ logging_outputs.append(inner_logging_outputs)
+ task.reduce_metrics(logging_outputs, criterion)
+
+ # with metrics.aggregate(new_root=True) as agg:
+ # for i, sample in enumerate(progress):
+ # if (
+ # cfg.dataset.max_valid_steps is not None
+ # and i > cfg.dataset.max_valid_steps
+ # ):
+ # break
+ # trainer.valid_step(sample)
+
+ stats = get_valid_stats(cfg, agg.get_smoothed_values())
+
+ progress.print(stats, tag=subset)
+
+ valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric])
+ return valid_losses
+
+
+def get_valid_stats(
+ cfg: DictConfig, stats: Dict[str, Any]
+) -> Dict[str, Any]:
+ if hasattr(checkpoint_utils.save_checkpoint, "best"):
+ key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric)
+ best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min
+ stats[key] = best_function(
+ checkpoint_utils.save_checkpoint.best,
+ stats[cfg.checkpoint.best_checkpoint_metric],
+ )
+ return stats
+
+
+def cli_main(
+ modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None
+) -> None:
+ parser = options.get_training_parser()
+ args = options.parse_args_and_arch(parser, modify_parser=modify_parser)
+
+ cfg = convert_namespace_to_omegaconf(args)
+
+ if cfg.common.use_plasma_view:
+ server = PlasmaStore(path=cfg.common.plasma_path)
+ logger.info(
+ f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}"
+ )
+
+ if args.profile:
+ with torch.cuda.profiler.profile():
+ with torch.autograd.profiler.emit_nvtx():
+ distributed_utils.call_main(cfg, main)
+ else:
+ distributed_utils.call_main(cfg, main)
+
+ # if cfg.common.use_plasma_view:
+ # server.server.kill()
+
+
+if __name__ == "__main__":
+ cli_main()
\ No newline at end of file