Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a different str for device instead of meta #7995

Merged
merged 2 commits into from
Sep 12, 2024
Merged

Use a different str for device instead of meta #7995

merged 2 commits into from
Sep 12, 2024

Conversation

qihqi
Copy link
Collaborator

@qihqi qihqi commented Sep 11, 2024

Implements this idea: #7705 but without the C++ piece.

We could use anything for the device (incl just a string: https://github.com/albanD/subclass_zoo/blob/main/new_device.py); however it's advantagous to use a torch.device with an existing registered key. (we could use 'xla:0' too).

The caveat is that we need to capture tensor constructors in jtorch instead of jaten (like it used to be) because otherwise torch will check if such device module is loaded or not (example for capturing 'cuda').

fixes #7966

@qihqi qihqi force-pushed the hanq_xla2_2 branch 2 times, most recently from d8fb3de to e099f54 Compare September 12, 2024 20:48
So XLATensor2's device property is no longer 'meta'

Capture tensor constructors and give option to override 'cuda' with 'jax'
device (specially useful for running GPU code on TPU).

Co-authored-by: Will Cromar <[email protected]>
@@ -1634,6 +1645,7 @@ def _aten_arange(
requires_grad=False,
device=None,
pin_memory=False,
**kwargs
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need add **kwargs here to accept **kwargs from res2 = func(input2, *args2, **kwargs2)?

I seems met one situation that: func in jaten.py file didn't received kwargs2 from res2 = func(input2, *args2, **kwargs2)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this is not needed. removed.

Copy link
Collaborator

@ManfeiBai ManfeiBai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, amazing change, LGTM

@qihqi qihqi merged commit 1b57d1e into master Sep 12, 2024
3 checks passed
@qihqi qihqi deleted the hanq_xla2_2 branch September 12, 2024 23:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Report real device for XLATensor2
2 participants