diff --git a/docs/source/losses.rst b/docs/source/losses.rst index baeebbbe9c..a5be560324 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -78,6 +78,11 @@ Segmentation Losses .. autoclass:: BarlowTwinsLoss :members: +`BoundaryLoss` +~~~~~~~~~~~~~~ +.. autoclass:: BoundaryLoss + :members: + `HausdorffDTLoss` ~~~~~~~~~~~~~~~~~ .. autoclass:: HausdorffDTLoss diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 087a24f9d7..9f35e5f075 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -14,6 +14,7 @@ from .adversarial_loss import PatchAdversarialLoss from .aucm_loss import AUCMLoss from .barlow_twins import BarlowTwinsLoss +from .boundary_loss import BoundaryLoss from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss from .contrastive import ContrastiveLoss from .deform import BendingEnergyLoss, DiffusionLoss diff --git a/monai/losses/boundary_loss.py b/monai/losses/boundary_loss.py new file mode 100644 index 0000000000..169763e622 --- /dev/null +++ b/monai/losses/boundary_loss.py @@ -0,0 +1,233 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings +from collections.abc import Callable + +import torch +from torch.nn.modules.loss import _Loss + +from monai.networks import one_hot +from monai.transforms.utils import distance_transform_edt +from monai.utils import LossReduction + +__all__ = ["BoundaryLoss"] + + +class BoundaryLoss(_Loss): + """ + Compute the boundary loss for highly unbalanced segmentation. + + The boundary loss is a distance-based loss that operates on the interface between segmentation + regions rather than on the regions themselves. This makes it particularly effective for + highly imbalanced segmentation tasks (e.g., small lesions, thin structures), where standard + Dice or Cross-Entropy losses struggle due to foreground-background imbalance. + + The loss is formulated as a pixel-wise weighted sum of predicted probabilities and a + signed distance map derived from the ground truth. The signed distance map is negative + inside the foreground region and positive outside, with zero on the boundary. + + The data `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` + (BNHW[D]). + Note that axis N of `input` is expected to be logits or probabilities for each class, if passing logits as input, + must set `sigmoid=True` or `softmax=True`, or specifying `other_act`. And the same axis of `target` + can be 1 or N (one-hot format). + + The original paper: + Kervadec, H. et al. (2019) Boundary loss for highly unbalanced segmentation. MIDL 2019. + https://arxiv.org/abs/1812.07032 + + Example: + >>> import torch + >>> from monai.losses import BoundaryLoss + >>> B, C, H, W = 2, 3, 5, 5 + >>> input = torch.rand(B, C, H, W) + >>> target = torch.randint(0, C, size=(B, H, W)) + >>> bl = BoundaryLoss(softmax=True, to_onehot_y=True) + >>> loss = bl(input, target) + """ + + def __init__( + self, + include_background: bool = True, + to_onehot_y: bool = False, + sigmoid: bool = False, + softmax: bool = False, + other_act: Callable | None = None, + reduction: LossReduction | str = LossReduction.MEAN, + batch: bool = False, + ) -> None: + """ + Args: + include_background: if False, channel index 0 (background category) is excluded from the calculation. + if the non-background segmentations are small compared to the total image size they can get overwhelmed + by the signal from the background so excluding it in such cases helps convergence. + to_onehot_y: whether to convert the ``target`` into the one-hot format, + using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False. + sigmoid: if True, apply a sigmoid function to the prediction. + softmax: if True, apply a softmax function to the prediction. + other_act: callable function to execute other activation layers, Defaults to ``None``. for example: + ``other_act = torch.tanh``. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + batch: whether to compute the distance map and loss over the batch dimension before the dividing. + Defaults to False, a boundary loss value is computed independently from each item in the batch + before any `reduction`. + + Raises: + TypeError: When ``other_act`` is not an ``Optional[Callable]``. + ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. + Incompatible values. + """ + super().__init__(reduction=LossReduction(reduction).value) + if other_act is not None and not callable(other_act): + raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") + if int(sigmoid) + int(softmax) + int(other_act is not None) > 1: + raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].") + + self.include_background = include_background + self.to_onehot_y = to_onehot_y + self.sigmoid = sigmoid + self.softmax = softmax + self.other_act = other_act + self.batch = batch + + @torch.no_grad() + def compute_distance_map(self, target: torch.Tensor) -> torch.Tensor: + """ + Compute the signed distance map for each class in the target. + + The signed distance map is negative inside the foreground region and positive outside, + with zero on the boundary. + + Args: + target: target tensor of shape BNHW[D], with values in {0, 1} (one-hot encoded). + + Returns: + Signed distance map of the same shape as target. + """ + if target.dim() not in (4, 5): + raise ValueError("Only 2D (BNHW) and 3D (BNHWD) supported") + + distance_map = torch.zeros_like(target, dtype=torch.float32) + + for batch_idx in range(target.shape[0]): + for channel_idx in range(target.shape[1]): + mask = target[batch_idx, channel_idx : channel_idx + 1] > 0.5 + + # Empty or full masks do not have a foreground/background interface. + if not mask.any() or mask.all(): + continue + + fg_dist: torch.Tensor = distance_transform_edt(mask) # type: ignore + bg_dist: torch.Tensor = distance_transform_edt(~mask) # type: ignore + + signed = torch.zeros_like(mask, dtype=torch.float32) + signed[mask] = -(fg_dist[mask].to(torch.float32) - 1) + signed[~mask] = bg_dist[~mask].to(torch.float32) + + distance_map[batch_idx, channel_idx] = signed[0] + + return distance_map + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be BNHW[D], where N is the number of classes. + target: the shape should be BNHW[D] or B1HW[D], where N is the number of classes. + + Raises: + ValueError: If the input is not 2D (BNHW) or 3D (BNHWD). + AssertionError: When input and target (after one hot transform if set) + have different shapes. + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + + Example: + >>> import torch + >>> from monai.losses import BoundaryLoss + >>> B, C, H, W = 2, 3, 5, 5 + >>> input = torch.rand(B, C, H, W) + >>> target_idx = torch.randint(0, C, size=(B, H, W)).long() + >>> target = one_hot(target_idx[:, None, ...], num_classes=C) + >>> bl = BoundaryLoss(softmax=True) + >>> loss = bl(input, target) + """ + if input.dim() not in (4, 5): + raise ValueError("Only 2D (BNHW) and 3D (BNHWD) supported") + + n_pred_ch = input.shape[1] + + # Apply activation to input + if self.sigmoid: + input = torch.sigmoid(input) + + if self.softmax: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2) + else: + input = torch.softmax(input, dim=1) + + if self.other_act is not None: + input = self.other_act(input) + + # Convert target to one-hot if needed + if self.to_onehot_y: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) + else: + if target.dim() == input.dim() - 1: + target = target.unsqueeze(dim=1) + target = one_hot(target, num_classes=n_pred_ch) + + # Validate shapes match + if input.shape != target.shape: + raise AssertionError(f"input and target shapes do not match: {input.shape} vs {target.shape}") + + # Exclude background if requested + if not self.include_background: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2) + else: + input = input[:, 1:] + target = target[:, 1:] + + # Compute signed distance maps from target + distance_map = self.compute_distance_map(target) + + # Compute boundary loss: sum over spatial dimensions of (probabilities * distance_map) + # Then average over classes and batch + spatial_axes = list(range(2, input.dim())) + + loss = torch.sum(input * distance_map, dim=spatial_axes) + + # Normalize by number of pixels per class per batch element + num_pixels = torch.prod(torch.as_tensor(input.shape[2:], device=input.device)) + loss = loss / num_pixels + if self.batch: + loss = loss.mean(dim=0) + + if self.reduction == LossReduction.MEAN.value: + loss = loss.mean() + elif self.reduction == LossReduction.SUM.value: + loss = loss.sum() + elif self.reduction == LossReduction.NONE.value: + # Return shape (B, C') unless batch=True reduces the batch dimension first. + pass + else: + raise ValueError(f"Unsupported reduction: {self.reduction}") + + return loss diff --git a/tests/losses/test_boundary_loss.py b/tests/losses/test_boundary_loss.py new file mode 100644 index 0000000000..156d4767c0 --- /dev/null +++ b/tests/losses/test_boundary_loss.py @@ -0,0 +1,334 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest.case import skipUnless + +import torch +from parameterized import parameterized + +from monai.losses import BoundaryLoss +from monai.utils import optional_import + +_, has_scipy = optional_import("scipy") + +# Reusable test tensors +ONES_2D = {"input": torch.ones((2, 2, 8, 8)), "target": torch.ones((2, 2, 8, 8))} +ONES_3D = {"input": torch.ones((2, 2, 8, 8, 8)), "target": torch.ones((2, 2, 8, 8, 8))} + +# Perfect match: target is a 2x2 square, input matches exactly +PERFECT_MATCH = { + "input": torch.tensor( + [[[[1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]] + ), + "target": torch.tensor( + [[[[1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]] + ), +} + +# Partial overlap: two 2x2 squares shifted by 1 pixel +PARTIAL_OVERLAP = { + "input": torch.tensor( + [[[[1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]] + ), + "target": torch.tensor( + [[[[0.0, 1.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]] + ), +} + +# Empty foreground class: target has no foreground in class 1 +EMPTY_FOREGROUND = { + "input": torch.tensor( + [[[[0.9, 0.9, 0.9], [0.9, 0.9, 0.9], [0.9, 0.9, 0.9]], [[0.1, 0.1, 0.1], [0.1, 0.1, 0.1], [0.1, 0.1, 0.1]]]] + ), + "target": torch.tensor( + [[[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]] + ), +} + +TEST_CASES = [] +for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: + # Basic 2D test with sigmoid + TEST_CASES.append( + [ + {"include_background": True, "sigmoid": True}, + { + "input": torch.tensor([[[[2.0, -2.0], [-2.0, 2.0]]]], device=device), + "target": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=device), + }, + None, # Just check it runs, value depends on distance map + ] + ) + # Basic 3D test with sigmoid + TEST_CASES.append( + [ + {"include_background": True, "sigmoid": True}, + { + "input": torch.tensor([[[[[2.0, -2.0], [-2.0, 2.0]], [[2.0, -2.0], [-2.0, 2.0]]]]], device=device), + "target": torch.tensor([[[[[1.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]]]]], device=device), + }, + None, + ] + ) + # Multi-class 2D with softmax + TEST_CASES.append( + [ + {"include_background": True, "softmax": True}, + { + "input": torch.tensor([[[[2.0, 0.0], [0.0, 2.0]], [[-2.0, 0.0], [0.0, -2.0]]]], device=device), + "target": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]]], device=device), + }, + None, + ] + ) + # With to_onehot_y + TEST_CASES.append( + [ + {"include_background": True, "to_onehot_y": True, "softmax": True}, + { + "input": torch.tensor([[[[2.0, 0.0], [0.0, 2.0]], [[-2.0, 0.0], [0.0, -2.0]]]], device=device), + "target": torch.tensor([[[[0, 0], [0, 1]]]], device=device), + }, + None, + ] + ) + # With reduction="none" + TEST_CASES.append( + [ + {"include_background": True, "sigmoid": True, "reduction": "none"}, + { + "input": torch.tensor([[[[2.0, -2.0], [-2.0, 2.0]]]], device=device), + "target": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=device), + }, + None, + ] + ) + # With reduction="sum" + TEST_CASES.append( + [ + {"include_background": True, "sigmoid": True, "reduction": "sum"}, + { + "input": torch.tensor([[[[2.0, -2.0], [-2.0, 2.0]]]], device=device), + "target": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=device), + }, + None, + ] + ) + # Exclude background + TEST_CASES.append( + [ + {"include_background": False, "sigmoid": True}, + { + "input": torch.tensor([[[[2.0, -2.0], [-2.0, 2.0]], [[-2.0, 2.0], [2.0, -2.0]]]], device=device), + "target": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]]], device=device), + }, + None, + ] + ) + # With other_act + TEST_CASES.append( + [ + {"include_background": True, "other_act": torch.tanh}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=device), + }, + None, + ] + ) + + +def _describe_test_case(test_func, test_number, params): + input_param, input_data, _ = params.args + return f"params:{input_param}, shape:{input_data['input'].shape}, device:{input_data['input'].device}" + + +@skipUnless(has_scipy, "Scipy required") +class TestBoundaryLoss(unittest.TestCase): + + @parameterized.expand(TEST_CASES, doc_func=_describe_test_case) + def test_runs(self, input_param, input_data, _): + """Test that the loss runs without errors for various configurations.""" + loss = BoundaryLoss(**input_param) + result = loss(**input_data) + # Just verify it's a scalar tensor and finite + self.assertTrue(torch.isfinite(result).all()) + + def test_perfect_match(self): + """Test that perfect predictions yield lower loss than imperfect ones.""" + loss_fn = BoundaryLoss() + perfect_loss = loss_fn(PERFECT_MATCH["input"], PERFECT_MATCH["target"]) + partial_loss = loss_fn(PARTIAL_OVERLAP["input"], PARTIAL_OVERLAP["target"]) + # Perfect match should have lower loss than partial overlap + self.assertLess(perfect_loss.item(), partial_loss.item()) + + def test_reduction_shapes(self): + """Test that different reductions produce expected shapes.""" + input_tensor = torch.ones((4, 2, 8, 8)) + target = torch.ones((4, 2, 8, 8)) + + self.assertEqual(BoundaryLoss(reduction="mean")(input_tensor, target).shape, torch.Size([])) + self.assertEqual(BoundaryLoss(reduction="sum")(input_tensor, target).shape, torch.Size([])) + # With include_background=True and 2 classes, shape should be (4, 2) + self.assertEqual(BoundaryLoss(reduction="none")(input_tensor, target).shape, torch.Size([4, 2])) + + def test_reduction_shapes_exclude_background(self): + """Test shapes when background is excluded.""" + input_tensor = torch.ones((4, 3, 8, 8)) + target = torch.ones((4, 3, 8, 8)) + + # With include_background=False, shape should be (4, 2) for 3 classes + self.assertEqual( + BoundaryLoss(reduction="none", include_background=False)(input_tensor, target).shape, torch.Size([4, 2]) + ) + + def test_single_channel_options_warn_and_are_ignored(self): + """Test that single-channel-only options follow other MONAI loss behavior.""" + input_tensor = torch.randn((1, 1, 4, 4), requires_grad=True) + target = torch.zeros((1, 1, 4, 4)) + target[..., 1:3, 1:3] = 1 + + with self.assertWarns(Warning): + loss = BoundaryLoss(softmax=True)(input_tensor, target) + loss.backward() + self.assertGreater(input_tensor.grad.abs().sum().item(), 0.0) + + with self.assertWarns(Warning): + result = BoundaryLoss(include_background=False)(input_tensor.detach(), target) + self.assertTrue(torch.isfinite(result)) + + with self.assertWarns(Warning): + result = BoundaryLoss(to_onehot_y=True)(input_tensor.detach(), target) + self.assertTrue(torch.isfinite(result)) + + def test_to_onehot_y_accepts_channel_free_target(self): + """Test target labels can omit the singleton channel dimension.""" + input_tensor = torch.randn((2, 3, 4, 4)) + target = torch.randint(0, 3, size=(2, 4, 4)) + result = BoundaryLoss(to_onehot_y=True, softmax=True)(input_tensor, target) + self.assertTrue(torch.isfinite(result)) + + def test_degenerate_target_distance_map_is_zero(self): + """Test that empty and full classes don't create edge-biased distance maps.""" + loss_fn = BoundaryLoss() + empty_target = torch.zeros((1, 1, 4, 4)) + full_target = torch.ones((1, 1, 4, 4)) + + self.assertTrue(torch.equal(loss_fn.compute_distance_map(empty_target), torch.zeros_like(empty_target))) + self.assertTrue(torch.equal(loss_fn.compute_distance_map(full_target), torch.zeros_like(full_target))) + + def test_batch_reduction_changes_none_shape_and_values(self): + """Test that batch=True reduces the batch dimension before final reduction.""" + input_tensor = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]], [[[0.0, 1.0], [1.0, 0.0]]]]) + target = torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]) + + batch_false = BoundaryLoss(reduction="none", batch=False)(input_tensor, target) + batch_true = BoundaryLoss(reduction="none", batch=True)(input_tensor, target) + + self.assertEqual(batch_false.shape, torch.Size([2, 1])) + self.assertEqual(batch_true.shape, torch.Size([1])) + self.assertTrue(torch.allclose(batch_true, batch_false.mean(dim=0))) + + def test_ill_shape(self): + """Test that mismatched shapes raise an error.""" + loss = BoundaryLoss() + with self.assertRaisesRegex(AssertionError, "shapes do not match"): + loss(torch.ones((1, 1, 2, 3)), torch.ones((1, 4, 5, 6))) + + def test_ill_opts(self): + """Test that invalid options raise errors.""" + with self.assertRaisesRegex(ValueError, ""): + BoundaryLoss(sigmoid=True, softmax=True) + with self.assertRaisesRegex(ValueError, ""): + BoundaryLoss(sigmoid=True, other_act=torch.tanh) + with self.assertRaisesRegex(ValueError, ""): + BoundaryLoss(softmax=True, other_act=torch.tanh) + with self.assertRaisesRegex(ValueError, ""): + BoundaryLoss(sigmoid=True, softmax=True, other_act=torch.tanh) + + chn_input = torch.ones((1, 1, 3, 3)) + chn_target = torch.ones((1, 1, 3, 3)) + with self.assertRaisesRegex(ValueError, ""): + BoundaryLoss(reduction="unknown")(chn_input, chn_target) + + def test_invalid_other_act_type(self): + """Test that non-callable other_act raises TypeError.""" + with self.assertRaises(TypeError): + BoundaryLoss(other_act="invalid") + + def test_empty_foreground(self): + """Test that empty foreground classes don't crash the loss.""" + loss_fn = BoundaryLoss(sigmoid=False) + result = loss_fn(EMPTY_FOREGROUND["input"], EMPTY_FOREGROUND["target"]) + self.assertTrue(torch.isfinite(result)) + + def test_dimension_validation(self): + """Test that unsupported dimensions raise errors.""" + loss = BoundaryLoss() + with self.assertRaises(ValueError): + # 1D input should fail + loss(torch.ones((1, 1, 10)), torch.ones((1, 1, 10))) + with self.assertRaises(ValueError): + # 4D input (5D with batch+channel) should fail + loss(torch.ones((1, 1, 2, 2, 2, 2)), torch.ones((1, 1, 2, 2, 2, 2))) + + def test_distance_map_computation(self): + """Test that distance maps are computed correctly for a simple case.""" + # Simple 3x3 case: foreground in center pixel + target = torch.zeros((1, 1, 3, 3)) + target[0, 0, 1, 1] = 1.0 # Center pixel is foreground + + loss_fn = BoundaryLoss() + distance_map = loss_fn.compute_distance_map(target) + + # Center pixel is on the boundary (single-pixel object), so distance should be 0 or near 0 + self.assertAlmostEqual(distance_map[0, 0, 1, 1].item(), 0.0, places=5) + + # Corners should be positive (outside foreground) + self.assertGreater(distance_map[0, 0, 0, 0].item(), 0) + self.assertGreater(distance_map[0, 0, 0, 2].item(), 0) + self.assertGreater(distance_map[0, 0, 2, 0].item(), 0) + self.assertGreater(distance_map[0, 0, 2, 2].item(), 0) + + # Neighbors of center should also be positive (outside foreground) + self.assertGreater(distance_map[0, 0, 0, 1].item(), 0) + self.assertGreater(distance_map[0, 0, 1, 0].item(), 0) + + def test_loss_gradient_flow(self): + """Test that gradients flow through the loss.""" + input_tensor = torch.randn((2, 2, 8, 8), requires_grad=True) + target = torch.ones((2, 2, 8, 8)) + + loss_fn = BoundaryLoss(sigmoid=True) + loss = loss_fn(input_tensor, target) + loss.backward() + + self.assertIsNotNone(input_tensor.grad) + self.assertTrue(torch.isfinite(input_tensor.grad).all()) + + def test_consistency_with_hausdorff_loss(self): + """Test that BoundaryLoss behaves differently from HausdorffDTLoss on the same input.""" + from monai.losses import HausdorffDTLoss + + input_tensor = torch.tensor([[[[2.0, -2.0], [-2.0, 2.0]]]]) + target = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]) + + bl_loss = BoundaryLoss(sigmoid=True)(input_tensor, target) + hd_loss = HausdorffDTLoss(sigmoid=True)(input_tensor, target) + + # They should produce different values (different formulations) + self.assertNotAlmostEqual(bl_loss.item(), hd_loss.item(), places=3) + + +if __name__ == "__main__": + unittest.main()