Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a different str for device instead of meta #7995

Merged
merged 2 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions experimental/torch_xla2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,15 @@ with env:
print(type(res)) # outputs XLATensor2
```

You can also enable the environment globally with
```python
import torch_xla2

torch_xla2.enable_globally()
```

Then everything afterwards is run with XLA.

## What is happening behind the scene:

When a torch op is executed inside of `env` context manager, we can swap out the
Expand Down
25 changes: 10 additions & 15 deletions experimental/torch_xla2/examples/basic_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,16 @@
"""

import torch
from torch.utils import _pytree as pytree
import torchvision
import torchvision.transforms as transforms
import torch_xla2.tensor


xla_env = torch_xla2.default_env()

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
#from torch.utils.tensorboard import SummaryWriter
#from datetime import datetime

# NOTE: add these lines to make it run on TPUs!
import torch_xla2
torch_xla2.enable_globally()

transform = transforms.Compose(
[transforms.ToTensor(),
Expand Down Expand Up @@ -83,7 +81,6 @@ def forward(self, x):


model = GarmentClassifier()
model = xla_env.to_xla(model)

loss_fn = torch.nn.CrossEntropyLoss()

Expand All @@ -102,7 +99,7 @@ def forward(self, x):
# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

def train_one_epoch(epoch_index, tb_writer):
def train_one_epoch(epoch_index, tb_writer=None):
running_loss = 0.
last_loss = 0.

Expand All @@ -112,7 +109,6 @@ def train_one_epoch(epoch_index, tb_writer):
for i, data in enumerate(training_loader):
# Every data instance is an input + label pair
# NEW: Move model to XLA device
data = xla_env.to_xla(data)
inputs, labels = data

# Zero your gradients for every batch!
Expand All @@ -135,16 +131,16 @@ def train_one_epoch(epoch_index, tb_writer):
last_loss = running_loss / 1000 # loss per batch
print(' batch {} loss: {}'.format(i + 1, last_loss))
tb_x = epoch_index * len(training_loader) + i + 1
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
#tb_writer.add_scalar('Loss/train', last_loss, tb_x)
running_loss = 0.

return last_loss



# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
#timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
#writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0
EPOCHS = 2
best_vloss = 1_000_000.
Expand All @@ -156,7 +152,7 @@ def train_one_epoch(epoch_index, tb_writer):
model.train(True)


avg_loss = train_one_epoch(epoch_number, writer)
avg_loss = train_one_epoch(epoch_number)

running_vloss = 0.0
# Set the model to evaluation mode, disabling dropout and using population
Expand All @@ -167,7 +163,6 @@ def train_one_epoch(epoch_index, tb_writer):
with torch.no_grad():
for i, vdata in enumerate(validation_loader):
# NOTE: move to XLA device
vinputs, vlabels = xla_env.to_xla(vdata)
voutputs = model(vinputs) # call model's forward
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss
Expand Down
9 changes: 3 additions & 6 deletions experimental/torch_xla2/test/test_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,22 @@ def test_add(self):
x = torch.tensor([1, 2, 3], dtype=torch.int32)
y = torch.tensor([4, 5, 6], dtype=torch.int32)
x.add_(y)
xt = torch_xla2.tensor.j2t(x._elem)
self.assertEqual(xt, torch.tensor([5, 7, 9], dtype=torch.int32))
self.assertEqual(x, torch.tensor([5, 7, 9], dtype=torch.int32))

def test_sub(self):
with self.env:
x = torch.tensor([1, 2, 3], dtype=torch.int32)
y = torch.tensor([4, 5, 6], dtype=torch.int32)
x.sub_(y)
xt = torch_xla2.tensor.j2t(x._elem)
self.assertEqual(xt, torch.tensor([-3, -3, -3], dtype=torch.int32))
self.assertEqual(x, torch.tensor([-3, -3, -3], dtype=torch.int32))

def test_mul(self):
with self.env:
x = torch.tensor([1, 2, 3], dtype=torch.int32)
y = torch.tensor([4, 5, 6], dtype=torch.int32)

x.mul_(y)
xt = torch_xla2.tensor.j2t(x._elem)
self.assertEqual(xt, torch.tensor([4, 10, 18], dtype=torch.int32))
self.assertEqual(x, torch.tensor([4, 10, 18], dtype=torch.int32))


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def setUpClass(cls):

def setUp(self):
self.env = tensor.Environment()
torch.manual_seed(0)

# Replaces all values in the input torch_tensor that are less than the given threshold
# with the threshold value itself.
Expand Down
20 changes: 20 additions & 0 deletions experimental/torch_xla2/torch_xla2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
__all__ = [
'default_env',
'extract_jax',
'enable_globally',
]

from jax._src import xla_bridge
Expand Down Expand Up @@ -61,3 +62,22 @@ def jax_func(states, inputs):
return env.t2j_iso(res)

return states, jax_func

def enable_globally():
global env
env = default_env().__enter__()
return env

def disable_globally():
global env
default_env().__exit__(None, None, None)


torch.utils.rename_privateuse1_backend('jax')
unsupported_dtype = [torch.quint8]
torch.utils.generate_methods_for_privateuse1_backend(
for_tensor=True, for_module=True, for_storage=True,
unsupported_dtype=unsupported_dtype)

import jax
torch._register_device_module('jax', jax)
6 changes: 5 additions & 1 deletion experimental/torch_xla2/torch_xla2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,8 @@ class Configuration:

# Flash attention
use_tpu_flash_attention: bool = False
shmap_flash_attention: bool = False
shmap_flash_attention: bool = False

# device
treat_cuda_as_jax_device: bool = True
use_torch_native_for_cpu_tensor: bool = False
2 changes: 1 addition & 1 deletion experimental/torch_xla2/torch_xla2/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _jax_view(t: TorchValue) -> JaxValue:
assert isinstance(t, tensor.XLATensor2)
return t.jax()
if isinstance(t, type(torch.int32)):
return tensor.j2t_dtype(t)
return tensor.t2j_dtype(t)

# torch.nn.Module needs special handling
if not isinstance(t, torch.nn.Module) and callable(t): # t is a TorchCallable
Expand Down
64 changes: 56 additions & 8 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import sys
from typing import Optional, Sequence
import functools

import jax
from jax import numpy as jnp
Expand Down Expand Up @@ -38,6 +39,7 @@
torch.ops.aten.squeeze_: torch.ops.aten.squeeze,
torch.ops.aten.bernoulli_: torch.ops.aten.bernoulli.p,
torch.ops.aten.clamp_: torch.ops.aten.clamp,
torch.ops.aten.random_: torch.ops.aten.uniform,
}


Expand All @@ -55,6 +57,18 @@ def op(*aten, **kwargs):
def inner(func):
for a in aten:
ops_registry.register_torch_dispatch_op(a, func, **kwargs)
continue

if isinstance(a, torch._ops.OpOverloadPacket):
opname = a.default.name() if 'default' in a.overloads() else a._qualified_op_name
elif isinstance(a, torch._ops.OpOverload):
opname = a.name()
else:
raise RuntimeError(f'oops {a}')

torchfunc = functools.partial(interop.call_jax, func)
# HACK: to_copy is where we make the initial conversion from CPU tensor to JAX tensor
torch.library.impl(opname, 'privateuseone')(torchfunc if a != torch.ops.aten._to_copy else func)
return func

return inner
Expand All @@ -80,14 +94,13 @@ def _aten_add(x, y, *, alpha=1):
return x + y * alpha


@op(torch.ops.aten.copy_, torch.ops.aten.copy_.default, is_jax_function=False)
@op(torch.ops.aten.copy_, is_jax_function=False)
def _aten_copy(x, y, memory_format=None):
x._elem = y._elem.astype(x._elem.dtype)
return x


@op(torch.ops.aten.clone)
@op(torch.ops.aten.clone.default)
def _aten_clone(x, memory_format=None):
return x

Expand Down Expand Up @@ -433,6 +446,8 @@ def _aten__to_copy(self, **kwargs):
return jnp.copy(self)




@op(torch.ops.aten.empty)
@op_base.convert_dtype()
def _aten_empty(size: Sequence[int], *, dtype=None, **kwargs):
Expand Down Expand Up @@ -465,7 +480,6 @@ def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs):


@op(torch.ops.aten.empty_permuted)
@op(torch.ops.aten.empty_permuted.default)
@op_base.convert_dtype()
def _aten_empty_permuted(sizes, physical_layout, dtype=None, **kwargs):
# Ignore the physical layout,
Expand All @@ -474,7 +488,6 @@ def _aten_empty_permuted(sizes, physical_layout, dtype=None, **kwargs):


@op(torch.ops.aten.empty_strided)
@op(torch.ops.aten.empty_strided.default)
@op_base.convert_dtype()
def _aten_empty_strided(sizes, stride, dtype=None, **kwargs):
# Ignore stride, since JAX and torch tensor doesn't share the same memory.
Expand Down Expand Up @@ -540,7 +553,6 @@ def permute(t, dims):

@op(torch.ops.aten.unsqueeze)
@op(torch.ops.aten.unsqueeze_copy)
@op(torch.ops.aten.unsqueeze.default)
def _aten_unsqueeze(self, dim):
if dim < 0:
dim += self.ndim + 1
Expand Down Expand Up @@ -1618,7 +1630,6 @@ def _with_reduction_scalar(jax_func, self, dim, keepdim):
def _aten_any(self, dim=None, keepdim=False):
return _with_reduction_scalar(jnp.any, self, dim, keepdim)


# aten.arange
@op(torch.ops.aten.arange.start_step)
@op(torch.ops.aten.arange.start)
Expand Down Expand Up @@ -1960,7 +1971,6 @@ def _aten_ge(self, other):


@op(torch.ops.aten.glu)
@op(torch.ops.aten.glu.default)
def _aten_glu(x, dim=-1):
return jax.nn.glu(x, dim)

Expand Down Expand Up @@ -2110,6 +2120,38 @@ def _aten_prod(self, dim=None, keepdim=False):


# aten.randperm
# randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None)
@op(torch.ops.aten.randperm, needs_env=True)
def _aten_randperm(
n, *,
generator=None,
dtype=None,
layout=None,
device=None,
pin_memory=None,
env=None):
"""
Generates a random permutation of integers from 0 to n-1.

Args:
n: The upper bound (exclusive) of the permutation range.
generator: A PRNGKey used as the random key. If None, a new key is created.
dtype: The desired data type of the output array. Default is jnp.int64.
layout: The desired layout of the output array (e.g., 'row-major', 'column-major').
device: The desired device on which to place the output array (e.g., jax.devices()[0]).
pin_memory: Whether to pin the output array's memory to the host.

Returns:
A DeviceArray containing a random permutation of integers from 0 to n-1.
"""
if dtype:
dtype = mappings.t2j_dtype(dtype)
else:
dtype = jnp.int64.dtype
key = env.get_and_rotate_prng_key(generator)
indices = jnp.arange(n, dtype=dtype)
permutation = jax.random.permutation(key, indices)
return permutation


# aten.reflection_pad3d
Expand Down Expand Up @@ -2467,6 +2509,12 @@ def _aten_normal(self, mean=0, std=1, generator=None, env=None):
res = _randn(*shape, generator=generator, env=env)
return res * std + mean

# TODO: not clear what this function should actually do
# https://github.com/pytorch/pytorch/blob/d96c80649f301129219469d8b4353e52edab3b78/aten/src/ATen/native/native_functions.yaml#L7933-L7940
@op(torch.ops.aten.lift_fresh)
def _aten_lift_fresh(self):
return self

@op(torch.ops.aten.uniform, needs_env=True)
def _aten_uniform(self, from_=0, to=1, *, generator=None, env=None):
assert from_ <= to, f'Uniform from(passed in {from_}) must be less than to(passed in {to})'
Expand All @@ -2476,7 +2524,7 @@ def _aten_uniform(self, from_=0, to=1, *, generator=None, env=None):

#func: randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

@op(torch.ops.aten.randint, torch.ops.aten.randint.generator, needs_env=True)
@op(torch.ops.aten.randint, needs_env=True)
@op_base.convert_dtype(use_default_dtype=False)
def _aten_randint(
*args,
Expand Down
Loading
Loading