Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linear Model not saved time using to_sparse_semi_structured #1460

Open
phyllispeng123 opened this issue Dec 26, 2024 · 2 comments
Open

Linear Model not saved time using to_sparse_semi_structured #1460

phyllispeng123 opened this issue Dec 26, 2024 · 2 comments
Assignees
Labels

Comments

@phyllispeng123
Copy link

phyllispeng123 commented Dec 26, 2024

@jcaip I generate a linear model Model_more_linear and try to test the acceleration using to_sparse_semi_structured. However, the inference time even increase. (i saw the sample code in https://pytorch.org/tutorials/prototype/semi_structured_sparse.html where the dense model does not warm up, whereas the semi-sparse model is ran after the dense model, maybe warmed-up. I wonder if the acceleration result is correct? or my inference procedure is wrong, could you help me with it ? )

The test result is : Dense time: 0.38278400897979736ms, Semi-sparse time: 0.49686399102211

class Model_more_linear(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(10240, 3072)
        self.activation = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(3072, 3072)
        self.linear3 = torch.nn.Linear(3072, 10240)
        self.linear4 = torch.nn.Linear(10240, 10240)
        self.linear5 = torch.nn.Linear(10240, 3072)
        self.softmax = torch.nn.Softmax()

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        x = self.linear3(x)
        x = self.linear4(x)
        x = self.linear5(x)
        x = self.softmax(x)
        return x
def check_sparse_more_linear():
    import torch
    from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
    from torch.utils.benchmark import Timer
    import time
    import gc
    SparseSemiStructuredTensor._FORCE_CUTLASS = True
    
    
    gc.collect()
    torch.cuda.empty_cache()
    weight_dtype = torch.bfloat16
    device = torch.device('cuda:0')
    #### original
    linear = Model_more_linear().to(device,weight_dtype)
    gt_x = torch.rand(1, 10240).to(device,weight_dtype)
    with torch.inference_mode():
        #### warmup dense model
        for i in range(3):
            x = torch.rand(1, 10240).to(device,weight_dtype)
            dense_output = linear(x)

        del dense_output
        gc.collect()
        torch.cuda.empty_cache()
        
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        start_event.record()
        dense_output = linear(gt_x) #### use the same input gt_x
        end_event.record() 
        torch.cuda.synchronize() 
        res_time = start_event.elapsed_time(end_event)
        print(f'dense time = {res_time}')
        del dense_output
        gc.collect()
        torch.cuda.empty_cache()
        
        for name, mod in linear.named_modules():
            if isinstance(mod, torch.nn.Linear):
                left,right = mod.weight.shape[0],mod.weight.shape[1]
                mask = torch.Tensor([0, 0, 1, 1]).tile((left, right//4)).to(mod.weight.device).bool()
                mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mask * mod.weight))
        
        #### warmup semi-sparse 
        for i in range(3):
            x = torch.rand(1, 10240).to(device,weight_dtype)
            sparse_output = linear(x)
        
        del sparse_output,name,mod,mask
        gc.collect()
        torch.cuda.empty_cache()
        
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        start_event.record()
        sparse_output = linear(gt_x) #### use the same input gt_x
        end_event.record()
        torch.cuda.synchronize() 
        res_time = start_event.elapsed_time(end_event)
        print(f'sparse time = {res_time}')
        

@phyllispeng123 phyllispeng123 changed the title Model not save Model not saved time using Dec 26, 2024
@phyllispeng123 phyllispeng123 changed the title Model not saved time using Linear Model not saved time using to_sparse_semi_structured Dec 26, 2024
@jcaip
Copy link
Contributor

jcaip commented Dec 26, 2024

@phyllispeng123 this is due to your matmul shapes, I see you are mutiplying (1, 10240) x (10240, 3720). This shape is not amenable to acceleration with 2:4 sparsity becuase we have mininum dimensions that we need to pad to. So some slowdown is expected. I would expect it to be faster for higher batch sizes (> 64)

Also it looks like from the shapes you're trying to do LLM inference acceleration? If that's the case I recommend using the MARLIN 2:4 sparse kernels in AO instead, see:

class SparseMarlin24(TestCase):

This will be much faster for LLM inference.

@phyllispeng123
Copy link
Author

phyllispeng123 commented Dec 28, 2024

@jcaip Thank you for your reply ! As I try different batch size >=64, I get average inference result (10 times after warmup) as follows:
batch_size = 64 : dense time = 0.37827200889587403, sparse time = 0.6369376182556152
batch_size = 512 : dense time = 0.9773280143737793, sparse time = 1.4351903915405273
batch_size = 1280 : dense time = 4.596492767333984, sparse time = 3.4081790924072264

It seems like the batch size > 1280 will show acceleration. Based on this result, I try to do FLUX dev1.0 model acceleration (model structure is https://github.com/kohya-ss/sd-scripts/blob/sd3/library/flux_models.py) . There are 2 linear layer in SingleStreamBlock called
self.linear1 = Linear(in_features=3072, out_features=21504, bias=True), inputs shape = (1, 4592, 3072)
self.linear2 = Linear(in_features=15360, out_features=3072, bias=True) inputs shape = (1, 4592, 3072)

both layers have inputs shape = (1, 4592, 3072). I use 2:4 semi-sparse & to_sparse_semi_structured to self.linear1 and self.linear2 and expect the model to be faster. However, I test the inference time in a single gpu (A800-40g, bf16), it is even slower (original dense: 466ms, add 2:4 semi-sparse: 590ms) , do you have any advice if I got it wrong ? or which sparse mask i can use ? Many thanks!!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants