Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory Optimization in logits2pred() #92

Open
ivantyj opened this issue Dec 12, 2024 · 0 comments
Open

Memory Optimization in logits2pred() #92

ivantyj opened this issue Dec 12, 2024 · 0 comments

Comments

@ivantyj
Copy link

ivantyj commented Dec 12, 2024

Description:

While segmenting larger CT scans using the whole-body-v2.0.1 model, we found that the logits2pred() function in auto3dseg_segresnet_inference.py can be optimized for memory usage. Specifically, removing the torch.softmax(logits, dim=dim) step has no impact on the final segmentation results but significantly reduces memory consumption during inference.

Current Code:

def logits2pred(logits, sigmoid=False, dim=1):
    if isinstance(logits, (list, tuple)):
        logits = logits[0]

    if sigmoid:
        pred = torch.sigmoid(logits)
        pred = (pred >= 0.5)
    else:
        pred = torch.softmax(logits, dim=dim)
        pred = torch.argmax(pred, dim=dim, keepdim=True).to(dtype=torch.uint8)

    return pred

Suggested Change:

Replace the code in the else block to one line

        pred = torch.argmax(logits, dim=dim, keepdim=True).to(dtype=torch.unit8)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant