Skip to content

torch.mode when input has nans #46225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Tracked by #61417
nikhilmishra000 opened this issue Oct 13, 2020 · 2 comments
Open
Tracked by #61417

torch.mode when input has nans #46225

nikhilmishra000 opened this issue Oct 13, 2020 · 2 comments
Labels
module: docs Related to our documentation, both in docs/ and docblocks module: NaNs and Infs Problems related to NaN and Inf handling in floating point module: numpy Related to numpy support, and also numpy compatibility of our operators module: reductions module: sorting and selection triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@nikhilmishra000
Copy link

nikhilmishra000 commented Oct 13, 2020

🐛 Bug

torch.mode has inconsisent behavior when the input contains nans:

  • The torch docs do not say what the nan policy is, whereas the scipy equivalent lets the user decide
  • On cpu, torch.mode acts like scipy's nan_policy="omit"
  • On cuda, it gives a nonsense result

To Reproduce

def test(device):
    x = torch.rand(1000).mul(5).long().to(device)
    s = torch.bincount(x, minlength=5).argsort(descending=True)
    
    mode =  x.mode().values
    print(f'w/o nans, got {mode}, expected {s[0]}')

    y = x.clone().float()
    y[y == mode] = np.nan
    mode =  y.mode().values.long()
    print(f'w nans, got {mode}, expected {s[1]}')

When running test("cpu"), both lines always give the expected result:

In [17]: test('cpu')
w/o nans, got 3, expected 3
w nans, got 0, expected 0

whereas when running test("cuda"), the first line always gives the expected result, but the second line gives something seemingly random:

In [26]: test('cuda')
w/o nans, got 2, expected 2
w nans, got 4, expected 0

Expected behavior

Environment

Output of collect_env.py:

Collecting environment information...
PyTorch version: 1.6.0
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.3 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.10.2

Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.0.130
GPU models and configuration:
GPU 0: TITAN RTX
GPU 1: TITAN RTX
GPU 2: GeForce RTX 2080 Ti
GPU 3: GeForce GTX 1080 Ti

Nvidia driver version: 440.64.00
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] msgpack-numpy==0.4.4.3
[pip3] numpy==1.16.1
[pip3] numpy-quaternion==2020.10.2.17.17.31
[pip3] numpy-stl==2.10.1
[pip3] torch==1.6.0
[pip3] torchvision==0.6.0
[conda] msgpack-numpy             0.4.4.3                  pypi_0    pypi
[conda] numpy                     1.16.1                   pypi_0    pypi
[conda] numpy-quaternion          2020.10.2.17.17.31          pypi_0    pypi
[conda] numpy-stl                 2.10.1                   pypi_0    pypi
[conda] torch                     1.6.0                    pypi_0    pypi
[conda] torchvision               0.6.0                    pypi_0    pypi

cc @brianjo @mruberry @rgommers @heitorschueroff @ezyang @gchanan @zou3519 @bdhirsh @ejguan @jlin27

@bdhirsh bdhirsh added high priority module: docs Related to our documentation, both in docs/ and docblocks module: numpy Related to numpy support, and also numpy compatibility of our operators shadow review Request the triage shadow to take a second look at your triage and see if they agree or not labels Oct 13, 2020
@bdhirsh
Copy link
Contributor

bdhirsh commented Oct 13, 2020

tested this on latest master, confirmed that this is still occurring (in particular the results for cpu and cuda are different)

@mruberry mruberry added module: sorting and selection module: NaNs and Infs Problems related to NaN and Inf handling in floating point labels Oct 13, 2020
@mruberry
Copy link
Collaborator

Thanks for reporting this issue, @nikhilmishra000, we should definitely fix this and would take a PR updating our behavior. We should propagate NaNs in this case and note that it's a BC-breaking change to do so.

@ezyang ezyang removed the shadow review Request the triage shadow to take a second look at your triage and see if they agree or not label Oct 13, 2020
@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 20, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: docs Related to our documentation, both in docs/ and docblocks module: NaNs and Infs Problems related to NaN and Inf handling in floating point module: numpy Related to numpy support, and also numpy compatibility of our operators module: reductions module: sorting and selection triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants