Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2:4 Sparse training #1425

Open
phyllispeng123 opened this issue Dec 17, 2024 · 4 comments
Open

2:4 Sparse training #1425

phyllispeng123 opened this issue Dec 17, 2024 · 4 comments
Assignees
Labels

Comments

@phyllispeng123
Copy link

Hi, when I was doing the model training using 2:4 semi-sparse mask following the tutorial https://pytorch.org/tutorials/prototype/semi_structured_sparse.html?highlight=transformer, I find that the sparsity.step() does not update mask in training ( I saved FakeSparsity.mask at each layer at each epoch)because the saved masks are always the same. I thought the 2:4 semi-sparse training will train the mask and also finetune the model iteself. Did I misunderstood?

@supriyar
Copy link
Contributor

cc @jcaip

@jcaip
Copy link
Contributor

jcaip commented Dec 18, 2024

Hi @phyllispeng123 the tutorial describes a one-shot pruning flow, where we only calculate the mask once (before fine-tuning) and then train to update the weights. So if you're following the tutorial then I would expect the mask to be the same for each epoch.

Sparsifier.step() should update the mask though, so if you've modified the code to call step() during the training, then this is unexpected. Can you share your code in that case?

@phyllispeng123
Copy link
Author

Hi @phyllispeng123 the tutorial describes a one-shot pruning flow, where we only calculate the mask once (before fine-tuning) and then train to update the weights. So if you're following the tutorial then I would expect the mask to be the same for each epoch.

Sparsifier.step() should update the mask though, so if you've modified the code to call step() during the training, then this is unexpected. Can you share your code in that case?

@jcaip Thank you for your reply !!!! I am doing the model training with respect to a transformer model. My simplified training code look like below. Hope you can give me some hints about weight finetuning together with mask finetuning, many thanks !!!!!

from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.ao.pruning import WeightNormSparsifier
SparseSemiStructuredTensor._FORCE_CUTLASS = True


torch.manual_seed(42)
sparsifier = WeightNormSparsifier(
    # apply sparsity to all blocks
    sparsity_level=1.0,
    # shape of 4 elements is a block
    sparse_block_shape=(1, 4),
    # two zeros for every block of 4
    zeros_per_block=2
    )


### load my model
model = load_model('my_model_ckpt')
model.cuda().half()

### pruning
sparse_config = []
with torch.no_grad():
    for param_tensor, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            sparse_config.append({"tensor_fqn": f"{param_tensor}.weight"})
            logger.info(f"{param_tensor} get semi-sparse")
sparsifier.prepare(model, sparse_config) 

### training, expect mask is updating                            
for epoch in range(num_train_epochs):
    for step, batch in enumerate(train_dataloader):
        model_pred = model(batch['input'])
        model_target = model(batch['target'])
        loss = loss_fn(model_target - model_pred)
        loss.backward()
        optimizer.step()
        sparsifier.step()            
        optimizer.zero_grad()
    
    
    #### save model, and check if mask is updated
    save_model(model)
    for param_tensor, module in model.named_modules():
        if '.parametrizations.weight.0' in param_tensor:
            mask_state = {}
            weight_name = param_tensor.replace('.parametrizations.weight.0','.weight')
            weight_name = weight_name.strip()
            mask_state[weight_name] = module.mask.float().cpu()
            save_file(mask_state, save_masktensor_path)
                                
        

@jcaip jcaip self-assigned this Dec 18, 2024
@jcaip
Copy link
Contributor

jcaip commented Dec 18, 2024

@phyllispeng123 can you try using the torchao sparsifier instead of the torch.ao sparsifier? It's the exact same code. We've ported it over to reside in this repo and I think there was a bug about masks getting saved that I fixed in ao but didn't upstream since that code will be deprecated soon.

from torchao.prototype.sparsity.sparsifier.weight_norm_sparsifier import WeightNormSparsifier

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants