forked from ml-explore/mlx-examples
-
Notifications
You must be signed in to change notification settings - Fork 1
/
hf_t5.py
53 lines (44 loc) · 1.85 KB
/
hf_t5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import argparse
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5EncoderModel
def embed(t5_model: str):
batch = [
"translate English to German: That is good.",
"This is an example of T5 working on MLX.",
]
tokenizer = AutoTokenizer.from_pretrained(t5_model)
torch_model = T5EncoderModel.from_pretrained(t5_model)
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
torch_forward = torch_model(**torch_tokens, output_hidden_states=True)
torch_output = torch_forward.last_hidden_state.detach().numpy()
print("\n TF BERT:")
for input_str, embedding in list(zip(batch, torch_output)):
print("Input:", input_str)
print(embedding)
print()
def generate(t5_model: str):
prompt = "translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast."
tokenizer = AutoTokenizer.from_pretrained(t5_model)
torch_model = AutoModelForSeq2SeqLM.from_pretrained(t5_model)
torch_tokens = tokenizer(prompt, return_tensors="pt", padding=True).input_ids
outputs = torch_model.generate(torch_tokens, do_sample=False, max_length=512)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run the T5 model using Hugging Face Transformers."
)
parser.add_argument(
"--encode-only",
action="store_true",
help="Only run the encoder and print the embeddings.",
default=False,
)
parser.add_argument(
"--model",
default="t5-small",
help="The huggingface name of the T5 model to save.",
)
args = parser.parse_args()
if args.encode_only:
embed(args.model)
else:
generate(args.model)