Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement quantization on-the-fly #100

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ This option would require a little more manual work, but you can use it with any

```commandline
# Windows
python rwkv\convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float16
python rwkv\convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin FP16

# Linux / MacOS
python rwkv/convert_pytorch_to_ggml.py ~/Downloads/RWKV-4-Pile-169M-20220807-8023.pth ~/Downloads/rwkv.cpp-169M.bin float16
python rwkv/convert_pytorch_to_ggml.py ~/Downloads/RWKV-4-Pile-169M-20220807-8023.pth ~/Downloads/rwkv.cpp-169M.bin FP16
```

**Optionally**, quantize the model into one of quantized formats from the table above:
Expand Down
169 changes: 130 additions & 39 deletions rwkv.cpp

Large diffs are not rendered by default.

43 changes: 43 additions & 0 deletions rwkv.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,53 @@ extern "C" {
// - ctx: the context the retrieve the error for, or NULL for the global error.
RWKV_API enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx);

enum rwkv_init_from_file_option_key {
// Sets target format of model parameters.
//
// If an FP16 or FP32 model is being loaded, and this option is set,
// parameters will be quantized just-in-time into the specified format.
// If an already quantized model is being loaded, value of this option is ignored.
// The function will not read the whole model file at once, but will do quantization tensor-by-tensor;
// it is safe to load big models which will fit into RAM when quantized.
// Use of this option will introduce significant one-time delay when loading the model.
//
// Intended use-case is to have only FP16 model on disk, while not wasting
// the disk space on models of all available quantized formats.
//
// Allowed values:
// - Q4_0
// - Q4_1
// - Q5_0
// - Q5_1
// - Q8_0
RWKV_INIT_FROM_FILE_OPTION_TARGET_FORMAT_NAME,
// Do not use this as an actual option key.
RWKV_INIT_FROM_FILE_OPTION_COUNT
};

struct rwkv_init_from_file_option {
// Key of the option.
enum rwkv_init_from_file_option_key key;
// Value of the option as a NULL-terminated, UTF-8 encoded string.
char * value;
};

// Loads the model from a file and prepares it for inference.
// Loading behavior can be customized with options, but none of them are required.
// Function behavior when multiple options with the same key are specified is undefined.
// Returns NULL on any error.
// - model_file_path: path to model file in ggml format.
// - n_threads: count of threads to use, must be positive.
// - options: array of options. Passing NULL is the same as setting option_count to 0.
// - option_count: size of the options array.
RWKV_API struct rwkv_context * rwkv_init_from_file_ex(
const char * model_file_path,
const uint32_t n_threads,
const struct rwkv_init_from_file_option * options,
const size_t option_count
);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LoganDark I think now the interface is generic enough to painlessly add new options in the future -- for mmap, etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eh... this does not inspire confidence for some reason. I am not sure why. I think all the existing parameters should be moved to the options structure, but also that the library needs more work before it can move to an options structure at all.

loading from file itself I intended to move into its own option, because for really insane use cases, I'm literally thinking of things like streaming the model from the network so it doesn't touch the disk at all. I imagine this being used for something like microcontrollers that don't have a filesystem. it sounds really stupid, I know, but it's a contrived example.

one of the things I planned to do first was move rwkv.cpp into using multiple files because I think its file is getting quite long and is a bit disorganized, with file reading functions and inference functions and quantization functions all in the same file. I think it works for ggml but rwkv.cpp is getting long enough that it's somewhat uncomfortable to navigate.

it's probably a bit weird of me to say that I already had a roadmap in mind but I don't think an interim solution like this would be very great, especially since having it here would encourage us to keep it.

so I would probably either hold off on merging this (I was planning to implement it myself anyway) or find a way that doesn't involve moving to an options dict so soon. but I think rwkv.cpp does not need quantize on load at all yet - it will become more useful when it can load directly from pytorch checkpoints, as those cannot be quantized at all, so quantizing on load would be the only option, but that subsystem does not exist yet and I will account for when it does.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already had a roadmap in mind
so I would probably either hold off on merging this

If possible, can you share the roadmap (even if it is rough), and some timelines? Speciflcally for PyTorch loading support

Honestly, there is no hurry to merging the PR, since everything worked fine before it and no one complained. But I would like to have somewhat good reason to postpone it.

one of the things I planned to do first was move rwkv.cpp into using multiple files because I think its file is getting quite long and is a bit disorganized

Completely agree, had these ideas myself. Not related to this PR tho :)

loading from file itself I intended to move into its own option

I don't think rwkv.cpp should support network loading or other non-file use cases; so file_path argument most probably will stay. As you said yourself, these use cases are insane lol

Copy link
Contributor

@LoganDark LoganDark Jun 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible, can you share the roadmap (even if it is rough), and some timelines? Speciflcally for PyTorch loading support

I plan to implement in this rough order:

  • new implementation of WKV for sequence mode (should be faster), but may not be finished
  • PR my own python bindings for the library that should handle errors a bit better / be slightly easier to use
  • reorganize rwkv.cpp into files (maybe move into src directly and use #include), no need to complicate the cmake
  • make the loading system more generic (should support any method of loading) , probably be a total API redesign
  • implement mmap, pytorch loading, quantize on load on top of that
  • I also have ideas for a compressed model format, maybe with magic number "ggzf" for "zip", because I noticed that pytorch checkpoints are far smaller than ggml models due to using zip compression, and I would love to reduce the size of large ggml models (by 10-20gb per model !!) by using compression. This could also speed up load times :)

also, I want to make model loading one-shot again (only read the file once), because depending on fseek and fstat and ftell is hurting our cross platform compatibility. Additionally that would remove the dependency on a hash map at runtime (is it hash map ? some kind of map) to load the tensors directly into the model. I have a working version of this in rust actually, but would need to be ported to C++ (should be easy).

Anyway, overall the goal is to make the library a lot more flexible, it was specialized as a prototype to load a single model from a binary file and evaluate single tokens, but it'll get a lot more exciting and faster if we make it more flexible.

Imagine downloading a compressed model file, and either loading it directly, or using the library itself to decompress it and then using mmap (without requiring python). Or even imagine downloading fresh pytorch checkpoints, minutes after BlinkDL first releases them, and either converting them tensor by tensor (like quantization) or just using them that way.

Imagine using this on desktops, servers, mobile phones, embedded devices (possibly with TPU ?!), whatever.

Imagine training models with rwkv.cpp, too (that is not on my roadmap because I don't know how I would do that yet, but I can still dream :3)

Not related to this PR tho :)

It's related as in I consider it a blocker, i.e. I wouldn't implement the options until the source code is organized enough.

non-file use cases

You know mmap support is the biggest non-file use case. rwkv.cpp will have to implement loading from memory anyway. The only difference is whether we allow third party programs to use this functionality. Ideally it would be implemented in such a way that rwkv.cpp will not have to support network loading or anything insane like that. It would just support "any kind of loading" and programs would be able to implement their own network loading if they wanted

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loading from file itself I intended to move into its own option, because for really insane use cases, I'm literally thinking of things like streaming the model from the network so it doesn't touch the disk at all. I imagine this being used for something like microcontrollers that don't have a filesystem. it sounds really stupid, I know, but it's a contrived example.

Mind blowing. I think this actually a excellent use case for rwkv. For my limited understanding for the rwkv internal , the context memory is constant and the memory access pattern is sequential(backward or forward). So it makes a lots of sense to convert the source of truth(f16 weight) to latest quantized format on the fly, much like the load time jit compiler

Both of you 🤘

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lin72h the source of truth is actually the f32 version, as that's what BlinkDL trains, but f16 would still count as a source of truth if you're using it to generate a quantized model. :)


// Same as rwkv_init_from_file_ex, but passing an empty array of options.
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads);

// Creates a new context from an existing one.
Expand Down
12 changes: 7 additions & 5 deletions rwkv/convert_pytorch_to_ggml.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Converts an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file.
# Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32
# Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin FP16
# Get model checkpoints from https://huggingface.co/BlinkDL
# See FILE_FORMAT.md for the documentation on the file format.

Expand All @@ -12,7 +12,7 @@ def parse_args():
parser = argparse.ArgumentParser(description='Convert an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file')
parser.add_argument('src_path', help='Path to PyTorch checkpoint file')
parser.add_argument('dest_path', help='Path to rwkv.cpp checkpoint file, will be overwritten')
parser.add_argument('data_type', help='Data type, float16 or float32', type=str, choices=['float16', 'float32'], default='float32')
parser.add_argument('data_type', help='Data type, FP16 or FP32', type=str, choices=['FP16', 'FP32'], default='FP16')
return parser.parse_args()

def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int:
Expand All @@ -26,6 +26,8 @@ def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int:
return n_layer

def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_type: str) -> None:
is_FP16 = data_type == 'FP16' or data_type == 'float16'

emb_weight: torch.Tensor = state_dict['emb.weight']

n_layer = get_layer_count(state_dict)
Expand All @@ -42,7 +44,7 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
n_vocab,
n_embed,
n_layer,
1 if data_type == 'float16' else 0
1 if is_FP16 else 0
))

for k in state_dict.keys():
Expand All @@ -56,8 +58,8 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
if '.time_decay' in k:
tensor = -torch.exp(tensor)

# Keep 1-dim vectors in fp32
if data_type == 'float16' and len(tensor.shape) > 1:
# Keep 1-dim vectors in FP32
if is_FP16 and len(tensor.shape) > 1:
tensor = tensor.half()

shape = tensor.shape
Expand Down
2 changes: 1 addition & 1 deletion rwkv/convert_pytorch_to_ggml.test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test() -> None:
'blocks.0.ln1.weight': torch.tensor([1], dtype=torch.float32)
}

convert_pytorch_to_ggml.write_state_dict(state_dict, dest_path=test_file_path, data_type='float32')
convert_pytorch_to_ggml.write_state_dict(state_dict, dest_path=test_file_path, data_type='FP32')

with open(test_file_path, 'rb') as input:
actual_bytes: bytes = input.read()
Expand Down
7 changes: 5 additions & 2 deletions rwkv/rwkv_cpp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import multiprocessing
import rwkv_cpp_shared_library
from typing import Tuple, Optional
from typing import Dict, Tuple, Optional

class RWKVModel:
"""
Expand All @@ -15,6 +15,7 @@ def __init__(
model_path: str,
thread_count: int = max(1, multiprocessing.cpu_count() // 2),
gpu_layers_count: int = 0,
options: Optional[Dict[rwkv_cpp_shared_library.RWKVInitFromFileOptionKey, str]] = None
):
"""
Loads the model and prepares it for inference.
Expand All @@ -28,6 +29,8 @@ def __init__(
Path to RWKV model file in ggml format.
thread_count : int
Thread count to use. If not set, defaults to CPU count / 2.
options : Optional[Dict[RWKVInitFromFileOptionKey, str]]
Options passed to rwkv_init_from_file_ex.
"""

assert os.path.isfile(model_path), f'{model_path} is not a file'
Expand All @@ -36,7 +39,7 @@ def __init__(

self._library = shared_library

self._ctx = self._library.rwkv_init_from_file(model_path, thread_count)
self._ctx = self._library.rwkv_init_from_file(model_path, thread_count, options)

if gpu_layers_count > 0:
self._library.rwkv_gpu_offload_layers(self._ctx, gpu_layers_count)
Expand Down
51 changes: 46 additions & 5 deletions rwkv/rwkv_cpp_shared_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import sys
import ctypes
import pathlib
from typing import Optional
import enum
from typing import Dict, Optional

QUANTIZED_FORMAT_NAMES = (
'Q4_0',
Expand All @@ -14,6 +15,29 @@

P_FLOAT = ctypes.POINTER(ctypes.c_float)

class RWKVInitFromFileOptionKey(enum.Enum):
# Sets target format of model parameters.
#
# If an FP16 or FP32 model is being loaded, and this option is set,
# parameters will be quantized just-in-time into the specified format.
# If an already quantized model is being loaded, value of this option is ignored.
# The function will not read the whole model file at once, but will do quantization tensor-by-tensor;
# it is safe to load big models which will fit into RAM when quantized.
# Use of this option will introduce significant one-time delay when loading the model.
#
# Intended use-case is to have only FP16 model on disk, while not wasting
# the disk space on models of all available quantized formats.
#
# For allowed values, see QUANTIZED_FORMAT_NAMES.
TARGET_FORMAT_NAME = 0

class RWKVInitFromFileOption(ctypes.Structure):

_fields_ = [
('key', ctypes.c_int),
('value', ctypes.c_char_p)
]

class RWKVContext:

def __init__(self, ptr: ctypes.pointer):
Expand All @@ -37,8 +61,8 @@ def __init__(self, shared_library_path: str):

self.library = ctypes.cdll.LoadLibrary(shared_library_path)

self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32]
self.library.rwkv_init_from_file.restype = ctypes.c_void_p
self.library.rwkv_init_from_file_ex.argtypes = [ctypes.c_char_p, ctypes.c_uint32, ctypes.POINTER(RWKVInitFromFileOption), ctypes.c_size_t]
self.library.rwkv_init_from_file_ex.restype = ctypes.c_void_p

self.library.rwkv_gpu_offload_layers.argtypes = [ctypes.c_void_p, ctypes.c_uint32]
self.library.rwkv_gpu_offload_layers.restype = ctypes.c_bool
Expand Down Expand Up @@ -70,9 +94,10 @@ def __init__(self, shared_library_path: str):
self.library.rwkv_get_system_info_string.argtypes = []
self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p

def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext:
def rwkv_init_from_file(self, model_file_path: str, thread_count: int, options: Optional[Dict[RWKVInitFromFileOptionKey, str]] = None) -> RWKVContext:
"""
Loads the model from a file and prepares it for inference.
Loading behavior can be customized with options, but none of them are required.
Throws an exception in case of any error. Error messages would be printed to stderr.

Parameters
Expand All @@ -81,9 +106,25 @@ def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVCo
Path to model file in ggml format.
thread_count : int
Count of threads to use, must be positive.
options : Optional[Dict[RWKVInitFromFileOptionKey, str]]
Options passed to rwkv_init_from_file_ex.
"""

ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count))
options_count = 0
options_ptr = None

if options is not None and len(options) > 0:
options_count = len(options)
options_ptr = (RWKVInitFromFileOption * options_count)()

i = 0
for k, v in options.items():
options_ptr[i].key = k.value
options_ptr[i].value = v.encode('utf-8')

i += 1

ptr = self.library.rwkv_init_from_file_ex(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count), options_ptr, options_count)

assert ptr is not None, 'rwkv_init_from_file failed, check stderr'

Expand Down
7 changes: 4 additions & 3 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ file(COPY tiny-rwkv-660K-FP32.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
file(COPY tiny-rwkv-660K-FP16.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
file(COPY expected_logits.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR})

rwkv_add_test(test_ggml_basics.c)
rwkv_add_test(test_tiny_rwkv.c)
rwkv_add_test(test_context_cloning.c)
file(GLOB tests *.c)
foreach (test ${tests})
rwkv_add_test(${test})
endforeach()
4 changes: 3 additions & 1 deletion tests/test_context_cloning.c
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include <rwkv.h>
// Tests that after context cloning evaluation gives identical results.

#include "rwkv.h"

#include <stdlib.h>
#include <stdio.h>
Expand Down
91 changes: 91 additions & 0 deletions tests/test_quantization_on_the_fly.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Tests that results from on-the-fly quantized model are identical with results of pre-quantized model.

#include "ggml.h"
#include "rwkv.h"

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define N_THREADS 2

int main(void) {
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q5_1.bin", "Q5_1");

struct rwkv_context * prequantized_ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32-Q5_1.bin", N_THREADS);

if (!prequantized_ctx) {
enum rwkv_error_flags error = rwkv_get_last_error(NULL);
fprintf(stderr, "Unexpected error 0x%.8X\n", error);
return EXIT_FAILURE;
}

// ---

struct rwkv_init_from_file_option option = {RWKV_INIT_FROM_FILE_OPTION_TARGET_FORMAT_NAME, "Q5_1"};

struct rwkv_context * on_the_fly_quantized_ctx = rwkv_init_from_file_ex("tiny-rwkv-660K-FP32.bin", N_THREADS, &option, 1);

if (!on_the_fly_quantized_ctx) {
enum rwkv_error_flags error = rwkv_get_last_error(NULL);
fprintf(stderr, "Unexpected error 0x%.8X\n", error);
return EXIT_FAILURE;
}

// ---

float * state = calloc(rwkv_get_state_len(prequantized_ctx), sizeof(float));

if (!state) {
fprintf(stderr, "Failed to allocate state\n");
return EXIT_FAILURE;
}

float * expected_logits = calloc(rwkv_get_logits_len(prequantized_ctx), sizeof(float));

if (!expected_logits) {
fprintf(stderr, "Failed to allocate logits\n");
return EXIT_FAILURE;
}

const unsigned char prompt[12] = "hello world";

rwkv_eval(prequantized_ctx, prompt[0], NULL, state, expected_logits);

for (int i = 1; prompt[i] != 0; i++) {
rwkv_eval(prequantized_ctx, prompt[i], state, state, expected_logits);
}

// ---

float * actual_logits = calloc(rwkv_get_logits_len(on_the_fly_quantized_ctx), sizeof(float));

if (!actual_logits) {
fprintf(stderr, "Failed to allocate logits\n");
return EXIT_FAILURE;
}

rwkv_eval(on_the_fly_quantized_ctx, prompt[0], NULL, state, actual_logits);

for (int i = 1; prompt[i] != 0; i++) {
rwkv_eval(on_the_fly_quantized_ctx, prompt[i], state, state, actual_logits);
}

// ---

if (memcmp(expected_logits, actual_logits, rwkv_get_logits_len(on_the_fly_quantized_ctx) * sizeof(float))) {
fprintf(stderr, "Results not identical :(\n");
return EXIT_FAILURE;
} else {
fprintf(stdout, "Results identical, success!\n");
}

rwkv_free(on_the_fly_quantized_ctx);
rwkv_free(prequantized_ctx);

free(expected_logits);
free(actual_logits);
free(state);

return 0;
}