>>103230340
If you want to use it's something like:
dataset = load_dataset(dataset_name, split="train")
print(len(dataset))
# Tokenize the dataset
dataset = dataset.map(lambda x: tokenizer(x['text'], truncation=True, padding=True, max_length=1024), batched=True)
dataset.format
dataset = dataset.with_format("torch")
print(dataset.format)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
## initialize SMT
smt = SMT(model, dataloader,
sparsity_ratio=0.01, warmup_steps=100)
# ## you can print the actual trainable numbers in SMT
smt.print_trainable_params()
optimizer = torch.optim.AdamW(smt.get_trainable_params(), lr=learning_rate)
But I can't guarantee that it works. I mostly posted because I'm hoping someone smarter than me can make it better lol.