Skip to content

Commit

Permalink
Merge pull request #312 from wildkid1024/master
Browse files Browse the repository at this point in the history
增加了新的ops, 支持低级op操作
  • Loading branch information
ztxz16 authored Sep 27, 2023
2 parents 60de06f + d920823 commit 435583e
Show file tree
Hide file tree
Showing 19 changed files with 761 additions and 63 deletions.
5 changes: 4 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ target_link_libraries(main fastllm)
add_executable(quant tools/src/quant.cpp)
target_link_libraries(quant fastllm)

add_executable(testOps test/ops/cppOps.cpp)
target_link_libraries(testOps fastllm)

add_executable(webui example/webui/webui.cpp)
target_link_libraries(webui fastllm)
add_custom_command(
Expand Down Expand Up @@ -113,4 +116,4 @@ else()
)
endif()

endif()
endif()
2 changes: 2 additions & 0 deletions include/fastllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ namespace fastllm {

void PrintShape() const; // 输出形状

std::vector<int> Shape() const;

void Print() const; // 输出

void CalcWeightSum(); // 计算WeightSum
Expand Down
13 changes: 11 additions & 2 deletions pyfastllm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ pyfastllm是基于fastllm的python api接口实现,通过pyfastllm可以更加

## 版本更新

### v0.1.4 2023-09-12

- 修复了一些后端接口变动的bug
- 增加了新的ops, 支持低级op操作


### v0.1.3 2023-07-08

- 增加使用和API接口文档
Expand Down Expand Up @@ -80,6 +86,7 @@ demo/cli_thread.py: 多线程调用api接口示例(推荐)
demo/cli_low_api.py: 底层API调用示例
demo/convert_model.py: 模型转换示例
demo/web_api.py, demo/web_api_client.py: fastapi webapi调用
demo/test_ops: 部分op的使用样例及测试

### 命令行工具

Expand Down Expand Up @@ -187,8 +194,10 @@ python web_api.py -m 0 -p path_for_chatglm --max_batch_size 32

- [x] 修改response_batch的output_str函数,以返回值的形式返回答案
- [x] 编解码部分优化,合并不同的返回类型
- [ ] 对接numpy等矩阵库
- [ ] Tensor的深复制和浅复制,以及基础运算符重载
- [ ] fix low_api下pastKV复制的bug
- [ ] 模型运行参数对象类,封装模型运行时参数,包含模型路径、运行线程数、是否为低内存模型、惩罚因子、温度等
- [ ] 暴露更多的底层api接口,按照module的方式定义模型的点,拼接model实现自定义model
- [x] 模型运行参数对象类,封装模型运行时参数,包含模型路径、运行线程数、是否为低内存模型、惩罚因子、温度等
- [ ] 增加更多的op
- [ ] 增加module

19 changes: 12 additions & 7 deletions pyfastllm/build_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import platform
import sys
import argparse
import glob

parser = argparse.ArgumentParser(description='build fastllm libs')
parser.add_argument('--cuda', dest='cuda', action='store_true', default=False,
Expand All @@ -23,20 +24,24 @@ def build_libs():
os.makedirs(cmake_build_dir)
os.chdir(cmake_build_dir)

# build it
# build it
cpu_num = min(os.cpu_count(), 4)
args = parser.parse_args()
if IS_WINDOWS:
os.system('cmake -G "Ninja" -DPY_API=ON .. && ninja pyfastllm')
os.system('cmake -G Ninja -DPY_API=ON .. && ninja pyfastllm')
elif IS_LINUX:
extra_opts = ' -DPY_API=ON '
extra_opts += ' -DUSE_CUDA=ON ' if args.cuda else ' '
build_cmd = 'cmake ' + extra_opts + ' .. && make pyfastllm -j4'
build_cmd = f"cmake {extra_opts} .. && make pyfastllm -j{cpu_num}"
print(build_cmd)
os.system('cmake ' + extra_opts + ' .. && make pyfastllm -j4')
os.system(f"cmake {extra_opts} .. && make pyfastllm -j{cpu_num}")
else:
extra_opts = '-DPY_API=ON'
os.system('cmake ' + extra_opts + '.. && make pyfastllm -j4')

os.system(f"cmake {extra_opts} .. && make pyfastllm -j{cpu_num}")

so_files = glob.glob("*.so", root_dir=cmake_build_dir)
for file in so_files:
shutil.copy(os.path.join(cmake_build_dir, file), os.path.join(root_dir, "pyfastllm/fastllm"))

if __name__ == '__main__':
build_libs()
build_libs()
2 changes: 1 addition & 1 deletion pyfastllm/demo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import platform
import logging
import argparse
sys.path.append('./build-py')
sys.path.append('../../build-py')
import pyfastllm # 或fastllm

logging.info(f"python gcc version:{platform.python_compiler()}")
Expand Down
93 changes: 93 additions & 0 deletions pyfastllm/demo/test_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import fastllm
import numpy as np

def np_rms_norm(inputs, weights, eps):
channel = inputs.shape[-1]
sqrt_mean = np.sqrt(np.sum(inputs**2)/channel + eps)
return inputs / sqrt_mean *weights


def np_layer_norm(inputs, gamma, beta, axis=-1):
assert axis < len(inputs.shapes), "axis should less than inputs dims"
channel = inputs.shape[axis]
mean = np.mean(inputs, axis=axis)
var = np.var(inputs, axis=axis)

output = (inputs - mean) / var * gamma + beta
return output

def np_linear(inputs, weights, bias):
output = np.matmul(inputs, weights.T) + bias
return output

def np_softmax(inputs, axis=None):
maxv = inputs.max(axis, keepdims=True)
exp_v = np.exp(inputs - maxv)
exp_sum = np.sum(exp_v, axis=axis)
return exp_v / exp_sum

def np_silu(inputs, ):
return inputs / (1 + np.exp(-inputs))

def np_attention(q, k, v, mask=None, group=None, scale=None):
qk = np_softmax(q @ k.T * scale, axis=-1)
attn = qk @ v
return attn

def test_linear():
inputs = np.array([[1, 2]])
weight = np.array([[3, 4, 5, 5, 6, 7]]).reshape([3, 2])
bias = np.array([0, 1, 1])
np_output = np_linear(inputs, weight, bias)
print(np_output)

input = fastllm.Tensor(fastllm.float32, [1, 2], [1, 2])
weights = fastllm.Tensor(fastllm.float32, [3, 2], [3, 4, 5, 5, 6, 7])
bias = fastllm.Tensor(fastllm.float32, [3], [0, 1, 1])
out = fastllm.ops.linear(input, weights, bias)
print(out)

def test_rms_norm():
inputs = np.array([1, 5]).reshape([1, 2])
weights = np.array([1, 3]).reshape([1, 2])
eps = 1e-6

np_out = np_rms_norm(inputs, weights, eps)
print(np_out)

input = fastllm.Tensor(fastllm.float32, [1, 2], [1, 5])
weights = fastllm.Tensor(fastllm.float32, [1, 2], [1, 3])
out = fastllm.Tensor()
out = fastllm.ops.rms_norm(input, weights, eps=1e-6)
print(out)

def test_silu():
inputs = np.array([1, 5]).reshape([1, 2])
output = np_softmax(inputs)
# output = np_silu(inputs)
print(output)

inputs = fastllm.Tensor(fastllm.float32, [1, 2], [1, 5])
out = fastllm.ops.activation(input=inputs, activate_type="softmax")
# out = fastllm.ops.activation(input=inputs, activate_type="silu")
print(out)

def test_attention():
q = np.array([1, 2, 3, 4, 5, 6]).reshape([2, 3])
k = np.array([5, 6, 7, 8, 9, 10]).reshape([2, 3])
v = np.array([1, 1, 1, 2, 1, 3]).reshape([2, 3])
scale = 1 / np.sqrt(q.shape[-1])
output = np_attention(q, k, v, scale=scale)
print(output)

q = fastllm.Tensor(fastllm.float32, [1, 2, 3], [1, 2, 3, 4, 5, 6])
k = fastllm.Tensor(fastllm.float32, [1, 2, 3], [5, 6, 7, 8, 9, 10])
v = fastllm.Tensor(fastllm.float32, [1, 2, 3], [1, 1, 1, 2, 1, 3])
mask = fastllm.Tensor()
output = fastllm.ops.attention(q, k, v, mask, group=1, scale=scale, attentionType=0)
print(output)

test_attention()
test_silu()
test_linear()
test_rms_norm()
14 changes: 13 additions & 1 deletion pyfastllm/fastllm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,14 @@
import os
import sys
import ctypes
import glob

_BASE_DIR = os.path.dirname(__file__)
sys.path.append(_BASE_DIR)
# libs = glob.glob("*.so")
# for lib in libs: _cdll = ctypes.cdll.LoadLibrary(lib)

from pyfastllm import *
from . import utils
from . import utils
from . import functions as ops

1 change: 1 addition & 0 deletions pyfastllm/fastllm/functions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .fastllm_ops import *
Empty file.
78 changes: 78 additions & 0 deletions pyfastllm/fastllm/functions/fastllm_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pyfastllm as fastllm


def embedding(data: fastllm.Tensor, ):
# some check
return fastllm.embedding(data, )

def rms_norm(input:fastllm.Tensor, weight: fastllm.Tensor, eps: float, output: fastllm.Tensor=None):
output = fastllm.rms_norm(input, weight, eps)
return output

def layer_norm(input: fastllm.Tensor,
gamma: fastllm.Tensor,
beta: fastllm.Tensor,
axis:int=-1 ):
output = fastllm.layer_norm(input, gamma, beta,axis)
return output

def linear(input: fastllm.Tensor,
weight: fastllm.Tensor,
bias: fastllm.Tensor):
output = fastllm.linear(input, weight, bias)
return output

def matmul(input0: fastllm.Tensor,
input1: fastllm.Tensor,
alpha: fastllm.Tensor):
output = fastllm.matmul(input0, input1, alpha)
return output

def attention(q: fastllm.Tensor,
k: fastllm.Tensor,
v: fastllm.Tensor,
mask: fastllm.Tensor,
group: int,
scale: float,
attentionType: int):
output = fastllm.attention(q, k, v, mask, group, scale, attentionType)
return output

def activation(input: fastllm.Tensor, axis=-1, activate_type="silu"):
assert activate_type in ("softmax", "silu", "gelu", "swiglu")
func = getattr(fastllm, activate_type)
if activate_type == "softmax":
return func(input, axis)
return func(input)

def mul(input: fastllm.Tensor, v: int):
output = fastllm.mul(input, v)
return output

def matmul_transB():
pass

def add(input0: fastllm.Tensor, input1: fastllm.Tensor):
output = fastllm.add(input0, input1)
return output

def AttentionMask():
pass

def AlibiMask():
pass

def topk():
pass

def RotatePosition2D():
pass

def NearlyRotatePosition2D():
pass

def LlamaRotatePosition2D():
pass

def RepeatPenalty():
pass
Loading

0 comments on commit 435583e

Please sign in to comment.