Skip to content

Commit

Permalink
Op info test for linalg.norm, linalg.pinv, linalg.solve (#7503) (#8044)
Browse files Browse the repository at this point in the history
  • Loading branch information
yenkwang committed Sep 19, 2024
1 parent c0501f0 commit 441098b
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 4 deletions.
5 changes: 1 addition & 4 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@
"linalg.matrix_norm",
"linalg.matrix_power",
"linalg.matrix_rank",
"linalg.norm",
"linalg.pinv",
"linalg.solve",
"linalg.solve_ex",
"linalg.solve_triangular",
"linalg.svd",
Expand Down Expand Up @@ -198,7 +195,7 @@
'nn.functional.feature_alpha_dropout',
}

atol_dict = {"matrix_exp": (2e-1, 2e-4)}
atol_dict = {"matrix_exp": (2e-1, 2e-4), "linalg.pinv": (8e-1, 2e0)}

def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True, check_output=True):
if isinstance(output1, torch.Tensor):
Expand Down
22 changes: 22 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,6 +1194,10 @@ def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None):
# (Optional) dtype conversion
if dtype is not None:
result = jnp.astype(result, self.dtype)

new_dtype = mappings.t2j_dtype(torch.get_default_dtype())
if result.dtype == jax.numpy.int64:
result = result.astype(new_dtype)
return result


Expand Down Expand Up @@ -3921,6 +3925,24 @@ def _aten__linalg_slogdet(input):
return res.sign, res.logabsdet


# torch.linalg.svd
@op(torch.ops.aten._linalg_svd)
def _aten__linalg_svd(a, full_matrices=True):
return jnp.linalg.svd(a, full_matrices)


# torch.linalg.pinv
@op(torch.ops.aten.linalg_pinv.atol_rtol_tensor)
def _aten_linalg_pinv_atol_rtol_tensor(a, rtol=None, **kwargs):
return jnp.linalg.pinv(a, rtol, hermitian=False)


# torch.linalg.solve
@op(torch.ops.aten._linalg_solve_ex)
def _aten__linalg_solve_ex(a, b):
return jnp.linalg.solve(a, b), jnp.array(0)


@op(torch.ops.aten.median)
def _aten_median(self, dim=None, keepdim=False):
output = _with_reduction_scalar(functools.partial(jnp.quantile, q=0.5, method='lower'), self, dim=dim, keepdim=keepdim).astype(self.dtype)
Expand Down
6 changes: 6 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,9 @@ def linalg_slogdet(input):
@register_function(torch.tensor_split)
def tensor_split(input, indices_or_sections, dim=0):
return jnp.array_split(input, indices_or_sections, axis=dim)


@register_function(torch.linalg.solve)
def linalg_solve(a, b):
res, _ = jaten._aten__linalg_solve_ex(a, b)
return res

0 comments on commit 441098b

Please sign in to comment.