From deba0f1fe6d48004acfa130abc6e641250689834 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 10 Jun 2026 09:21:04 +0000 Subject: [PATCH] fix(MetaTensor): astype with torch dtype now returns MetaTensor preserving metadata When calling MetaTensor.astype() with a torch dtype (e.g. torch.int32), the result was a plain torch.Tensor, silently losing all metadata (affine, spacing, applied transforms, etc.). The root cause was that out_type was hardcoded to torch.Tensor instead of the actual type of self. Fix by using type(self) as out_type when a torch dtype is requested, so that convert_data_type() receives output_type=MetaTensor, sets track_meta=True, and preserves metadata through the dtype cast. The analyzer module already annotated the result of astype(torch.int16) as MetaTensor, relying on this contract. Updated test to assert the result is an instance of MetaTensor and that the metadata key is preserved after the cast. Closes #8202 Signed-off-by: Oleksandr Sanin --- monai/data/meta_tensor.py | 5 +++-- tests/data/meta_tensor/test_meta_tensor.py | 8 ++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 12bd76ba60..06444a8769 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -442,7 +442,8 @@ def astype(self, dtype, device=None, *_args, **_kwargs): _kwargs: additional kwargs (currently unused). Returns: - data array instance + ``MetaTensor`` when a torch dtype is given (metadata is preserved), + or ``np.ndarray`` when a numpy dtype is given. """ if isinstance(dtype, str): mod_str, *dtype = dtype.split(".", 1) @@ -453,7 +454,7 @@ def astype(self, dtype, device=None, *_args, **_kwargs): out_type: type[torch.Tensor] | type[np.ndarray] | None if mod_str == "torch": - out_type = torch.Tensor + out_type = type(self) elif mod_str in ("numpy", "np"): out_type = np.ndarray else: diff --git a/tests/data/meta_tensor/test_meta_tensor.py b/tests/data/meta_tensor/test_meta_tensor.py index c0e53fd24c..a12f519f62 100644 --- a/tests/data/meta_tensor/test_meta_tensor.py +++ b/tests/data/meta_tensor/test_meta_tensor.py @@ -434,8 +434,12 @@ def test_astype(self): for np_types in ("float32", "np.float32", "numpy.float32", np.float32, float, "int", np.uint16): self.assertIsInstance(t.astype(np_types), np.ndarray) for pt_types in ("torch.float", torch.float, "torch.float64"): - self.assertIsInstance(t.astype(pt_types), torch.Tensor) - self.assertIsInstance(t.astype("torch.float", device="cpu"), torch.Tensor) + result = t.astype(pt_types) + self.assertIsInstance(result, MetaTensor) + self.assertEqual(result.meta.get("fname"), "filename") + result = t.astype("torch.float", device="cpu") + self.assertIsInstance(result, MetaTensor) + self.assertEqual(result.meta.get("fname"), "filename") def test_transforms(self): key = "im"