>>102898557
This is for the weight scaling.
import sys
import safetensors.torch
import torch
sd = safetensors.torch.load_file(sys.argv[1])
dtype_to = torch.float8_e4m3fn
inf = torch.finfo(dtype_to)
max_value = max(inf.max, abs(inf.min))
out_sd = {}
for k in sd:
if k.endswith(".weight"):
w = sd[k].float()
if w.dim() == 2:
calc = torch.max(torch.abs(w)) / max_value
out_sd[k] = (w / calc).clip(min=inf.min, max=inf.max).to(dtype=dtype_to)
out_sd[k[:-len(".weight")] + ".scale_weight"] = calc
else:
out_sd[k] = sd[k]
else:
out_sd[k] = sd[k]
out_sd["scaled_fp8"] = torch.tensor([])
safetensors.torch.save_file(out_sd, sys.argv[2])
The input scaling has to be calculated during inference but it's not necessary for flux it just increases quality when doing the fp8 matrix mults if your gpu supports doing them in fp8.