Description
Hey guys, I just implemented the generalised dice loss (multi-class version of dice loss), as described in ref :
(my targets are defined as: (batch_size, image_dim1, image_dim2, image_dim3, nb_of_classes))
def generalized_dice_loss_w(y_true, y_pred):
# Compute weights: "the contribution of each label is corrected by the inverse of its volume"
Ncl = y_pred.shape[-1]
w = np.zeros((Ncl,))
for l in range(0,Ncl): w[l] = np.sum( np.asarray(y_true[:,:,:,:,l]==1,np.int8) )
w = 1/(w**2+0.00001)
# Compute gen dice coef:
numerator = y_true*y_pred
numerator = w*K.sum(numerator,(0,1,2,3))
numerator = K.sum(numerator)
denominator = y_true+y_pred
denominator = w*K.sum(denominator,(0,1,2,3))
denominator = K.sum(denominator)
gen_dice_coef = numerator/denominator
return 1-2*gen_dice_coef
But something must be wrong. I'm working with 3D images that I have to segment for 4 classes (1 background class and 3 object classes, I have a imbalanced dataset). First odd thing: while my train loss and accuracy improve during training (and converge really fast), my validation loss/accuracy are constant trough epochs (see image). Second, when predicting on test data, only the background class is predicted: I get a constant volume.
I used the exact same data and script but with categorical cross-entropy loss and get plausible results (object classes are segmented). Which means something is wrong with my implementation. Any idea what it could be?
Plus I believe it would be usefull to the keras community to have a generalised dice loss implementation, as it seems to be used in most of recent semantic segmentation tasks (at least in the medical image community).
PS: it seems odd to me how the weights are defined; I get values around 10^-10. Anyone else has tried to implement this? I also tested my function without the weights but get same problems.
Activity
xychenunc commentedon Mar 5, 2018
Hi, I came across a similar issue. The dataset is imbalanced, and the region which is small compared to the whole image cannot be well segmented. I think it has nothing to do with your loss function, maybe it is due to patch-based approach. How large you choose your patch size?
xychenunc commentedon Mar 5, 2018
And I think the problem with your loss function is the weights are not normalized. I think a normalized weights should be what you want. And w = 1/(w**2+0.00001) maybe should be rewritten as something like w = w/(np.sum(w)+0.00001). Otherwise, the generalized loss is not 'balanced', the region which takes a larger portion of the image accounts for a relatively small part in the total loss.
emoebel commentedon Mar 6, 2018
Hey xynechunc, thanks for your answer! I tried normalizing the weights, but it didn't do any difference. While it is true that the weight values are better interpretable (instead of values around 10^-10 I have now values between 0 and 1), it seems that numerically it does not change the loss behaviour.
Why are you asking about the patch size? Aren't the weights supposed to cope with the unbalanced class problem? Anyway, my patch size is 56x56x56 voxel, and my objects have a diameter of 10 voxel. In my patches, I have in average 7% of voxels labeled as object, the rest is background.
jpcenteno80 commentedon Mar 6, 2018
I am trying something similar for a 2D semantic segmentation project with 10 categories (label 0 is background). Before trying dice, I was using sparse categorical crossentropy with very good results. However, because label 0 was being included in the loss calculation, both training and validation accuracy were artificially high (> 0.98). My implementation of dice is based on this: https://github.com/Lasagne/Recipes/issues/99.
y_true has shape (batch,m,n,1) and y_pred has shape (batch,m,n,10). Here is my version of dice:
A model trained with the above implementation of dice tends to predict 4 out of the 9 categories and the segmentation is less than ideal and much worse than I got with sparse categorical crossentropy.
However, when I convert the segmentation task into a binary decision (merge all categories into one), the segmentation is pretty good. Here is the loss function changed for a binary problem:
Not sure if the binary results are better because it is an 'easier' task or because my dice loss function is wrong.
xychenunc commentedon Mar 20, 2018
emoebel commentedon Mar 27, 2018
@xychenunc thanks for your answer, I realised also that the problem is class imbalance. I just dont understand why, because the weights in the loss function are supposed to cope for that.
@jpcenteno80 does sparse-categorical-crossentropy work better than normal categorical-crossentropy in the case of segmentation? I tried it out myself, but I'm getting an error concerning array shapes:
ValueError: Error when checking target: expected conv3d_15 to have shape (56, 56, 56, 1) but got array with shape (56, 56, 56, 4)
Concerning your loss function 'dice_coef_9cat_loss': I don't think it is a good idea to ignore background. Examples of "non-objects" are as important as examples of "objects", the problem is class imbalance. If you completely ignore the background, chances are that you'll get a lot of false positives.
Concerning your function 'dice_coef_binary_loss': maybe it works better because by merging all your object classes there is a better balance between "background" and "objects".
Check out the ref I cited in my original post, they describe how to implement dice loss for multiple imbalanced classes. Maybe your implementation will work, because mine doesn't and I don't understand why.
jpcenteno80 commentedon Mar 27, 2018
When using sparse categorical crossentropy, you don't need to one-hot encode your target. Here is a discussion comparing crossentropy vs sparse crossentropy in stackoverflow (google: TensorFlow: what's the difference between sparse_softmax_cross_entropy_with_logits and softmax_cross_entropy_with_logits?) <- Copying the link wasn't working for me.
So in my case (2D segmentation model) I feed in the raw mask as my target, which is an array 384x384 with pixels labeled from 0 to 9 (10 categories).
I ended up getting decent results with the 'dice_coef_9cat_loss' function after all (neglecting the background label). It just took longer training and starting with lower learning rates (Nadam, 1e-5 or 1e-4). I also tried cyclical learning rates, which I think helped: https://github.com/bckenstler/CLR. I decided to neglect background in the loss calculation because my class imbalance was pretty large and I could not figure out in Keras how to use 'sample_weight' in the fit method with 2D arrays.
I'm also transitioning into pytorch and I like that it seems more flexible in terms of setting up custom metrics or loss functions. I am using the tiramisu architecture for semantic segmentation which uses negative log likelihood as the loss (implementation here: https://github.com/bfortuner/pytorch_tiramisu). The results so far are great. I highly recommend using this architecture for semantic segmentation. Have not tried it in 3D though.
zdryan commentedon Mar 28, 2018
I've taken the same approach as @jpcenteno80 as I am also unable to successfully implement the generalised dice loss. Would rather avoid using temporal sample weights.
emoebel commentedon Apr 6, 2018
Hey guys, I found a way to implement multi-class dice loss, I get satisfying segmentations now. I implemented the loss as explained in ref : this paper describes the Tversky loss, a generalised form of dice loss, which is identical to dice loss when alpha=beta=0.5
Here is my implementation, for 3D images:
I would be curious to know if this works for your applications. To adapt from 3D images to 2D images, you should modify all "sum(...,(0,1,2,3))" to "sum(...,(0,1,2))".
gattia commentedon Apr 6, 2018
@lazyleaf, I just stumbled upon this. I am doing 3D segmentation on multiclass. I will definitely try out the proposed method and see how it works. However, I also have another solution that has worked for me in the past:
This simply calculates the dice score for each individual label, and then sums them together, and includes the background. The best dice score you will ever get is equal to
numLables*-1.0
. When monitoring I always keep in mind that the dice for the background is almost always near 1.0.kroskal commentedon Apr 17, 2018
@lazyleaf, I was also struggling to implement this loss function. But with som inspiration from your code, here is my take on it (for 2D images).
Ps! I haven't tried to train a network with it yet.
lkshrsch commentedon May 7, 2018
@lazyleaf thank you for pointing to the tversky loss. I implemented your code (had to change K.shape --> K.int_shape ) but it still complaints that "TypeError: long() argument must be a string or a number, not 'NoneType'".
Do you know why this is happening, and do you see it in your own code?
rakeshakkineni commentedon May 19, 2018
@lkshrsch To remove the compilation error i have replaced "ones = K.ones(K.shape(y_true))" with "ones = K.ones_like(y_true)".
DaanKuppens commentedon May 31, 2018
@kroskal Thanks for the implementation. Did you try training a network yet with this loss?
When I use your generalized dice loss I somehow get loss values larger than 1 and dice coefficients smaller than 0. Do you know what may cause this?
DaanKuppens commentedon May 31, 2018
I have the generalized_dice_coef and generalized_dice_loss now working between [0 1] for 2D images. I normalized the weights to the presence of the class in the entire dataset instead of just the batch, using the following code:
32 remaining items
xychenunc commentedon Jul 4, 2020
I noticed that this implementation of multi-class dice loss could lead to sub-optimal performance in some cases (probably depending on specific dataset and/or architecture design). That is, for some minority class(es), the prediction can be nothing. Have you encountered this problem in your work and how did you solve it? Thanks
gattia commentedon Jul 4, 2020
This can sometimes be resolved by tuning other parameters (learning rate, optimizer, etc.)
If you have many classes that are imbalanced it can cause issues or not equally weight them. In this case, you can create weighting schemes that can sometimes help.
maxvfischer commentedon Aug 4, 2020
I've implemented a bunch of binary and multi-class loss functions for image segmentation (one-hot encoded masks) that you might find useful: https://github.com/maxvfischer/keras-image-segmentation-loss-functions
E.g.:
Tanimoto loss: https://github.com/maxvfischer/keras-image-segmentation-loss-functions/blob/master/losses/multiclass_losses.py#L8
Dice's coefficient loss: https://github.com/maxvfischer/keras-image-segmentation-loss-functions/blob/master/losses/multiclass_losses.py#L42
Squared Dice's coefficient loss: https://github.com/maxvfischer/keras-image-segmentation-loss-functions/blob/master/losses/multiclass_losses.py#L74
rohan19250 commentedon Dec 10, 2020
Hello - I am working on 4 class segmentation problem, so I have 4 labels. I am able to get combined dice scores and losses using the functions below:
How can I get dice coefficient and dice loss per label instead of a combined dice coefficient (see below)?
maxvfischer commentedon Dec 10, 2020
@rohan19250 I don't know what optimizer you're using during training, but presuming that you're using a gradient-based optimizer like SGD or ADAM, you want a single loss value to be able to optimize the network.
That being said, if you still want to compute
dice_coef
anddice_coef_loss
for each label seperately, I would add them as Kerasmetrics
.You could probably try to add something like this (NOTE: I have not tested this code. Think of it as pseudo-code):
Then compile your model in this fashion:
rohan19250 commentedon Dec 10, 2020
Thanks a lot @maxvfischer! This worked. So we are not summing over the last axis (excluded"axis=-1") in this function for individual labels. Could you explain this a bit?
maxvfischer commentedon Dec 10, 2020
@rohan19250
What
axis=-1
means is that you're referring to the last axis (without knowing the actual index of the last axis). In your case when computing the dice loss for all labels, the tensorsy_true
andy_pred
are of shape(None, <IMAGE_HEIGHT>, <IMAGE_WIDTH>, 4)
, whereNone
is the unknown batch size and 4 is the amount of classes you have. By runningK.sum(..., axis=-1)
you are summing over all the classes (last channel).But in my code, we're extracting the true values and the predictions for a single class:
The shape will go from
y_true.shape = y_pred.shape = (None, <IMAGE_HEIGHT>, <IMAGE_WIDTH>, 4)
to
y_true_single_class.shape = y_pred_single_class.shape = (None, <IMAGE_HEIGHT>, <IMAGE_WIDTH>)
If you would keep
axis=-1
, that will refer to the last channel in the tensor, in our case<IMAGE_WIDTH>
. By keeping it, you would've summed over the image width, something we don't want to do.Hope it explains why I removed it.
EDIT:
I just saw that I didn't write that. When you don't supply an axis to
K.sum(...)
, it will sum over all axisesrohan19250 commentedon Dec 10, 2020
@maxvfischer Got it! this is really helpful.
My Y_val is of shape (2880, 192, 192, 4). I am using the combined dice coefficient loss function for overall network, and just calculating individual dice coefficients.
The functions I defined based on your above code for individual classes:
It seems the combined validation dice coefficient is good, but the individual class dice coefficients are not that good (shown first few epochs). Perhaps I should explore other loss functions and data augmentation (Y train ~13000 images)?
maxvfischer commentedon Dec 10, 2020
@rohan19250 Impossible for me to answer by the information you've provided.
combined validation dice coefficient
should be converging to something "good", because that's what you're optimizing. For how long did you train it? Have you plotted the individual dice coefficients over more epochs? Does it converges to something?I would probably think more about your problem:
3.2.4. Generalization to multiclass imbalanced problems
in this paper, they explain a way to weight your classes: https://arxiv.org/pdf/1904.00592.pdfImageDataGenerator
and its built in image augmentations (https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator) or write your ownImageDataGenerator
(by inheritingkeras.utils.Sequence
) and use this augmentation package: https://github.com/albumentations-team/albumentations .rohan19250 commentedon Dec 10, 2020
@maxvfischer - I have trained for 40 epochs(Adam optimizer,lr=1e-5). The individual dice coefficients converge to ~.29 for the first class, ~.46 for the second class, and ~.42 for the third class.
The combined validation dice coefficient converges to ~.98
Some of the details about the architecture of the U-net model.
Also, I am able to do predictions and see the ground truth and the predicted labels for some of the test images(which is a bit satisfying), but was concerned about the individual dice scores not good which definitely could be worked upon.
I will explore the other suggestions you provided.
maxvfischer commentedon Dec 10, 2020
@rohan19250 For interpretability, you might want to use intersection over union/Jaccard Index (https://en.wikipedia.org/wiki/Jaccard_index) for each class as a metric instead of dice coefficient:
Good luck!
Tarandeep97 commentedon Apr 17, 2022
@gattia
Won't there been a weight for each label multiplied to dice_coef, weight that depicts the number of pixels for that label ?
gattia commentedon Apr 17, 2022
What you are describing is one version of the DSC that has been called generalized-dsc.
I think they inversely weighted based on the number of pixels labeled that class.
the above would replace the code in the for loop to inversely weight it based on the number of pixels labeled that particular class (in the ground truth).