Anonymous
04/25/24(Thu)23:32:48 No.100184295 I tried optimizing a quantized model with gradient descent to minimize the error relative to the original model.
Specifically, I hacked torchtune to replace the bf16 weight matrices in all the nn.Linear layers with the separate qs/scales/d components of the quantized Q6_K format. These are stored as floating-point "latent weights" (like in the bitnet paper), but the forward pass rounds, clamps, and runs the normal Q6_K dequantization on the fly, so the weights that actually get used in the linear layers are always exactly those of some valid Q6_K quant.
I took a normal Q6_K quant of Llama3 8B Instruct and separately optimized each layer to minimize the error it introduces relative to the same layer in the original model. I did each layer separately to minimize VRAM usage, since I eventually want to apply this to L3 70B. This took about 6 hours on 1x 4090. Then I converted the results back to a Q6_K GGUF.
KL-divergence on wiki.test.raw improved by a small amount:
old: 0.004234
new: 0.003945
delta: 0.000289 (7%)
I have a few ideas to try next to improve the results:
>Train 2-4 layers together. My hope is that this will give the optimizer some flexibility to have the layers cancel out each other's errors.
>Train the layers sequentially. First, train layer 0 as normal. Then, instead of training layer 1 to map layer_0_fp16_output -> layer_1_fp16_output, train it to map layer_0_quant_output -> layer_1_fp16_output. This lets the layer 1 optimizer know about the errors introduced by layer 0 so it can correct for them.
But I'd love to hear other suggestions too. I'm sure there are anons ITT who know a lot more about ML than I do