You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
While segmenting larger CT scans using the whole-body-v2.0.1 model, we found that the logits2pred() function in auto3dseg_segresnet_inference.py can be optimized for memory usage. Specifically, removing the torch.softmax(logits, dim=dim) step has no impact on the final segmentation results but significantly reduces memory consumption during inference.
Current Code:
def logits2pred(logits, sigmoid=False, dim=1):
if isinstance(logits, (list, tuple)):
logits = logits[0]
if sigmoid:
pred = torch.sigmoid(logits)
pred = (pred >= 0.5)
else:
pred = torch.softmax(logits, dim=dim)
pred = torch.argmax(pred, dim=dim, keepdim=True).to(dtype=torch.uint8)
return pred
Suggested Change:
Replace the code in the else block to one line
pred = torch.argmax(logits, dim=dim, keepdim=True).to(dtype=torch.unit8)
The text was updated successfully, but these errors were encountered:
Description:
While segmenting larger CT scans using the
whole-body-v2.0.1
model, we found that thelogits2pred()
function inauto3dseg_segresnet_inference.py
can be optimized for memory usage. Specifically, removing thetorch.softmax(logits, dim=dim)
step has no impact on the final segmentation results but significantly reduces memory consumption during inference.Current Code:
Suggested Change:
Replace the code in the
else
block to one lineThe text was updated successfully, but these errors were encountered: