Skip to content

torch.conj behaves differently on cpu and mps #148599

Closed as duplicate of#148156
Closed as duplicate of#148156
@arandono

Description

@arandono

🐛 Describe the bug

torch.conj appears to behave differently on cpu vs mps devices. On cpu when combined in matrix multiplication operations it behaves as expected. On mps devices it does not perform conjugation before the matrix multiplication. Here's an example:

a = torch.rand(2,2, dtype=torch.cfloat)
A = a.to("mps")

b = torch.rand(N,N, dtype=torch.cfloat)
B = b.to("mps")

ab1 = torch.mm(a,b)
AB1 = torch.mm(A,B)

ab2 = torch.mm(a,torch.conj(b))
AB2 = torch.mm(A,torch.conj(B))

ab3 = torch.mm(a,torch.conj_physical(b))
AB3 = torch.mm(A,torch.conj_physical(B))

print(ab1)
print(AB1)
print(ab2)
print(AB2)
print(ab3)
print(AB3)

We should have ab1=AB1, and ab2=AB2=ab3=AB3. But note that ab2≠AB2. Instead, AB2=AB1 suggesting the conjugate operation was not executed properly on mps devices. However, torch.conj_physical appears to work as expected.

Versions

Collecting environment information...
PyTorch version: 2.6.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.5 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.6)
CMake version: Could not collect
Libc version: N/A

Python version: 3.11.5 (main, Sep 11 2023, 08:31:25) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] numpy==1.26.0
[pip3] torch==2.6.0
[pip3] torchaudio==2.6.0
[pip3] torchvision==0.21.0
[conda] numpy 1.26.0 py311he598dae_0
[conda] numpy-base 1.26.0 py311hfbfe69c_0
[conda] torch 2.6.0 pypi_0 pypi
[conda] torchaudio 2.6.0 pypi_0 pypi
[conda] torchvision 0.21.0 pypi_0 pypi

cc @ezyang @anjali411 @dylanbespalko @mruberry @nikitaved @amjames @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: complexRelated to complex number support in PyTorchmodule: mpsRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions