You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@jcaip I generate a linear model Model_more_linear and try to test the acceleration using to_sparse_semi_structured. However, the inference time even increase. (i saw the sample code in https://pytorch.org/tutorials/prototype/semi_structured_sparse.html where the dense model does not warm up, whereas the semi-sparse model is ran after the dense model, maybe warmed-up. I wonder if the acceleration result is correct? or my inference procedure is wrong, could you help me with it ? )
The test result is : Dense time: 0.38278400897979736ms, Semi-sparse time: 0.49686399102211
class Model_more_linear(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10240, 3072)
self.activation = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(3072, 3072)
self.linear3 = torch.nn.Linear(3072, 10240)
self.linear4 = torch.nn.Linear(10240, 10240)
self.linear5 = torch.nn.Linear(10240, 3072)
self.softmax = torch.nn.Softmax()
def forward(self, x):
x = self.linear1(x)
x = self.activation(x)
x = self.linear2(x)
x = self.linear3(x)
x = self.linear4(x)
x = self.linear5(x)
x = self.softmax(x)
return x
def check_sparse_more_linear():
import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.utils.benchmark import Timer
import time
import gc
SparseSemiStructuredTensor._FORCE_CUTLASS = True
gc.collect()
torch.cuda.empty_cache()
weight_dtype = torch.bfloat16
device = torch.device('cuda:0')
#### original
linear = Model_more_linear().to(device,weight_dtype)
gt_x = torch.rand(1, 10240).to(device,weight_dtype)
with torch.inference_mode():
#### warmup dense model
for i in range(3):
x = torch.rand(1, 10240).to(device,weight_dtype)
dense_output = linear(x)
del dense_output
gc.collect()
torch.cuda.empty_cache()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
dense_output = linear(gt_x) #### use the same input gt_x
end_event.record()
torch.cuda.synchronize()
res_time = start_event.elapsed_time(end_event)
print(f'dense time = {res_time}')
del dense_output
gc.collect()
torch.cuda.empty_cache()
for name, mod in linear.named_modules():
if isinstance(mod, torch.nn.Linear):
left,right = mod.weight.shape[0],mod.weight.shape[1]
mask = torch.Tensor([0, 0, 1, 1]).tile((left, right//4)).to(mod.weight.device).bool()
mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mask * mod.weight))
#### warmup semi-sparse
for i in range(3):
x = torch.rand(1, 10240).to(device,weight_dtype)
sparse_output = linear(x)
del sparse_output,name,mod,mask
gc.collect()
torch.cuda.empty_cache()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
sparse_output = linear(gt_x) #### use the same input gt_x
end_event.record()
torch.cuda.synchronize()
res_time = start_event.elapsed_time(end_event)
print(f'sparse time = {res_time}')
The text was updated successfully, but these errors were encountered:
@phyllispeng123 this is due to your matmul shapes, I see you are mutiplying (1, 10240) x (10240, 3720). This shape is not amenable to acceleration with 2:4 sparsity becuase we have mininum dimensions that we need to pad to. So some slowdown is expected. I would expect it to be faster for higher batch sizes (> 64)
Also it looks like from the shapes you're trying to do LLM inference acceleration? If that's the case I recommend using the MARLIN 2:4 sparse kernels in AO instead, see:
@jcaip Thank you for your reply ! As I try different batch size >=64, I get average inference result (10 times after warmup) as follows: batch_size = 64 : dense time = 0.37827200889587403, sparse time = 0.6369376182556152 batch_size = 512 : dense time = 0.9773280143737793, sparse time = 1.4351903915405273 batch_size = 1280 : dense time = 4.596492767333984, sparse time = 3.4081790924072264
It seems like the batch size > 1280 will show acceleration. Based on this result, I try to do FLUX dev1.0 model acceleration (model structure is https://github.com/kohya-ss/sd-scripts/blob/sd3/library/flux_models.py) . There are 2 linear layer in SingleStreamBlock called self.linear1 = Linear(in_features=3072, out_features=21504, bias=True), inputs shape = (1, 4592, 3072) self.linear2 = Linear(in_features=15360, out_features=3072, bias=True) inputs shape = (1, 4592, 3072)
both layers have inputs shape = (1, 4592, 3072). I use 2:4 semi-sparse & to_sparse_semi_structured to self.linear1 and self.linear2 and expect the model to be faster. However, I test the inference time in a single gpu (A800-40g, bf16), it is even slower (original dense: 466ms, add 2:4 semi-sparse: 590ms) , do you have any advice if I got it wrong ? or which sparse mask i can use ? Many thanks!!!
@jcaip I generate a linear model
Model_more_linear
and try to test the acceleration usingto_sparse_semi_structured
. However, the inference time even increase. (i saw the sample code in https://pytorch.org/tutorials/prototype/semi_structured_sparse.html where the dense model does not warm up, whereas the semi-sparse model is ran after the dense model, maybe warmed-up. I wonder if the acceleration result is correct? or my inference procedure is wrong, could you help me with it ? )The test result is :
Dense time: 0.38278400897979736ms, Semi-sparse time: 0.49686399102211
The text was updated successfully, but these errors were encountered: