-
Notifications
You must be signed in to change notification settings - Fork 170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
mithral C++ parameter optimization #29
Comments
From my understanding the C++ code does not contain parameter optimization. Only the python code called here does. I have been trying to piece together what pieces from the python code need to be exported to make the cpp code work but I'm failing at some pieces. For example:
here at the same time creating the query lookup tables as shown here I can't seem to figure out what to pass into the cpp equivalent functions which have the functions:
Overall while the papers results look impressive, I don't think it is currently possible to reproduce the speed and accuracy numbers reported in the paper with a single codebase. The python code produces the correct effectiveness numbers and tradeoffs but without proper integration into the c++ bindings don't show the same speed characteristics as what is reported in the paper. |
@mpetri Im also having problems with using the python version to learn the parameters and the c++ version to do the rest. But Im still facing some problems with the calculation process. The last n/2 columns of my output matrix are always zeroed out (this also happens when not loading the previously learned parameters into the object, hence just running First I generate X in python and run Im also curious how @dumpinfo has managed the py -> c++ parameter migration. |
I have the same problem. I was printing the
counting the different Maybe @dblalock could help out here to help us move forward? Do you have some code you can share? |
Yes, padding with an extra zero at the end is probably the best solution. There are logically only 15, but the C++ cares about alignment and so wants blocks of 16 of them. The best overall solution would be having clean wrappers for the C++ code and having the Python call those instead of the Python implementations. This would basically just solve everything. But...I just never got to that. The results in the paper just join a table of accuracies and a table of speeds on the matrix shapes + num_codebooks, which is, in the words of the experiments readme, "kind of an abomination." |
@fjrdev do you have some code to share that extracts the python values and imports into the c++ codebase? |
@mpetri I print the split dims, split values , centroids, scales and offsets to a What procedure are you following when learning the parameters in python? |
@fjrdev I ported the python code to rust that produces equivalent results. However, I'm now trying to incorporate the more efficient c++ code and I'm encountering more and more issues. For example:
It is unclear if scaling this into into int8_t ranges is safe and is never explained anywhere any thoughts? after spending substantial time on this I'm getting more convinced that without the help of the original author this can't be made to work as intended :( |
The int8 vs uint8 shouldn't matter for the purpose of splitting. Any affine transformation is fine as long as it's applied to the data and the split vals the same way. If I remember correctly, the split vals are each supposed to be sequences of 16x8bit vectors. It's one such vector per codebook. So the split vals array is a contiguous sequence of C 16B vectors, for a total of 16C bytes. The first split val is at element 0 in each vector, then the next two vals are at indices {1, 2}, etc. The last element is unused and just for alignment. I think (based on the encoding logic + the fact that it's the minimal representation) that the splitdims are just C sequences of 4 ints, all contiguous. And wow, I am really regretting not commenting the relevant functions better--I thought I had the main stuff doxygenated like in bolt.hpp, but boy was I wrong. Sorry about that. Also, kind of a moot point, but I don't think column-major is that weird; it's the default in Eigen, Matlab, and Julia IIRC. But you're totally right that transposing will add overhead if you're starting with a rowmajor matrix. |
"If I remember correctly, the split vals are each supposed to be sequences of 16x8bit vectors. It's one such vector per codebook." @dblalock Im a little bit confused by the dimensions of the splitvals matrix. Since there are (# codebooks * 4) columns I assumed that one codebook fills 4 columns of the splitvals matrix of the The corresponding part of the splitvals matrix would look like this: According to your answer, the correct alignment would look like this: I don't quite understand what happens to the 3 remaining columns reserved for this codebook. |
I looked at the code some more and I think your first (zero-padded) version is correct. There's no packing into a single vector happening. int split_idx = 0;
for (int c = 0; c < ncodebooks; c++) {
// compute input and output column starts
...
for (int s = 0; s < nsplits_per_codebook; s++) {
...
auto splitvals_ptr = all_splitvals + (vals_per_split * split_idx);
current_vsplitval_luts[s] = _mm256_broadcastsi128_si256(
load_si128i((const __m128i*)splitvals_ptr));
}
split_idx += nsplits_per_codebook; |
Maybe I don't really understand the algorithm correctly but from my understand you use the split vals to walk down a binary tree updating the codes with 2 * code or 2 * code +1 if you go left or right in the tree. This is what this function does in the python version of the however, there is a dependency obviously on the having processed the first level of the binary tree to decide what path in the tree to take next. if we load |
Actually looking at the code some more I see we have a for loop over blocks and inside that a for loop over split vals so that would mean we can make those decisions sequentially! |
Yes, it's not obvious because it's SIMD-ified, but we actually are walking the tree in this loop (which I think is the one you're referring to). The node for a given input is stored in the |
Hello @mpetri , Can you share your python code to split the result of |
@VpouL A wrong result at out_mat(N, M)? |
@fjrdev Just like you, splitdims and all_splitvals can be derived from the MithralEncoder::splits_lists array but I don't know how to properly extract the values form the python code. |
Does the mithral C++ codebase provide functions for parameter optimization? I have problems finding it and executing
run_matmul()
in the structmithral_amm_task
returns a wrong output matrix, since the optimization parameters are set randomly. Thank you!The text was updated successfully, but these errors were encountered: