Skip to content

BackendUtilsTest fails for Torch backend under MPS (MacOS) #20708

@kas2020-commits

Description

@kas2020-commits

I've copied over the stdout below, but my understanding of the issue is that since NumPy defaults to 64-bit floats and Torch now supports MPS as a backend and MPS (at least on my 2022 M2 macbook pro) doesn't support float64, the conversion of the numpy array to a torch tensor fails.

One easy fix for this would be adding a dtype to x to change it to float32. It seems fine for the test since the test is focused on the dynamic backend switching to torch?

System Details:

  • platform: aarch64-darwin
  • python version: 3.12
  • Keras version: 3.7.0
=================================== FAILURES ===================================
_________________ BackendUtilsTest.test_dynamic_backend_torch __________________

self = <keras.src.utils.backend_utils_test.BackendUtilsTest testMethod=test_dynamic_backend_t

name = 'torch'

    @parameterized.named_parameters(
        ("numpy", "numpy"),
        ("jax", "jax"),
        ("tensorflow", "tensorflow"),
        ("torch", "torch"),
    )
    def test_dynamic_backend(self, name):
        dynamic_backend = backend_utils.DynamicBackend()
        x = np.random.uniform(size=[1, 2, 3])
    
        if name == "numpy":
            dynamic_backend.set_backend(name)
            if backend.backend() != "numpy":
                with self.assertRaisesRegex(
                    NotImplementedError,
                    "Currently, we cannot dynamically import the numpy backend",
                ):
                    y = dynamic_backend.numpy.log10(x)
            else:
                y = dynamic_backend.numpy.log10(x)
                self.assertIsInstance(y, np.ndarray)
        elif name == "jax":
            import jax
    
            dynamic_backend.set_backend(name)
            y = dynamic_backend.numpy.log10(x)
            self.assertIsInstance(y, jax.Array)
        elif name == "tensorflow":
            import tensorflow as tf
    
            dynamic_backend.set_backend(name)
            y = dynamic_backend.numpy.log10(x)
            self.assertIsInstance(y, tf.Tensor)
        elif name == "torch":
            import torch
    
            dynamic_backend.set_backend(name)
>           y = dynamic_backend.numpy.log10(x)

keras/src/utils/backend_utils_test.py:47: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
keras/src/backend/torch/numpy.py:846: in log10
    x = convert_to_tensor(x)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

x = array([[[0.08590334, 0.14173587, 0.21748844],
        [0.41238207, 0.41767306, 0.8030875 ]]])
dtype = torch.float64, sparse = None

    def convert_to_tensor(x, dtype=None, sparse=None):
        if sparse:
            raise ValueError("`sparse=True` is not supported with torch backend")
        if isinstance(x, Variable):
            # TorchDynamo has bugs supporting nn.Parameter type check.
            # Return it directly instead of pass it to the rest of the logic in the
            # function.
            return x.value
        if is_tensor(x):
            device = get_device()
            if x.device != device:
                x = x.to(device)
            if dtype is None:
                return x
            return x.to(to_torch_dtype(dtype))
        if dtype is None:
            if isinstance(x, bool):
                return torch.as_tensor(x, dtype=torch.bool, device=get_device())
            elif isinstance(x, int):
                return torch.as_tensor(x, dtype=torch.int32, device=get_device())
            elif isinstance(x, float):
                return torch.as_tensor(
                    x, dtype=to_torch_dtype(floatx()), device=get_device()
                )
    
        # Convert to np in case of any array-like that is not list or tuple.
        if not isinstance(x, (list, tuple)):
            x = np.array(x)
        elif len(x) > 0 and any(isinstance(x1, torch.Tensor) for x1 in x):
            # Handle list or tuple of torch tensors
            return torch.stack([convert_to_tensor(x1) for x1 in x])
        if isinstance(x, np.ndarray):
            if x.dtype == np.uint32:
                # Torch backend does not support uint32.
                x = x.astype(np.int64)
            if standardize_dtype(x.dtype) == "bfloat16":
                # Torch backend does not support converting bfloat16 ndarray.
                x = x.astype(np.float32)
                dtype = "bfloat16"
            dtype = dtype or x.dtype
        if dtype is None:
            dtype = result_type(
                *[getattr(item, "dtype", type(item)) for item in tree.flatten(x)]
            )
        dtype = to_torch_dtype(dtype)
>       return torch.as_tensor(x, dtype=dtype, device=get_device())
E       TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't 
 float64. Please use float32 instead.

keras/src/backend/torch/core.py:233: TypeError

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions