-
Notifications
You must be signed in to change notification settings - Fork 99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement quantization on-the-fly #100
Draft
saharNooby
wants to merge
6
commits into
master
Choose a base branch
from
quantization-on-the-fly
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
65cdb51
Implement on-the-fly quantization
saharNooby dca26e9
Resolve TODO items
saharNooby 7a13fd2
Fix error code
saharNooby 4d27fa8
Reformat code
saharNooby d3b6749
Consistently use FP16 and FP32 for rwkv.cpp data types
saharNooby c49d3d8
Add test for on-the-fly quantization
saharNooby File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
// Tests that results from on-the-fly quantized model are identical with results of pre-quantized model. | ||
|
||
#include "ggml.h" | ||
#include "rwkv.h" | ||
|
||
#include <stdio.h> | ||
#include <stdlib.h> | ||
#include <string.h> | ||
|
||
#define N_THREADS 2 | ||
|
||
int main(void) { | ||
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q5_1.bin", "Q5_1"); | ||
|
||
struct rwkv_context * prequantized_ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32-Q5_1.bin", N_THREADS); | ||
|
||
if (!prequantized_ctx) { | ||
enum rwkv_error_flags error = rwkv_get_last_error(NULL); | ||
fprintf(stderr, "Unexpected error 0x%.8X\n", error); | ||
return EXIT_FAILURE; | ||
} | ||
|
||
// --- | ||
|
||
struct rwkv_init_from_file_option option = {RWKV_INIT_FROM_FILE_OPTION_TARGET_FORMAT_NAME, "Q5_1"}; | ||
|
||
struct rwkv_context * on_the_fly_quantized_ctx = rwkv_init_from_file_ex("tiny-rwkv-660K-FP32.bin", N_THREADS, &option, 1); | ||
|
||
if (!on_the_fly_quantized_ctx) { | ||
enum rwkv_error_flags error = rwkv_get_last_error(NULL); | ||
fprintf(stderr, "Unexpected error 0x%.8X\n", error); | ||
return EXIT_FAILURE; | ||
} | ||
|
||
// --- | ||
|
||
float * state = calloc(rwkv_get_state_len(prequantized_ctx), sizeof(float)); | ||
|
||
if (!state) { | ||
fprintf(stderr, "Failed to allocate state\n"); | ||
return EXIT_FAILURE; | ||
} | ||
|
||
float * expected_logits = calloc(rwkv_get_logits_len(prequantized_ctx), sizeof(float)); | ||
|
||
if (!expected_logits) { | ||
fprintf(stderr, "Failed to allocate logits\n"); | ||
return EXIT_FAILURE; | ||
} | ||
|
||
const unsigned char prompt[12] = "hello world"; | ||
|
||
rwkv_eval(prequantized_ctx, prompt[0], NULL, state, expected_logits); | ||
|
||
for (int i = 1; prompt[i] != 0; i++) { | ||
rwkv_eval(prequantized_ctx, prompt[i], state, state, expected_logits); | ||
} | ||
|
||
// --- | ||
|
||
float * actual_logits = calloc(rwkv_get_logits_len(on_the_fly_quantized_ctx), sizeof(float)); | ||
|
||
if (!actual_logits) { | ||
fprintf(stderr, "Failed to allocate logits\n"); | ||
return EXIT_FAILURE; | ||
} | ||
|
||
rwkv_eval(on_the_fly_quantized_ctx, prompt[0], NULL, state, actual_logits); | ||
|
||
for (int i = 1; prompt[i] != 0; i++) { | ||
rwkv_eval(on_the_fly_quantized_ctx, prompt[i], state, state, actual_logits); | ||
} | ||
|
||
// --- | ||
|
||
if (memcmp(expected_logits, actual_logits, rwkv_get_logits_len(on_the_fly_quantized_ctx) * sizeof(float))) { | ||
fprintf(stderr, "Results not identical :(\n"); | ||
return EXIT_FAILURE; | ||
} else { | ||
fprintf(stdout, "Results identical, success!\n"); | ||
} | ||
|
||
rwkv_free(on_the_fly_quantized_ctx); | ||
rwkv_free(prequantized_ctx); | ||
|
||
free(expected_logits); | ||
free(actual_logits); | ||
free(state); | ||
|
||
return 0; | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LoganDark I think now the interface is generic enough to painlessly add new options in the future -- for
mmap
, etc.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eh... this does not inspire confidence for some reason. I am not sure why. I think all the existing parameters should be moved to the options structure, but also that the library needs more work before it can move to an options structure at all.
loading from file itself I intended to move into its own option, because for really insane use cases, I'm literally thinking of things like streaming the model from the network so it doesn't touch the disk at all. I imagine this being used for something like microcontrollers that don't have a filesystem. it sounds really stupid, I know, but it's a contrived example.
one of the things I planned to do first was move rwkv.cpp into using multiple files because I think its file is getting quite long and is a bit disorganized, with file reading functions and inference functions and quantization functions all in the same file. I think it works for ggml but rwkv.cpp is getting long enough that it's somewhat uncomfortable to navigate.
it's probably a bit weird of me to say that I already had a roadmap in mind but I don't think an interim solution like this would be very great, especially since having it here would encourage us to keep it.
so I would probably either hold off on merging this (I was planning to implement it myself anyway) or find a way that doesn't involve moving to an options dict so soon. but I think rwkv.cpp does not need quantize on load at all yet - it will become more useful when it can load directly from pytorch checkpoints, as those cannot be quantized at all, so quantizing on load would be the only option, but that subsystem does not exist yet and I will account for when it does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If possible, can you share the roadmap (even if it is rough), and some timelines? Speciflcally for PyTorch loading support
Honestly, there is no hurry to merging the PR, since everything worked fine before it and no one complained. But I would like to have somewhat good reason to postpone it.
Completely agree, had these ideas myself. Not related to this PR tho :)
I don't think
rwkv.cpp
should support network loading or other non-file use cases; sofile_path
argument most probably will stay. As you said yourself, these use cases are insane lolThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I plan to implement in this rough order:
src
directly and use#include
), no need to complicate the cmakealso, I want to make model loading one-shot again (only read the file once), because depending on fseek and fstat and ftell is hurting our cross platform compatibility. Additionally that would remove the dependency on a hash map at runtime (is it hash map ? some kind of map) to load the tensors directly into the model. I have a working version of this in rust actually, but would need to be ported to C++ (should be easy).
Anyway, overall the goal is to make the library a lot more flexible, it was specialized as a prototype to load a single model from a binary file and evaluate single tokens, but it'll get a lot more exciting and faster if we make it more flexible.
Imagine downloading a compressed model file, and either loading it directly, or using the library itself to decompress it and then using mmap (without requiring python). Or even imagine downloading fresh pytorch checkpoints, minutes after BlinkDL first releases them, and either converting them tensor by tensor (like quantization) or just using them that way.
Imagine using this on desktops, servers, mobile phones, embedded devices (possibly with TPU ?!), whatever.
Imagine training models with rwkv.cpp, too (that is not on my roadmap because I don't know how I would do that yet, but I can still dream :3)
It's related as in I consider it a blocker, i.e. I wouldn't implement the options until the source code is organized enough.
You know mmap support is the biggest non-file use case. rwkv.cpp will have to implement loading from memory anyway. The only difference is whether we allow third party programs to use this functionality. Ideally it would be implemented in such a way that rwkv.cpp will not have to support network loading or anything insane like that. It would just support "any kind of loading" and programs would be able to implement their own network loading if they wanted
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mind blowing. I think this actually a excellent use case for rwkv. For my limited understanding for the rwkv internal , the context memory is constant and the memory access pattern is sequential(backward or forward). So it makes a lots of sense to convert the source of truth(f16 weight) to latest quantized format on the fly, much like the load time jit compiler
Both of you 🤘
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lin72h the source of truth is actually the f32 version, as that's what BlinkDL trains, but f16 would still count as a source of truth if you're using it to generate a quantized model. :)