Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward XLATensorImpl::is_contiguous_custom to TensorImpl. #8032

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
62 changes: 61 additions & 1 deletion test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
DeviceSupport = collections.namedtuple('DeviceSupport', ['num_devices'])

XLA_DISABLE_FUNCTIONALIZATION = bool(
os.environ.get('XLA_DISABLE_FUNCTIONALIZATION', False))
int(os.environ.get('XLA_DISABLE_FUNCTIONALIZATION', '0')))


def _is_on_tpu():
Expand Down Expand Up @@ -2783,6 +2783,66 @@ def test_unsafe_buffer_pointer(self):
buf_ptr_3 = torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor_3)
self.assertGreaterEqual(buf_ptr_3, 0)

def test_consistent_strides(self):
# Tests whether the `is_contiguous()` method is consisten with the tensor's stride.
# In other words, if `is_contiguous()` is true, the tensor's stride should reflect
# in a contiguous storage.

def stride_is_contiguous(tensor):
# Order the sizes and strides tuple list in ascending stride order, so that the
# first element corresponds to the smallest stride.
sizes_and_strides = list(
sorted(zip(tensor.shape, tensor.stride()), key=lambda t: t[1]))

# A contiguous tensor's smallest stride should be 1.
if sizes_and_strides[0][1] != 1:
return False

# Check whether the next larger stride `stride[i + 1]` is equal the current
# one `stride[i]` multiplied by the current size `size[i]`.
for i, (size, stride) in enumerate(sizes_and_strides[:-1]):
if stride[i + 1] != stride[i] * size[i]:
return False

return True

def assert_strides_consistent(tensor):
self.assertEquals(tensor.is_contiguous(), stride_is_contiguous(tensor))

# Obviously contiguous, since it was created with random.
a = torch.rand(10).to(xm.xla_device())
assert_strides_consistent(a)

# Not contiguous, since we are skipping every other element.
b = a[::2]
assert_strides_consistent(b)

# Still not contiguous, since 'b' is not contiguous.
c = b[1:]
assert_strides_consistent(c)

def test_contiguity_on_different_memory_format(self):
# Create contiguous strided tensor.
a = torch.rand(2, 3, 4, 5).to(xm.xla_device())
self.assertTrue(a.is_contiguous())
# When functionalization is disabled, we fallback to the old behavior, where
# `is_contiguous()` calls always returns True.
self.assertEquals(
a.is_contiguous(memory_format=torch.channels_last),
XLA_DISABLE_FUNCTIONALIZATION)

# Make `a` contiguous in torch.channels_last memory format.
#
# This should, in theory, be a no-op, since we can't really change the strides
# of XLA tensors. However, `contiguous` is a composite operation that checks the
# tensor's metadata. Therefore, it shall clone the tensor whenever its strides
# do not conform to the given memory format.
b = a.contiguous(memory_format=torch.channels_last)
# When functionalization is disabled, we fallback to the old behavior, where
# `is_contiguous()` calls always returns True.
self.assertEquals(b.is_contiguous(), XLA_DISABLE_FUNCTIONALIZATION)
self.assertTrue(b.is_contiguous(memory_format=torch.channels_last))


class TestDLPack(parameterized.TestCase):

Expand Down
23 changes: 20 additions & 3 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1227,11 +1227,28 @@ at::Tensor XLANativeFunctions::clamp_min(const at::Tensor& self,
}

at::Tensor XLANativeFunctions::clone(
const at::Tensor& self,
std::optional<at::MemoryFormat> /* memory_format */) {
const at::Tensor& self, std::optional<at::MemoryFormat> memory_format) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
return bridge::AtenFromXlaTensor(

at::Tensor out = bridge::AtenFromXlaTensor(
tensor_methods::clone(bridge::GetXlaTensor(self)));

if (!runtime::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) {
at::Tensor ref;
if (memory_format.has_value() &&
*memory_format != at::MemoryFormat::Preserve) {
// We need to run the meta function as reference, for setting the correct
// strides to the output tensor.
at::Tensor ref_self = self.to(at::kMeta);
ref = ref_self.clone(memory_format);
} else {
ref = self;
}
out.unsafeGetTensorImpl()->set_sizes_and_strides(ref.sym_sizes(),
ref.sym_strides());
}

return out;
}

at::Tensor XLANativeFunctions::constant_pad_nd(const at::Tensor& self,
Expand Down
9 changes: 8 additions & 1 deletion torch_xla/csrc/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,16 @@ int64_t XLATensorImpl::numel_custom() const {
}

bool XLATensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const {
// If functionalization is disabled, the tensors' metadata aren't being
// updated w.r.t. the output of meta functions. Therefore, we fallback to the
// old behavior returning true, always.
if (runtime::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) {
return true;
}

// Storage is always contiguous, but the tensor metadata is_contiguous_ might
// be false due to the update in the functionalization layer..
return true;
return c10::TensorImpl::is_contiguous_custom(memory_format);
}

void XLATensorImpl::SetupSizeProperties() {
Expand Down
Loading