Skip to content

Commit

Permalink
test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Sep 12, 2024
1 parent 7762dd4 commit d8fb3de
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
1 change: 1 addition & 0 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def setUpClass(cls):

def setUp(self):
self.env = tensor.Environment()
torch.manual_seed(0)

# Replaces all values in the input torch_tensor that are less than the given threshold
# with the threshold value itself.
Expand Down
9 changes: 9 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def op(*aten, **kwargs):
def inner(func):
for a in aten:
ops_registry.register_torch_dispatch_op(a, func, **kwargs)
continue

if isinstance(a, torch._ops.OpOverloadPacket):
opname = a.default.name() if 'default' in a.overloads() else a._qualified_op_name
Expand Down Expand Up @@ -437,6 +438,14 @@ def _aten_dot(x, y):
return jnp.dot(x, y)


@op(torch.ops.aten._to_copy)
def _aten__to_copy(self, **kwargs):
dtype = mappings.t2j_dtype(kwargs["dtype"])
if dtype != self.dtype:
return self.astype(dtype)
return jnp.copy(self)




@op(torch.ops.aten.empty)
Expand Down

0 comments on commit d8fb3de

Please sign in to comment.