diff --git a/lassonet/cox.py b/lassonet/cox.py index facc338..562975f 100644 --- a/lassonet/cox.py +++ b/lassonet/cox.py @@ -31,7 +31,9 @@ def forward(self, log_h, y): events = events[idx] event_ind = events.nonzero().flatten() - + if event_ind.nelement() == 0: + return torch.tensor(0.0) + # numerator log_num = log_h[event_ind].mean()