Skip to content

Commit

Permalink
[REMOVED SCIPY]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye Gomez authored and Kye Gomez committed May 27, 2024
1 parent 726468e commit 4d19440
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 7 deletions.
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zetascale"
version = "2.5.2"
version = "2.5.4"
description = "Rapidly Build, Optimize, and Deploy SOTA AI Models"
authors = ["Zeta Team <[email protected]>"]
license = "MIT"
Expand All @@ -21,15 +21,14 @@ torch = ">=2.1.1,<3.0"
pytest = "8.2.1"
torchfix = "*"
einops = "0.7.0"
bitsandbytes = "0.43.0"
bitsandbytes = "*"
transformers = "4.41.0"
einops-exts = "0.0.4"
torchvision = "0.18.0"
accelerate = "0.30.1"
datasets = "*"
loguru = "*"
vector-quantize-pytorch = "1.14.7"
scipy = "1.9.3"
beartype = "0.17.2"
tqdm = "4.66.3"
rich = "13.7.1"
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ torchfix
torchdiffeq>=0.2.3,<0.3.0
beartype>=0.15.0,<0.16.0
vector-quantize-pytorch>=1.12.0,<1.13.0
scipy>=1.9.3,<1.10.0
loguru
rich==13.7.1
tiktoken==0.6.0
Expand Down
6 changes: 3 additions & 3 deletions zeta/quant/qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.stats import norm
# from scipy.stats import norm
from tqdm import tqdm

bnb_available = False
Expand Down Expand Up @@ -362,9 +362,9 @@ def get_nf4(cached=True) -> torch.Tensor:
)

offset = 0.9677083
v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist()
v1 = torch.linspace(offset, 0.5, 9)[:-1].tolist()
# v2 = [0]*(256-15)
v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
v3 = (torch.linspace(offset, 0.5, 8)[:-1]).tolist()
# v = v1 + v3 + 0.0
nkf = torch.tensor(v1 + v3 + [0.0])
nkf = nkf.sort().values
Expand Down

0 comments on commit 4d19440

Please sign in to comment.