Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Use privateuseone dispatch key #7705

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

will-cromar
Copy link
Collaborator

@will-cromar will-cromar commented Jul 17, 2024

Following https://github.com/bdhirsh/pytorch_open_registration_example/blob/master/cpp_extensions/open_registration_extension.cpp as an example for C++ code.

With this change, we can actually create XLATensor2s using the "jax" device without explicitly overriding the dispatcher:

>>> torch.tensor([0], device='foo:0')
XLATensor2(..., device='meta', size=(1,), dtype=torch.int64)
>>> torch.tensor([0], device='jax:0') + torch.tensor([2], device='jax:0').numpy()
tensor([2])

Base automatically changed from wcromar/fix-incorrect-registration to master July 17, 2024 21:30
@will-cromar
Copy link
Collaborator Author

This looks really promising so far.

We can at the very least get better device semantics by implementing torch.ops.aten._to_copy in Python and registering it with torch.library.impl to privateuseone. We can then rename privateuseone to anything more user-friendly (jax in my example) with torch.utils.rename_privateuse1_backend. Instead of having to wrap all tensor creation with a mode or using our custom to_xla utility, you could use tensor.to or create tensors directly on our device with jax:0, e.g.

>>> import torch
>>> import torch_xla2.custom_device
>>> torch.tensor([0]).to('jax:0')
XLATensor2(..., device='meta', size=(1,), dtype=torch.int64)
>>> torch.tensor([0], device='jax:0')
XLATensor2(..., device='meta', size=(1,), dtype=torch.int64)

Even that is a nice improvement.

What I'm experimenting with now is actually setting our XlaTensor2's device to jax (instead of meta) and effectively removing __torch_dispatch__ and relying on the torch dispatcher. I also registered all of the ops in jaten with torch.library.impl. What's really interesting here is that our lowerings actually still get passed our Python subclass XlaTensor2, so we can easily pull out our wrapped JAX array.

The problem is if I try to print an XlaTensor2 when it has the jax:0 device, I get this error:

>>> torch.tensor([0], device='jax:0').cpu()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: Could not run 'aten::reshape' with arguments from the 'Autogradjax' backend. [...] 'aten::reshape' is only available for these backends: [list of every dispatch key that does not include an AutogradPrivateUse1....]`.

Why does printing require an reshape autograd implementation? No idea. More interestingly, why haven't we registered one? Looking at the PrivateUse1 doc it looks like we should be getting the autograd implementations for "free" with TORCH_LIBRARY_IMPL. Importantly, I'm using the Python version torch.library.impl which may not do the same thing.

Using the torch dispatcher instead of __torch_dispatch__ might help us correctly handle some code that's in C++ (e.g. the DDP implementation) or keep Dynamo from trying to trace into our Python dispatch logic. I'll keep investigating.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant