diff --git a/CMakeLists.txt b/CMakeLists.txt index b4708eae..75cfaafe 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -71,6 +71,9 @@ target_link_libraries(main fastllm) add_executable(quant tools/src/quant.cpp) target_link_libraries(quant fastllm) +add_executable(testOps test/ops/cppOps.cpp) +target_link_libraries(testOps fastllm) + add_executable(webui example/webui/webui.cpp) target_link_libraries(webui fastllm) add_custom_command( @@ -113,4 +116,4 @@ else() ) endif() -endif() \ No newline at end of file +endif() diff --git a/include/fastllm.h b/include/fastllm.h index 6030b41a..c5a8284f 100644 --- a/include/fastllm.h +++ b/include/fastllm.h @@ -287,6 +287,8 @@ namespace fastllm { void PrintShape() const; // 输出形状 + std::vector Shape() const; + void Print() const; // 输出 void CalcWeightSum(); // 计算WeightSum diff --git a/pyfastllm/README.md b/pyfastllm/README.md index 7dae01ea..ec22614b 100644 --- a/pyfastllm/README.md +++ b/pyfastllm/README.md @@ -12,6 +12,12 @@ pyfastllm是基于fastllm的python api接口实现,通过pyfastllm可以更加 ## 版本更新 +### v0.1.4 2023-09-12 + +- 修复了一些后端接口变动的bug +- 增加了新的ops, 支持低级op操作 + + ### v0.1.3 2023-07-08 - 增加使用和API接口文档 @@ -80,6 +86,7 @@ demo/cli_thread.py: 多线程调用api接口示例(推荐) demo/cli_low_api.py: 底层API调用示例 demo/convert_model.py: 模型转换示例 demo/web_api.py, demo/web_api_client.py: fastapi webapi调用 +demo/test_ops: 部分op的使用样例及测试 ### 命令行工具 @@ -187,8 +194,10 @@ python web_api.py -m 0 -p path_for_chatglm --max_batch_size 32 - [x] 修改response_batch的output_str函数,以返回值的形式返回答案 - [x] 编解码部分优化,合并不同的返回类型 +- [ ] 对接numpy等矩阵库 - [ ] Tensor的深复制和浅复制,以及基础运算符重载 - [ ] fix low_api下pastKV复制的bug -- [ ] 模型运行参数对象类,封装模型运行时参数,包含模型路径、运行线程数、是否为低内存模型、惩罚因子、温度等 -- [ ] 暴露更多的底层api接口,按照module的方式定义模型的点,拼接model实现自定义model +- [x] 模型运行参数对象类,封装模型运行时参数,包含模型路径、运行线程数、是否为低内存模型、惩罚因子、温度等 +- [ ] 增加更多的op +- [ ] 增加module diff --git a/pyfastllm/build_libs.py b/pyfastllm/build_libs.py index 2fa3cd5b..0aeb8c1b 100644 --- a/pyfastllm/build_libs.py +++ b/pyfastllm/build_libs.py @@ -3,6 +3,7 @@ import platform import sys import argparse +import glob parser = argparse.ArgumentParser(description='build fastllm libs') parser.add_argument('--cuda', dest='cuda', action='store_true', default=False, @@ -23,20 +24,24 @@ def build_libs(): os.makedirs(cmake_build_dir) os.chdir(cmake_build_dir) - # build it + # build it + cpu_num = min(os.cpu_count(), 4) args = parser.parse_args() if IS_WINDOWS: - os.system('cmake -G "Ninja" -DPY_API=ON .. && ninja pyfastllm') + os.system('cmake -G Ninja -DPY_API=ON .. && ninja pyfastllm') elif IS_LINUX: extra_opts = ' -DPY_API=ON ' extra_opts += ' -DUSE_CUDA=ON ' if args.cuda else ' ' - build_cmd = 'cmake ' + extra_opts + ' .. && make pyfastllm -j4' + build_cmd = f"cmake {extra_opts} .. && make pyfastllm -j{cpu_num}" print(build_cmd) - os.system('cmake ' + extra_opts + ' .. && make pyfastllm -j4') + os.system(f"cmake {extra_opts} .. && make pyfastllm -j{cpu_num}") else: extra_opts = '-DPY_API=ON' - os.system('cmake ' + extra_opts + '.. && make pyfastllm -j4') - + os.system(f"cmake {extra_opts} .. && make pyfastllm -j{cpu_num}") + + so_files = glob.glob("*.so", root_dir=cmake_build_dir) + for file in so_files: + shutil.copy(os.path.join(cmake_build_dir, file), os.path.join(root_dir, "pyfastllm/fastllm")) if __name__ == '__main__': - build_libs() + build_libs() \ No newline at end of file diff --git a/pyfastllm/demo/cli.py b/pyfastllm/demo/cli.py index 07cf4a55..6a3f9b16 100644 --- a/pyfastllm/demo/cli.py +++ b/pyfastllm/demo/cli.py @@ -3,7 +3,7 @@ import platform import logging import argparse -sys.path.append('./build-py') +sys.path.append('../../build-py') import pyfastllm # 或fastllm logging.info(f"python gcc version:{platform.python_compiler()}") diff --git a/pyfastllm/demo/test_ops.py b/pyfastllm/demo/test_ops.py new file mode 100644 index 00000000..bc7c2118 --- /dev/null +++ b/pyfastllm/demo/test_ops.py @@ -0,0 +1,93 @@ +import fastllm +import numpy as np + +def np_rms_norm(inputs, weights, eps): + channel = inputs.shape[-1] + sqrt_mean = np.sqrt(np.sum(inputs**2)/channel + eps) + return inputs / sqrt_mean *weights + + +def np_layer_norm(inputs, gamma, beta, axis=-1): + assert axis < len(inputs.shapes), "axis should less than inputs dims" + channel = inputs.shape[axis] + mean = np.mean(inputs, axis=axis) + var = np.var(inputs, axis=axis) + + output = (inputs - mean) / var * gamma + beta + return output + +def np_linear(inputs, weights, bias): + output = np.matmul(inputs, weights.T) + bias + return output + +def np_softmax(inputs, axis=None): + maxv = inputs.max(axis, keepdims=True) + exp_v = np.exp(inputs - maxv) + exp_sum = np.sum(exp_v, axis=axis) + return exp_v / exp_sum + +def np_silu(inputs, ): + return inputs / (1 + np.exp(-inputs)) + +def np_attention(q, k, v, mask=None, group=None, scale=None): + qk = np_softmax(q @ k.T * scale, axis=-1) + attn = qk @ v + return attn + +def test_linear(): + inputs = np.array([[1, 2]]) + weight = np.array([[3, 4, 5, 5, 6, 7]]).reshape([3, 2]) + bias = np.array([0, 1, 1]) + np_output = np_linear(inputs, weight, bias) + print(np_output) + + input = fastllm.Tensor(fastllm.float32, [1, 2], [1, 2]) + weights = fastllm.Tensor(fastllm.float32, [3, 2], [3, 4, 5, 5, 6, 7]) + bias = fastllm.Tensor(fastllm.float32, [3], [0, 1, 1]) + out = fastllm.ops.linear(input, weights, bias) + print(out) + +def test_rms_norm(): + inputs = np.array([1, 5]).reshape([1, 2]) + weights = np.array([1, 3]).reshape([1, 2]) + eps = 1e-6 + + np_out = np_rms_norm(inputs, weights, eps) + print(np_out) + + input = fastllm.Tensor(fastllm.float32, [1, 2], [1, 5]) + weights = fastllm.Tensor(fastllm.float32, [1, 2], [1, 3]) + out = fastllm.Tensor() + out = fastllm.ops.rms_norm(input, weights, eps=1e-6) + print(out) + +def test_silu(): + inputs = np.array([1, 5]).reshape([1, 2]) + output = np_softmax(inputs) + # output = np_silu(inputs) + print(output) + + inputs = fastllm.Tensor(fastllm.float32, [1, 2], [1, 5]) + out = fastllm.ops.activation(input=inputs, activate_type="softmax") + # out = fastllm.ops.activation(input=inputs, activate_type="silu") + print(out) + +def test_attention(): + q = np.array([1, 2, 3, 4, 5, 6]).reshape([2, 3]) + k = np.array([5, 6, 7, 8, 9, 10]).reshape([2, 3]) + v = np.array([1, 1, 1, 2, 1, 3]).reshape([2, 3]) + scale = 1 / np.sqrt(q.shape[-1]) + output = np_attention(q, k, v, scale=scale) + print(output) + + q = fastllm.Tensor(fastllm.float32, [1, 2, 3], [1, 2, 3, 4, 5, 6]) + k = fastllm.Tensor(fastllm.float32, [1, 2, 3], [5, 6, 7, 8, 9, 10]) + v = fastllm.Tensor(fastllm.float32, [1, 2, 3], [1, 1, 1, 2, 1, 3]) + mask = fastllm.Tensor() + output = fastllm.ops.attention(q, k, v, mask, group=1, scale=scale, attentionType=0) + print(output) + +test_attention() +test_silu() +test_linear() +test_rms_norm() diff --git a/pyfastllm/fastllm/__init__.py b/pyfastllm/fastllm/__init__.py index ef89d923..8fdc0b53 100644 --- a/pyfastllm/fastllm/__init__.py +++ b/pyfastllm/fastllm/__init__.py @@ -1,2 +1,14 @@ +import os +import sys +import ctypes +import glob + +_BASE_DIR = os.path.dirname(__file__) +sys.path.append(_BASE_DIR) +# libs = glob.glob("*.so") +# for lib in libs: _cdll = ctypes.cdll.LoadLibrary(lib) + from pyfastllm import * -from . import utils \ No newline at end of file +from . import utils +from . import functions as ops + diff --git a/pyfastllm/fastllm/functions/__init__.py b/pyfastllm/fastllm/functions/__init__.py new file mode 100644 index 00000000..efb0421c --- /dev/null +++ b/pyfastllm/fastllm/functions/__init__.py @@ -0,0 +1 @@ +from .fastllm_ops import * diff --git a/pyfastllm/fastllm/functions/custom_ops.py b/pyfastllm/fastllm/functions/custom_ops.py new file mode 100644 index 00000000..e69de29b diff --git a/pyfastllm/fastllm/functions/fastllm_ops.py b/pyfastllm/fastllm/functions/fastllm_ops.py new file mode 100644 index 00000000..33cc13f1 --- /dev/null +++ b/pyfastllm/fastllm/functions/fastllm_ops.py @@ -0,0 +1,78 @@ +import pyfastllm as fastllm + + +def embedding(data: fastllm.Tensor, ): + # some check + return fastllm.embedding(data, ) + +def rms_norm(input:fastllm.Tensor, weight: fastllm.Tensor, eps: float, output: fastllm.Tensor=None): + output = fastllm.rms_norm(input, weight, eps) + return output + +def layer_norm(input: fastllm.Tensor, + gamma: fastllm.Tensor, + beta: fastllm.Tensor, + axis:int=-1 ): + output = fastllm.layer_norm(input, gamma, beta,axis) + return output + +def linear(input: fastllm.Tensor, + weight: fastllm.Tensor, + bias: fastllm.Tensor): + output = fastllm.linear(input, weight, bias) + return output + +def matmul(input0: fastllm.Tensor, + input1: fastllm.Tensor, + alpha: fastllm.Tensor): + output = fastllm.matmul(input0, input1, alpha) + return output + +def attention(q: fastllm.Tensor, + k: fastllm.Tensor, + v: fastllm.Tensor, + mask: fastllm.Tensor, + group: int, + scale: float, + attentionType: int): + output = fastllm.attention(q, k, v, mask, group, scale, attentionType) + return output + +def activation(input: fastllm.Tensor, axis=-1, activate_type="silu"): + assert activate_type in ("softmax", "silu", "gelu", "swiglu") + func = getattr(fastllm, activate_type) + if activate_type == "softmax": + return func(input, axis) + return func(input) + +def mul(input: fastllm.Tensor, v: int): + output = fastllm.mul(input, v) + return output + +def matmul_transB(): + pass + +def add(input0: fastllm.Tensor, input1: fastllm.Tensor): + output = fastllm.add(input0, input1) + return output + +def AttentionMask(): + pass + +def AlibiMask(): + pass + +def topk(): + pass + +def RotatePosition2D(): + pass + +def NearlyRotatePosition2D(): + pass + +def LlamaRotatePosition2D(): + pass + +def RepeatPenalty(): + pass diff --git a/pyfastllm/fastllm/models.py b/pyfastllm/fastllm/models.py new file mode 100644 index 00000000..90200180 --- /dev/null +++ b/pyfastllm/fastllm/models.py @@ -0,0 +1,168 @@ +#!encoding=utf8 +import fastllm +import logging +from typing import List, Tuple + +try: + import torch +except Exception as e: + logging.warn("You must install torch before using this module!") + + +class ModelConfig(): + def __init__(self, **argv) -> None: + self._dict = dict(argv) + self._c_config = None + + for attr, value in argv.items(): + setattr(self, attr, value) + + def _to_dict(self, ): + return self._dict + + def to_c_config(self, ): + attr_map = { + 'max_length': 'output_token_limit', + } + + if not self._c_config: + self._c_config = fastllm.GenerationConfig() + + for attr, value in self._dict: + setattr(self._c_config, attr_map.get(attr) or attr, value) + + return self._c_config + + def __str__(self, ): + print("ModelConfig: ") + for key, value in self._dict.items(): + print(f"{key} : {value}") + + +class baseModel: + def __init__ (self, config:ModelConfig): + pass + + def chat(self, ): + pass + + def get_input_embeddings(self, ): + pass + + def set_input_embeddings(self, ): + pass + + def get_prompt(self, ): + pass + + def forward(self, ): + pass + + +def chatBaseModel(): + def get_output_embeddings(): + pass + + def set_output_embeddings(): + pass + + def prepare_inputs_for_generation(): + pass + + def forward(): + pass + + def stream_chat(): + pass + + def stream_generate(): + pass + + + + def get_prompt(self, + query: str, + history: List[Tuple[str, str]] = None) -> str: + if (not(history)): + history = [] + prompt = "" + for i, (old_query, response) in enumerate(history): + prompt = fastllm.make_history_llm_model(self.model, prompt.encode(), i, old_query.encode(), response.encode()).decode() + prompt = fastllm.make_input_llm_model(self.model, prompt.encode(), len(history), query.encode()).decode() + return prompt + + def save(self, path : str): + fastllm.save_llm_model(self.model, path.encode()) + + def response(self, + query: str, + history: List[Tuple[str, str]] = None) -> str: + prompt = query if self.direct_query else self.get_prompt(query, history) + ret = fastllm.response_str_llm_model(self.model, prompt.encode()).decode() + return ret + + def stream_response(self, + query: str, + history: List[Tuple[str, str]] = None, + one_by_one = True): + prompt = query if self.direct_query else self.get_prompt(query, history) + handle = fastllm.launch_response_str_llm_model(self.model, prompt.encode()) + res = "" + ret = b'' + while True: + ret += fastllm.fetch_response_str_llm_model(self.model, handle) + cur = "" + try: + cur = ret.decode() + ret = b'' + except: + pass + if (cur == ""): + break + if one_by_one: + yield cur + else: + res += cur + yield res + + def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1, + do_sample = True, top_p = 0.8, temperature = 0.8, logits_processor = None, **kwargs): + if (not(history)): + history = [] + prompt = query if self.direct_query else self.get_prompt(query, history) + print("prompt", prompt) + input = tokenizer.encode(prompt) + handle = fastllm.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input)) + + result = [] + while True: + cur = fastllm.fetch_response_llm_model(self.model, handle) + if (cur == -1): + break + result.append(cur) + response = tokenizer.decode(result) + history = history + [(query, response)] + return response, history + + def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values = None, + max_length: int = 8192, do_sample = True, top_p = 0.8, temperature = 0.8, logits_processor = None, + return_past_key_values = False, **kwargs) -> str: + if (not(history)): + history = [] + prompt = query if self.direct_query else self.get_prompt(query, history) + input = tokenizer.encode(prompt) + handle = fastllm.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input)) + tokens = [] + while True: + cur = fastllm.fetch_response_llm_model(self.model, handle) + if (cur == -1): + break + tokens.append(cur) + response = tokenizer.decode(tokens) + new_history = history + [(query, response)] + if return_past_key_values: + yield response, new_history, None + else: + yield response, new_history + + diff --git a/pyfastllm/fastllm/nn/BaseModule.py b/pyfastllm/fastllm/nn/BaseModule.py new file mode 100644 index 00000000..942834e4 --- /dev/null +++ b/pyfastllm/fastllm/nn/BaseModule.py @@ -0,0 +1,18 @@ +from typing import Any + + +class Module(): + def __init__(self) -> None: + pass + + def __call__(self, *args: Any, **kwds: Any) -> Any: + return self.forward(*args, **args) + + @classmethod + def forward(self, ): + pass + + def _init_weight(self, ): + pass + + diff --git a/pyfastllm/fastllm/nn/__init__.py b/pyfastllm/fastllm/nn/__init__.py new file mode 100644 index 00000000..f026804c --- /dev/null +++ b/pyfastllm/fastllm/nn/__init__.py @@ -0,0 +1 @@ +from BaseModule import Module diff --git a/pyfastllm/setup.py b/pyfastllm/setup.py index 4082fd8c..6aa63e56 100644 --- a/pyfastllm/setup.py +++ b/pyfastllm/setup.py @@ -11,62 +11,68 @@ args, unknown = parser.parse_known_args() sys.argv = [sys.argv[0]] + unknown -__VERSION__ = "'0.1.3'" +__VERSION__ = "'0.1.4'" BASE_DIR = os.path.dirname(os.path.dirname(__file__)) -ext_modules = [] -try: - from pybind11.setup_helpers import Pybind11Extension - source_files = glob.glob(os.path.join(BASE_DIR, "src/**/*.cpp"), recursive=True) - for file in source_files: - if file.endswith("cudadevice.cpp"): - source_files.remove(file) - extra_compile_args = ["-w", "-DPY_API"] - # If any libraries are used, e.g. libabc.so - include_dirs = [os.path.join(BASE_DIR, "include/"), os.path.join(BASE_DIR, "include/devices/cpu/"), os.path.join(BASE_DIR, "include/models"), os.path.join(BASE_DIR, "include/utils")] - library_dirs = [] - - # (optional) if the library is not in the dir like `/usr/lib/` - # either to add its dir to `runtime_library_dirs` or to the env variable "LD_LIBRARY_PATH" - # MUST be absolute path - runtime_library_dirs = [] - libraries = [] +def config_ext(): + ext_modules = [] + try: + from pybind11.setup_helpers import Pybind11Extension + source_files = glob.glob(os.path.join(BASE_DIR, "src/**/*.cpp"), recursive=True) + for file in source_files: + if file.endswith("cudadevice.cpp"): + source_files.remove(file) - if args.cuda: - assert False, "Not Implement Yet!" - extra_compile_args.append("-DUSE_CUDA -Wl,-rpath,$ORIGIN/") + extra_compile_args = ["-w", "-DPY_API"] + # If any libraries are used, e.g. libabc.so + include_dirs = [os.path.join(BASE_DIR, "include/"), os.path.join(BASE_DIR, "include/devices/cpu/"), os.path.join(BASE_DIR, "include/models"), os.path.join(BASE_DIR, "include/utils")] + library_dirs = [] + + # (optional) if the library is not in the dir like `/usr/lib/` + # either to add its dir to `runtime_library_dirs` or to the env variable "LD_LIBRARY_PATH" + # MUST be absolute path + runtime_library_dirs = [] + libraries = [] - source_files.append(os.path.join(BASE_DIR, "src/devices/cuda/cudadevice.cpp")) - include_dirs.append(os.path.join(BASE_DIR, "include/devices/cuda/")) + if args.cuda: + assert False, "Not Implement Yet!" + extra_compile_args.append("-DUSE_CUDA -Wl,-rpath,$ORIGIN/") - library_dirs.append("/usr/local/cuda/lib64/") - library_dirs.append(os.path.join(BASE_DIR, "pyfastllm/")) + source_files.append(os.path.join(BASE_DIR, "src/devices/cuda/cudadevice.cpp")) + include_dirs.append(os.path.join(BASE_DIR, "include/devices/cuda/")) - libraries.append("fastllm_cuda") + library_dirs.append("/usr/local/cuda/lib64/") + library_dirs.append(os.path.join(BASE_DIR, "pyfastllm/")) - ext_modules = [ - Pybind11Extension( - "pyfastllm", - source_files, - define_macros=[('VERSION_INFO', __VERSION__)], - include_dirs=include_dirs, - library_dirs=library_dirs, - runtime_library_dirs=runtime_library_dirs, - libraries=libraries, - extra_compile_args=extra_compile_args, - cxx_std=17, - language='c++' - ), - ] -except Exception as e: - print(f"some errors happened: ") - print(e) - sys.exit(1) + libraries.append("fastllm_cuda") -cmdclass = {} + ext_modules = [ + Pybind11Extension( + "pyfastllm", + source_files, + define_macros=[('VERSION_INFO', __VERSION__)], + include_dirs=include_dirs, + library_dirs=library_dirs, + runtime_library_dirs=runtime_library_dirs, + libraries=libraries, + extra_compile_args=extra_compile_args, + cxx_std=17, + language='c++' + ), + ] + except Exception as e: + print(f"some errors happened: ") + print(e) + sys.exit(1) + + return ext_modules +cmdclass = {} +dyn_libs = glob.glob("*.so", root_dir="./fastllm") +dyn_libs += glob.glob("*.dll", root_dir="./fastllm") +# print(dyn_libs) setup( name='fastllm', version=eval(__VERSION__), @@ -77,14 +83,18 @@ maintainer_email='', url='', long_description='', - ext_modules=ext_modules, + # ext_modules=ext_modules, + # packages = ['fastllm', 'fastllm.utils'], packages = find_packages(), + package_data={ + 'fastllm': dyn_libs, + }, cmdclass=cmdclass, - setup_requires=["pybind11"], + setup_requires=[""], install_requires=[""], python_requires='>=3.6', - # data_files = [('', ['libfastllm_cuda.so'])], - include_package_data=False, + # data_files = [('', dyn_libs)], + include_package_data=True, entry_points={ 'console_scripts':[ 'fastllm-convert = fastllm.convert:main' diff --git a/pyfastllm/test_func.sh b/pyfastllm/test_func.sh new file mode 100644 index 00000000..1f8083d9 --- /dev/null +++ b/pyfastllm/test_func.sh @@ -0,0 +1,7 @@ +pip uninstall -y fastllm +rm -rf fastllm/pyfastllm.cpython-310-x86_64-linux-gnu.so +rm -rf build/ +python3 build_libs.py +python3 setup.py sdist bdist_wheel +pip install dist/fastllm-0.1.4-py3-none-any.whl +python3 demo/test_ops.py \ No newline at end of file diff --git a/src/fastllm.cpp b/src/fastllm.cpp index f9c97f0b..7bc4a1d3 100644 --- a/src/fastllm.cpp +++ b/src/fastllm.cpp @@ -247,6 +247,7 @@ namespace fastllm { } Data::Data(fastllm::DataType type, const std::vector &dims, const std::vector &data) : Data::Data(type, dims) { + // std::cout<<"调用数值构造"<Allocate(); if (type == DataType::FLOAT32) { std::memcpy(this->cpuData, data.data(), this->GetBytes()); @@ -258,6 +259,7 @@ namespace fastllm { } void Data::CopyFrom(const Data &ori) { + // std::cout<<"调用拷贝构造"<dims || this->cpuData == nullptr) { if (ori.dims.size() == 0) { delete[] this->cpuData; @@ -515,6 +517,10 @@ namespace fastllm { printf("\n"); } + std::vector Data::Shape() const{ + return this->dims; + } + void Data::Print() const { printf("shape: "); for (int i : this->dims) { diff --git a/src/pybinding.cpp b/src/pybinding.cpp index d5454582..ca7c0aae 100644 --- a/src/pybinding.cpp +++ b/src/pybinding.cpp @@ -1,6 +1,104 @@ #include "model.h" #include "factoryllm.h" +namespace pyfastllm{ + // 对接不断更新的后端接口 + // 需优化,减少内存拷贝 + fastllm::Data RMSNorm(const fastllm::Data &input, const fastllm::Data &weight, float eps){ + fastllm::Data output; + // std::cout<<"run rms norm"< 0) ss += "\n"; + for (int j = 0; j < 10 && j < m; j++) { + if (j>0) ss += " "; + ss += std::to_string(reinterpret_cast(data.cpuData)[i*m+j]); + } + if (m > 10) { + ss += "... "; + for (int j = 0; j < 10 && j < m; j++) { + if (j>0) ss += " "; + ss += std::to_string(reinterpret_cast(data.cpuData)[i*m + (m-10+j)]); + } + } + + } + ss += "]"; + return ss; + } +} + #ifdef PY_API #include #include @@ -48,6 +146,34 @@ PYBIND11_MODULE(pyfastllm, m) { // low level m.def("get_llm_type", &fastllm::GetModelTypeFromFile); + m.def("llm_sampling", &fastllm::LLMSampling) + // .def("embedding", &fastllm::Embedding) + .def("rms_norm", &pyfastllm::RMSNorm) + .def("layer_norm", &pyfastllm::LayerNorm) + .def("linear", &pyfastllm::Linear) + // .def("split", &fastllm::Split) + // .def("cat", &fastllm::Cat) + // .def("cat_direct", &fastllm::CatDirect) + .def("matmul", &pyfastllm::MatMul) + // .def("matmul_transB", &fastllm::MatMulTransB) + .def("softmax", &pyfastllm::Softmax) + .def("silu", &pyfastllm::Silu) + .def("gelu", &pyfastllm::Gelu) + .def("swiglu", &pyfastllm::Swiglu) + .def("mul", &pyfastllm::Mul) + .def("attention", &pyfastllm::Attention); + // .def("mul_to", &fastllm::MulTo) + // .def("add_to", &fastllm::AddTo) + // .def("attention_mask", &fastllm::AttentionMask) + // .def("alibi_mask", &fastllm::AlibiMask) + // .def("permute", &fastllm::Permute) + // .def("permute_self", &fastllm::PermuteSelf) + // .def("topk", &fastllm::TopK) + // .def("rotateposition2D", &fastllm::RotatePosition2D) + // .def("nearlyrotateposition2D", &fastllm::NearlyRotatePosition2D) + // .def("llama_rotateposition2D", &fastllm::LlamaRotatePosition2D) + // .def("repeat_penalty", &fastllm::RepeatPenalty); + py::enum_(m, "Dtype") .value("float32", fastllm::DataType::FLOAT32) .value("bfloat16", fastllm::DataType::BFLOAT16) @@ -60,13 +186,25 @@ PYBIND11_MODULE(pyfastllm, m) { .value("int32param", fastllm::DataType::INT32PARAM) .export_values(); - py::class_(m, "Tensor") + py::class_(m, "Tensor", py::buffer_protocol()) + .def_buffer([](fastllm::Data &m) -> py::buffer_info { + return py::buffer_info( + m.cpuData, /* Pointer to buffer */ + sizeof(float), /* Size of one scalar */ + py::format_descriptor::format(), /* Python struct-style format descriptor */ + m.dims.size(), /* Number of dimensions */ + m.dims, /* Buffer dimensions */ + { sizeof(float) * m.dims[1], /* Strides (in bytes) for each index */ + sizeof(float) } + ); + }) .def_readonly("dims", &fastllm::Data::dims) .def(py::init<>()) .def(py::init()) .def(py::init&>()) .def(py::init&, const std::vector&>()) .def(py::init()) + .def_readonly("shape", &fastllm::Data::dims) .def("copy_from", &fastllm::Data::CopyFrom) .def("count", &fastllm::Data::Count) .def("to_list", [](fastllm::Data& data){ @@ -76,6 +214,7 @@ PYBIND11_MODULE(pyfastllm, m) { } return vecData; }) + .def("__str__", &pyfastllm::String) .def("print", &fastllm::Data::Print) .def("to", static_cast(&fastllm::Data::ToDevice)); diff --git a/test/ops/cppOps.cpp b/test/ops/cppOps.cpp new file mode 100644 index 00000000..d4dd865f --- /dev/null +++ b/test/ops/cppOps.cpp @@ -0,0 +1,147 @@ +#include "fastllm.h" + +void callBaseOp(int optype=0){ + fastllm::Data inputs = fastllm::Data(fastllm::DataType::FLOAT32, {1, 2}, {1, 5}); + fastllm::Data outputs = fastllm::Data(fastllm::DataType::FLOAT32, {1, 2}, {3, 4}); + + switch (optype) + { + case 0: + fastllm::AddTo(inputs, outputs, 1); + break; + case 1: + fastllm::Cat(inputs, inputs, 0, outputs); + break; + case 2: + fastllm::Mul(inputs, 2, outputs); + break; + case 3: + fastllm::Permute(inputs, {1, 0}, outputs); + break; + case 4: + fastllm::Split(inputs, 0, 0, 1, outputs); + break; + case 5: + fastllm::Permute(inputs, {1, 0}, outputs); + fastllm::MatMul(inputs, outputs, outputs); + break; + default: + break; + } + outputs.Print(); +} + +void callNormOp(int normType=0){ + fastllm::Data inputs = fastllm::Data(fastllm::DataType::FLOAT32, {1, 2}, {1, 5}); + fastllm::Data weights = fastllm::Data(fastllm::DataType::FLOAT32, {1, 2}, {1, 2}); + fastllm::Data gamma = fastllm::Data(fastllm::DataType::FLOAT32, {1, 2}, {1, 1}); + fastllm::Data beta = fastllm::Data(fastllm::DataType::FLOAT32, {1, 2}, {0, 0}); + fastllm::Data outputs; + + switch (normType) + { + case 0: + fastllm::LayerNorm(inputs, gamma, beta, -1, outputs); + break; + case 1: + fastllm::RMSNorm(inputs, weights, 1e-5, outputs); + break; + default: + break; + } + outputs.Print(); +} + + +void callLinearOp(){ + fastllm::Data inputs = fastllm::Data(fastllm::DataType::FLOAT32, {1, 2}, {1, 2}); + fastllm::Data weights = fastllm::Data(fastllm::DataType::FLOAT32, {3, 2}, {3, 4, 5, 5, 6, 7}); + fastllm::Data bias = fastllm::Data(fastllm::DataType::FLOAT32, {1, 3}, {0, 1, 1}); + fastllm::Data outputs; + fastllm::Linear(inputs, weights, bias, outputs); + outputs.Print(); +} + +void callActivationOp(int activateType=0){ + fastllm::Data inputs = fastllm::Data(fastllm::DataType::FLOAT32, {1, 2}, {1, 5}); + fastllm::Data outputs; + switch (activateType) + { + case 0: + fastllm::Silu(inputs, outputs); + break; + case 1: + fastllm::Softmax(inputs, outputs, -1); + break; + case 2: + fastllm::GeluNew(inputs, outputs); + break; + case 3: + fastllm::Swiglu(inputs, outputs); + break; + default: + break; + } + outputs.Print(); +} + +void callAttentionOp(int group=1, int attentionType=0){ + const fastllm::Data q = fastllm::Data(fastllm::DataType::FLOAT32, {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + const fastllm::Data k = fastllm::Data(fastllm::DataType::FLOAT32, {1, 2, 3}, {5, 6, 7, 8, 9, 10}); + const fastllm::Data v = fastllm::Data(fastllm::DataType::FLOAT32, {1, 2, 3}, {1, 1, 1, 2, 1, 3}); + const fastllm::Data mask = fastllm::Data(); + int dims = q.dims.back(); + float scale = 1/sqrt(dims); + fastllm::Data output; + + fastllm::Attention(q, k, v, mask, output, group, scale, attentionType); +} + +void testBase(){ + printf("testing BaseOp...\n"); + for (int i=0;i<6;i++){ + callBaseOp(i); + } + printf("test BaseOp finished!\n"); +} + +void testActivation(){ + printf("testing ActivationOp...\n"); + for (int i=0;i<4;i++){ + callActivationOp(i); + } + printf("test ActivationOp finished!\n"); +} + +void testAttention(){ + printf("testing AttentionOp...\n"); + callAttentionOp(); + printf("test AttentionOp finished!\n"); +} + +void testLinaer(){ + printf("testing LinearOp...\n"); + callLinearOp(); + printf("test LinearOp finished!\n"); +} + +void testNorm(){ + printf("testing NormOp...\n"); + for (int i=0;i<2;i++){ + callNormOp(i); + } + printf("test NormOp finished!\n"); +} + +void testAll(){ + testBase(); + testActivation(); + testAttention(); + testNorm(); + testLinaer(); +} + + +int main(){ + testAll(); +} \ No newline at end of file diff --git a/third_party/pybind11 b/third_party/pybind11 deleted file mode 160000 index 84932280..00000000 --- a/third_party/pybind11 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 849322806cd4b3697ad1d35eedd6d0352c5f267a