Skip to content

Commit

Permalink
Merge branch 'patch20240822' of github.com:openpsi-project/realhf int…
Browse files Browse the repository at this point in the history
…o profile
  • Loading branch information
garrett4wade committed Aug 28, 2024
2 parents f12bd7a + fad2214 commit 11e0e7a
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 37 deletions.
4 changes: 2 additions & 2 deletions examples/scripts/local/gen.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ python3 -m realhf.apps.quickstart gen \
allocation_mode=manual \
model.type._class=$MODEL_FAMILY \
model.path=$SFT_MODEL_PATH \
dataset.path=/lustre/fw/datasets/imdb/rl/ppo_prompt.jsonl \
dataset.path=/lustre/fw/datasets/imdb/rl/ppo_prompt-small.jsonl \
dataset.max_prompt_len=1024 \
dataset.train_bs_n_seqs=128 \
dataset.train_bs_n_seqs=100 \
allocation.parallel.pipeline_parallel_size=1 \
allocation.parallel.model_parallel_size=2 \
allocation.parallel.data_parallel_size=4 \
Expand Down
5 changes: 4 additions & 1 deletion realhf/api/core/data_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ def load_hf_tokenizer(
if padding_side is not None:
kwargs["padding_side"] = padding_side
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name_or_path, fast_tokenizer=fast_tokenizer, **kwargs
model_name_or_path,
fast_tokenizer=fast_tokenizer,
trust_remote_code=True,
**kwargs,
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
Expand Down
7 changes: 0 additions & 7 deletions realhf/api/core/system_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,6 @@ class ExperimentSaveEvalControl:
:param benchmark_steps: Terminate the training after this number of steps.
Used by system benchmark only. Please leave it to None for normal training.
:type benchmark_steps: Optional[int]
:param save_eval_timeout: Timeout in seconds for saving and evaluation.
Will be used for the last step of the experiment. The master worker will sleep
for `save_eval_timeout` seconds to wait all save or evaluations to finish.
Defaults to 120 seconds.
:type save_eval_timeout: int
"""

total_train_epochs: int = 1
Expand All @@ -204,8 +199,6 @@ class ExperimentSaveEvalControl:
eval_freq_secs: Optional[int] = None
# benchmark
benchmark_steps: Optional[int] = None
# Graceful exit
save_eval_timeout: int = 120


@dataclasses.dataclass
Expand Down
32 changes: 31 additions & 1 deletion realhf/experiments/common/gen_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ModelInterfaceType,
ModelName,
)
from realhf.api.core.data_api import DatasetUtility, load_hf_tokenizer
from realhf.api.core.dfg import MFCDef
from realhf.api.core.model_api import GenerationHyperparameters
from realhf.api.quickstart.dataset import PromptOnlyDatasetConfig
Expand Down Expand Up @@ -45,6 +46,35 @@ class GenerationConfig(CommonExperimentConfig):
)
allocation: MFCConfig = dataclasses.field(default_factory=MFCConfig)

output_file: str = "output.jsonl"

def __post_init__(self):
from realhf.impl.dataset.prompt_dataset import PromptDataset

util = DatasetUtility(
seed=0,
ddp_rank=0,
world_size=1,
tokenizer=load_hf_tokenizer(self.model.path),
)
d = PromptDataset(
util, max_length=self.dataset.max_prompt_len, dataset_path=self.dataset.path
)
if len(d) % self.dataset.train_bs_n_seqs != 0:
raise ValueError(
f"The size of the dataset must be a multiple of batch size for generation. "
f"Otherwise the final batch will be dropped. Please pad your dataset size with random prompts. "
f"Current dataset size: {len(d)}, batch size: {self.dataset.train_bs_n_seqs}."
)
if self.output_file is not None:
if not self.output_file.endswith(".jsonl"):
raise ValueError("Output path must end with .jsonl")
if "/" in self.output_file:
raise ValueError(
"Output path must not contain '/'. It should be a simple "
"filename that will be saved to the logging directory."
)

@property
def models(self):
return {
Expand All @@ -59,7 +89,7 @@ def rpcs(self):
# Customized dataclass objects will not work in that case.
interface = ModelInterfaceAbstraction(
"generation",
args={"generation_config": OmegaConf.to_container(self.gen, resolve=True)},
args={"generation_config": OmegaConf.to_container(self.gen, resolve=True), "output_file": self.output_file},
)
gen = MFCDef(
name="gen",
Expand Down
128 changes: 114 additions & 14 deletions realhf/impl/model/interface/gen_interface.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,43 @@
import dataclasses
import fcntl
import json
import os

import colorama
import torch

import realhf.api.core.model_api as model_api
from realhf.api.core.data_api import SequenceSample
from realhf.base import constants, logging
from realhf.base.datapack import flat2d

logger = logging.getLogger("Generation Interface", "benchmark")


def acquire_lock(lock_file):
fd = open(lock_file, "w")
fcntl.flock(fd, fcntl.LOCK_EX)
return fd


def release_lock(lock_fd):
fcntl.flock(lock_fd, fcntl.LOCK_UN)
lock_fd.close()


def write_dict_to_jsonl(dict_data, file_path, lock_file):
lock_fd = acquire_lock(lock_file)
try:
with open(file_path, "a") as file:
json.dump(dict_data, file)
file.write("\n")
finally:
release_lock(lock_fd)


@dataclasses.dataclass
class GenerationInterface(model_api.ModelInterface):
output_file: str | None = None
generation_config: dict = dataclasses.field(default_factory=dict)

def __post_init__(self):
Expand Down Expand Up @@ -45,27 +71,101 @@ def generate(
if res is None:
return None

gen_tokens, logprobs, *_ = res
gen_tokens, *_ = res

# Decode and log the first generated sentence.
l = input_.seqlens["packed_prompts"][0][0]
tokens = torch.cat(
[input_.data["packed_prompts"][:l], gen_tokens[0]]
).unsqueeze(0)
out = model.tokenizer.batch_decode(
tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
if constants.model_parallel_rank() == 0 and constants.is_last_pipe_stage():
res = {
"generated_length": gen_tokens.shape[1],
"batch_size": gen_tokens.shape[0],
}
if not (
constants.model_parallel_rank() == 0 and constants.is_last_pipe_stage()
):
# Not DP head, return stats.
return res

if self.output_file is not None:

# Concatenate prompts with gen_tokens, decode, and output to file.
prompt_lens = flat2d(input_.seqlens["packed_prompts"])
gen_lengths = (gen_tokens != model.tokenizer.pad_token_id).logical_and(
gen_tokens != model.tokenizer.eos_token_id
).sum(dim=-1) + 1
gen_lengths = gen_lengths.clip(max=gen_tokens.shape[-1])
assert len(gen_lengths) == len(prompt_lens) == input_.bs, (
input_.bs,
len(prompt_lens),
len(gen_lengths),
)

prompt_tokens_lis = []
ans_tokens_lis = []
prompt_offset = 0
for i, (prompt_len, gen_len) in enumerate(zip(prompt_lens, gen_lengths)):
prompt_tokens_lis.append(
input_.data["packed_prompts"][
prompt_offset : prompt_offset + prompt_len
]
)
ans_tokens_lis.append(gen_tokens[i, :gen_len])
prompt_offset += prompt_len
assert prompt_offset == sum(prompt_lens)
seq_tokens_lis = [
torch.cat([a, b]) for a, b in zip(prompt_tokens_lis, ans_tokens_lis)
]

prompt_str = model.tokenizer.batch_decode(
prompt_tokens_lis,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
ans_str = model.tokenizer.batch_decode(
ans_tokens_lis,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
seq_str = model.tokenizer.batch_decode(
seq_tokens_lis,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)

lock_file = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"_gen.lock",
)
output_file = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
self.output_file,
)
if constants.data_parallel_rank() == 0:
logger.info(f"Dumping output to: {output_file}...")
for p, a, s, _id in zip(prompt_str, ans_str, seq_str, input_.ids):
d = dict(
prompt=p,
answer=a,
seq=s,
id=_id,
)
write_dict_to_jsonl(d, output_file, lock_file)
else:
# Decode and log the first generated sentence.
l = input_.seqlens["packed_prompts"][0][0]
tokens = torch.cat(
[input_.data["packed_prompts"][:l], gen_tokens[0]]
).unsqueeze(0)
out = model.tokenizer.batch_decode(
tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
dp_rank = constants.data_parallel_rank()
logger.info(
f"DP rank {dp_rank}, the first generated sequence "
f"is: {colorama.Fore.YELLOW + colorama.Style.DIM}{out[0]}{colorama.Style.RESET_ALL}"
)

res = {
"generated_length": gen_tokens.shape[1],
"batch_size": gen_tokens.shape[0],
}
return res


Expand Down
49 changes: 37 additions & 12 deletions realhf/system/master_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,13 @@ def group_rpc_blocked(
handle_type: str,
datas: List,
verbose: bool = True,
) -> List:
request_ids = request_all(stream, handlers, handle_type, datas, verbose=verbose)
res = []
for req_id in request_ids:
r = stream.poll(pattern=create_exact_match_pattern([req_id]), block=True)
res.append(r)
return res
):
req_ids = request_all(stream, handlers, handle_type, datas, verbose=verbose)
payloads = [
stream.poll(pattern=create_exact_match_pattern([req_id]), block=True)
for req_id in req_ids
]
return [p.data for p in payloads]


def _request_parameter_sync(
Expand Down Expand Up @@ -815,7 +815,7 @@ async def model_eval_thread_func(
)
eval_stats = _gather_stat(list(filter(lambda x: bool(x), eval_stats)))
logger.info(
f"Evaluation results at epoch {epoch + 1} step {epoch_step + 1}: {eval_stats}"
f"Evaluation results at epoch {epoch} step {epoch_step}: {eval_stats}"
)


Expand Down Expand Up @@ -1359,12 +1359,37 @@ def _poll(self):
self.__benchmark_steps is not None
and self._global_step >= self.__benchmark_steps
) or (is_new_epoch and self._epoch > self.__total_train_epochs):
if should_eval or should_save:
if should_eval:
eval_stats = group_rpc_blocked(
self.__stream,
self.__all_model_handlers,
"evaluate",
[None for _ in self.__all_model_handlers],
)
eval_stats = _gather_stat(
list(filter(lambda x: bool(x), eval_stats))
)
logger.info(
f"Evaluation results at epoch {self._epoch} step {self._epoch_step}: {eval_stats}"
)
if should_save:
model_save_dirs = [
os.path.join(
self.MODEL_SAVE_ROOT,
s.model_name.role,
f"epoch{self._epoch}epochstep{self._epoch_step}globalstep{self._global_step}",
)
for s in self.__trainable_model_handlers
]
group_rpc_blocked(
self.__stream,
self.__trainable_model_handlers,
"save",
model_save_dirs,
)
logger.info(
f"Waiting for all save/eval requests at the last step"
f" for {self.config.exp_ctrl.save_eval_timeout} secs..."
f"Save models at epoch {self._epoch} step {self._epoch_step}."
)
time.sleep(self.config.exp_ctrl.save_eval_timeout)
if self.__benchmark_steps is not None:
logger.info(
f"Finished benchmark {self.__benchmark_steps}. "
Expand Down

0 comments on commit 11e0e7a

Please sign in to comment.