-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
101 lines (82 loc) · 3.12 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from typing import List, Union
import torch
from torch import Tensor
from ortools.linear_solver import pywraplp
import binpacking
from transformers import AutoTokenizer
from model import CustomCausalLlamaModel, CustomCausalMistralModel
# As implemented here:
# https://github.com/pytorch/pytorch/issues/10536#issuecomment-1320935162
def left_pad_sequence(
sequences: Union[Tensor, List[Tensor]],
batch_first: bool = True,
padding_value: float = 0.0,
) -> Tensor:
sequences = tuple(map(lambda s: s.flip(0), sequences))
padded_sequence = torch._C._nn.pad_sequence(sequences, batch_first, padding_value)
_seq_dim = padded_sequence.dim()
padded_sequence = padded_sequence.flip(-_seq_dim + batch_first)
return padded_sequence
def greedy_packing(length_dict, max_bin_size):
return binpacking.to_constant_volume(length_dict, max_bin_size)
# https://developers.google.com/optimization/pack/bin_packing
def integer_program_packing(length_dict, max_bin_size):
data = {}
data["items"] = list(length_dict.keys())
data["weights"] = list(length_dict.values())
data["bins"] = data["items"]
data["bin_capacity"] = max_bin_size
solver = pywraplp.Solver.CreateSolver("SCIP")
if not solver:
return
x = {}
for i in data["items"]:
for j in data["bins"]:
x[(i, j)] = solver.IntVar(0, 1, "x_%i_%i" % (i, j))
y = {}
for j in data["bins"]:
y[j] = solver.IntVar(0, 1, "y[%i]" % j)
for i in data["items"]:
solver.Add(sum(x[i, j] for j in data["bins"]) == 1)
for j in data["bins"]:
solver.Add(sum(x[(i, j)] * data["weights"][i] for i in data["items"]) <= y[j] * data["bin_capacity"])
solver.Minimize(solver.Sum([y[j] for j in data["bins"]]))
status = solver.Solve()
if status == pywraplp.Solver.OPTIMAL:
result = []
for j in data["bins"]:
if y[j].solution_value() == 1:
bin_dict = {}
for i in data["items"]:
if x[i, j].solution_value() > 0:
bin_dict[i] = data["weights"][i]
result.append(bin_dict)
else:
raise ("The problem does not have an optimal solution.")
return result
def load_model_and_tokenizer(
base_model: str = "llama1b",
loadbit: int = 8,
):
# Load tokenizer and model
if base_model == "llama1b":
path = "princeton-nlp/Sheared-LLaMA-1.3B"
elif base_model == "llama2":
path = "/path/to/llama2"
tokenizer = AutoTokenizer.from_pretrained(path)
tokenizer.pad_token = "[PAD]"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
load_in_8bit = loadbit == 8
load_in_4bit = loadbit == 4
if "llama" in base_model:
model = CustomCausalLlamaModel.from_pretrained(
path, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
)
elif "mistral" in base_model:
model = CustomCausalMistralModel.from_pretrained(
path, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
)
model.eval()
if loadbit != 8 and loadbit != 4:
model.to(device)
return model, tokenizer