Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loss.py BinaryDiceLoss use of dim = 1 when flattening the tensor. #83

Open
skapoor2024 opened this issue Aug 11, 2024 · 4 comments
Open

Comments

@skapoor2024
Copy link

In the provided loss function

class BinaryDiceLoss(nn.Module):
    def __init__(self, smooth=1, p=2, reduction='mean'):
        super(BinaryDiceLoss, self).__init__()
        self.smooth = smooth
        self.p = p
        self.reduction = reduction

    def forward(self, predict, target):
        assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
        predict = predict.contiguous().view(-1)
        target = target.contiguous().view(-1)

        num = torch.sum(torch.mul(predict, target), dim=1)
        den = torch.sum(predict, dim=1) + torch.sum(target, dim=1) + self.smooth

        dice_score = 2*num / den
        dice_loss = 1 - dice_score

        dice_loss_avg = dice_loss[target[:,0]!=-1].sum() / dice_loss[target[:,0]!=-1].shape[0]

        return dice_loss_avg
        

When we flatten the predict and target array why do we try to sum with dim = 1. Shouldn't it be dim = 0. Since there is only one dimension there . Also when predict and target dimensions are sent to the loss function, they are selected from individual batch and individual organ leaving the 3-D image for target and prediction.

@ljwztc
Copy link
Owner

ljwztc commented Aug 21, 2024

Make sense. This should be 0. But why we encountered no errors when running this code. Maybe earlier version of Pytorch support this summation?

@skapoor2024
Copy link
Author

I believe we should make the following changes like

        predict = predict.contiguous().view(1,-1)
        target = target.contiguous().view(1,-1)

This would make the dim=1 work properly as all the d,h,w will concatenate and the final dim would be (1,dhw)

@OxInsky1105
Copy link

yes you are right!

@OxInsky1105
Copy link

predict = predict.contiguous().view(predict.size(0),-1)
target = target.contiguous().view(predict.size(0),-1)

I think changes as follow may be better! The final dim would be (b, dhw)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants