Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to run multi-label segmentation? #82

Open
gulubao opened this issue Apr 12, 2023 · 15 comments
Open

How to run multi-label segmentation? #82

gulubao opened this issue Apr 12, 2023 · 15 comments

Comments

@gulubao
Copy link

gulubao commented Apr 12, 2023

I encountered some issues with multi-label segmentation, and I would like to ask for your help.

The demo ISIC dataset has a single label and the demo BRATS dataset has multiple labels but is merged in the class BRATSDataset3D.

I am interested in performing multi-label segmentation on my own dataset, and I am wondering how to set the dataset and model for this purpose.

Could you please provide a demo for multi-label segmentation?

@theneao
Copy link

theneao commented Apr 13, 2023

Did you try using v2? I see that the structure diagram of v2 seems to be multi classified, but I don't seem to find any specific modifications in the code to address this issue. I tried to modify V1, but due to limited ability, it is difficult to modify loss and multi class output, but there are no errors reported, but I still do not have multi class ability. Do you currently have any good findings or ideas regarding this issue?

@gulubao
Copy link
Author

gulubao commented Apr 14, 2023

I am attempting to train on v2. I am not sure if the author adjusted the loss function for multi-class classification, as MSE and VB don't seem to limit the number of categories.

Here are the adjustments I made for multilabel:
1. Preprocess the mask labels.
1.1 Mark labels in the entire dataset as 0,1,..n. Reassign the mask values in the entire dataset by 0-n.
1.2 When importing into a custom DataSet, normalize mask with mask = mask / n.
1.3 Use transforms.Resize((args.image_size, args.image_size), interpolation=InterpolationMode.NEAREST) for mask resizing.
1.4 Delete torch.where(mask > 0, 1, 0).
2. In the gaussain_diffusion.training_losses_segmentation function, change res = torch.where(mask > 0, 1, 0) to res = torch.softmax(mask, dim=0).

Due to limited computational resources, I had to reduce the batch_size, so I replaced all nn.BatchNorm2d with nn.InstanceNorm2d.

I am still in the training process, and I am not sure if these adjustments will work.

I hope the author could share the method for multilabel segmentation in Fig. 2 of the paper.

@theneao
Copy link

theneao commented Apr 15, 2023

I have made similar adjustments on V1 before, but in the end, I found that the loss calculation predicted noise (which can be modified to predict x0), but the predicted results only had one channel and cannot be predicted for multiple classes.

The main reason, I think, may be that the 'out_channels' in the Unet network definition was not modified before.

But after trying to modify it, I found that the model_ output, model_ var_ Values=th. split (model_output, C, dim=1), it is not clear how to allocate the number of channels between the two in multi class scenarios, and it has been found that assigning values to "model_output" using a method similar to "model_output [:,: 0,:]" directly will result in reporting dimension errors in the next loop, although the dimension is correct after the initial run. Of course, it is also possible that my lack of proficiency in learning has caused some mistakes. You can try it yourself.

If it's convenient, you can directly contact me through my homepage email, or let me know your other contact information by email.

@theneao
Copy link

theneao commented Apr 15, 2023

For the separation of channel numbers, I don't know why using torch.split()is different from directly using model_output [:,: 0,:,:], but I can only specify the classification ratio using model_output, model_var_values=th.split (model_output, 4, dim=1) or model_output, model_var_values=th.split (model_output, [4,1], dim=1) (I do a 5 classification task, and I think output_channel can be set to 5 or 8).
However, continuing with the operation will still result in errors.
In

def_ predict_ xstart_ from_ eps(self, x_t, t, eps):
assert x_ t.shape == eps.shape

The dimension error of 'x_t' is still reported. This is a computable forward diffusion image that can be input, and I believe it can be copied to the same channel as EPS for calculation. Currently, there has been no attempt.
In addition, after comparing the two versions of the code, it was found that some modifications were made to the network structure related to the V2 paper, and no changes were made to the category.

@gulubao
Copy link
Author

gulubao commented Apr 17, 2023

I did not try to add channels to do multi-segmentation but tried to generate multiple pixel values representing different classes in the last channel of the current code.

I conducted experiments on BRATS, however, the effect was very poor. My adjustments were in the comment above.

The figure below is the 50th slice in folder slice0001. The figure that looks dark is the GT segmentation. The gray one is the model output.

slice0001_slice50_GT
slice0001_slice50_output_ens

@theneao
Copy link

theneao commented Apr 18, 2023

The results are indeed very poor, and the binary classification performance is also almost poor on my dataset, with poor fine-grained performance, far lower than the common U-NET network. It is unclear what caused it. I feel like I want to give up on this project.

Recently, there have been many segmentation networks based on diffusion models for similar tasks. If you are interested, we can discuss them through private email.

I did not try to add channels to do multi-segmentation but tried to generate multiple pixel values representing different classes in the last channel of the current code.

I conducted experiments on BRATS, however, the effect was very poor. My adjustments were in the comment above.

The figure below is the 50th slice in folder slice0001. The figure that looks dark is the GT segmentation. The gray one is the model output.

slice0001_slice50_GT slice0001_slice50_output_ens

@lixiang007666
Copy link

@gulubao, When I was dealing with the sample code of multi-label classification, I encountered some problems, can you communicate with me?

email: [email protected]

@theneao
Copy link

theneao commented Apr 21, 2023

I did not try to add channels to do multi-segmentation but tried to generate multiple pixel values representing different classes in the last channel of the current code.

I conducted experiments on BRATS, however, the effect was very poor. My adjustments were in the comment above.

The figure below is the 50th slice in folder slice0001. The figure that looks dark is the GT segmentation. The gray one is the model output.

slice0001_slice50_GT slice0001_slice50_output_ens

I think it may be the problem of loss function design

@jaceqin
Copy link

jaceqin commented May 16, 2023

I did not try to add channels to do multi-segmentation but tried to generate multiple pixel values representing different classes in the last channel of the current code.

I conducted experiments on BRATS, however, the effect was very poor. My adjustments were in the comment above.

The figure below is the 50th slice in folder slice0001. The figure that looks dark is the GT segmentation. The gray one is the model output.

slice0001_slice50_GT slice0001_slice50_output_ens

Have you solved your multi-tab classification yet?

@saisusmitha
Copy link

@gulubao @theneao Hi guys, I think the output sample is including image too - I mean it's giving the segmentation of brain border too - Is this the case with you guys too? - Seeing at the result i think it's the same with your outputs too. Kindly let me know and correct me if I am missing something.

I did not try to add channels to do multi-segmentation but tried to generate multiple pixel values representing different classes in the last channel of the current code.

I conducted experiments on BRATS, however, the effect was very poor. My adjustments were in the comment above.

The figure below is the 50th slice in folder slice0001. The figure that looks dark is the GT segmentation. The gray one is the model output.

slice0001_slice50_GT slice0001_slice50_output_ens

@agentdr1
Copy link

any updates on this? also interested in multi class segmentation.

@smallboy-code
Copy link

There are some results of multi class segmentation for brats dataset. And I don't konw how to threshold the output mask like the ground turth.
0
2
3
5

@thd2020
Copy link

thd2020 commented Sep 26, 2024

I am attempting to train on v2. I am not sure if the author adjusted the loss function for multi-class classification, as MSE and VB don't seem to limit the number of categories.

Here are the adjustments I made for multilabel: 1. Preprocess the mask labels. 1.1 Mark labels in the entire dataset as 0,1,..n. Reassign the mask values in the entire dataset by 0-n. 1.2 When importing into a custom DataSet, normalize mask with mask = mask / n. 1.3 Use transforms.Resize((args.image_size, args.image_size), interpolation=InterpolationMode.NEAREST) for mask resizing. 1.4 Delete torch.where(mask > 0, 1, 0). 2. In the gaussain_diffusion.training_losses_segmentation function, change res = torch.where(mask > 0, 1, 0) to res = torch.softmax(mask, dim=0).

Due to limited computational resources, I had to reduce the batch_size, so I replaced all nn.BatchNorm2d with nn.InstanceNorm2d.

I am still in the training process, and I am not sure if these adjustments will work.

I hope the author could share the method for multilabel segmentation in Fig. 2 of the paper.

@gulubao Bro, how you dealing with masks? Multi-channel or one-hot? And do you modify args.in_ch?

@Destinycjk
Copy link

There are some results of multi class segmentation for brats dataset. And I don't konw how to threshold the output mask like the ground turth. 0 2 3 5

Your results look great! Could you please explain how you adjusted the source code to achieve multi-class segmentation?

@YihangZHO
Copy link

There are some results of multi class segmentation for brats dataset. And I don't konw how to threshold the output mask like the ground turth. 0 2 3 5

how is your dice score? I also tried to use this repo for multi-class segmentation, but it seems acc is very low and much worse than nnunet

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants