Skip to content

Commit

Permalink
Revert "Allow MpDeviceLoader to shard dictionaries of tensors" (#7964)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Sep 5, 2024
1 parent 282bf93 commit 59570c7
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 207 deletions.
19 changes: 0 additions & 19 deletions docs/spmd_advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,6 @@ train_loader = pl.MpDeviceLoader(
input_sharding=xs.ShardingSpec(input_mesh, ('data', None, None, None)))
```

It is also possible to specify a different `input_sharding` for each element of the batch if they are different shapes:

```python
# if batch = next(train_loader) looks like
# {'x': <tensor of shape [s1, s2, s3, s4]>, 'y': <tensor for shape [s1, s2]>}

# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
train_loader, # wraps PyTorch DataLoader
device,
# specify different sharding for each input of the batch.
input_sharding={
'x': xs.ShardingSpec(input_mesh, ('data', None, None, None)),
'y': xs.ShardingSpec(input_mesh, ('data', None))
}
)
```


### Virtual Device Optimization

PyTorch/XLA normally transfers tensor data asynchronously from host to device once the tensor is defined. This is to overlap the data transfer with the graph tracing time. However, because GSPMD allows the user to modify the tensor sharding _after _the tensor has been defined, we need an optimization to prevent unnecessary transfer of tensor data back and forth between host and device. We introduce Virtual Device Optimization, a technique to place the tensor data on a virtual device SPMD:0 first, before uploading to the physical devices when all the sharding decisions are finalized. Every tensor data in SPMD mode is placed on a virtual device, SPMD:0. The virtual device is exposed to the user as an XLA device XLA:0 with the actual shards on physical devices, like TPU:0, TPU:1, etc.
Expand Down
1 change: 0 additions & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@ function run_xla_op_tests3 {
run_test "$CDIR/stablehlo/test_unbounded_dynamism.py"
run_test "$CDIR/quantized_ops/test_quantized_matmul.py"
run_test "$CDIR/quantized_ops/test_dot_general.py"
run_test "$CDIR/spmd/test_mp_input_sharding.py"
run_test "$CDIR/spmd/test_xla_sharding.py"
run_test "$CDIR/spmd/test_xla_sharding_hlo.py"
run_test "$CDIR/spmd/test_xla_virtual_device.py"
Expand Down
126 changes: 0 additions & 126 deletions test/spmd/test_mp_input_sharding.py

This file was deleted.

74 changes: 13 additions & 61 deletions torch_xla/distributed/parallel_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class PerDeviceQueue(object):

def __init__(self, device, loader_prefetch_size, device_prefetch_size):
self.device = device
self.cpu_loader_queue = kq.Queue(maxsize=loader_prefetch_size)
self.loader_queue = kq.Queue(maxsize=loader_prefetch_size)
self.queue = kq.Queue(maxsize=device_prefetch_size)
self.close_queue_count = itertools.count()

Expand Down Expand Up @@ -46,8 +46,6 @@ def next(self):
self._batches_yielded += 1

item = self._loader.next_item(self._device)
if isinstance(item, Exception):
raise item
if item is None:
xm.mark_step()
raise StopIteration
Expand All @@ -58,7 +56,7 @@ class ParallelLoader(object):
"""Wraps an existing PyTorch DataLoader with background data upload.
Args:
cpu_loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be
loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be
wrapped.
devices (`torch.device`...): The list of devices where the data has to be
sent. The i-th sample returned by the `loader` will be sent to `devices[i
Expand All @@ -76,21 +74,21 @@ class ParallelLoader(object):
host_to_device_transfer_threads (int, optional): The number of threads that
work in parallel to transfer data from loader queue to device queue.
Default: 1
input_sharding (ShardingSpec, Dict(str, ShardingSpec), optional): Sharding
spec to apply to compatible input tensors after loading.
input_sharding (ShardingSpec, optional): Sharding spec to apply to
compatible input tensors after loading.
Default: None
"""

def __init__(self,
cpu_loader,
loader,
devices,
batchdim=0,
batches_per_execution=1,
loader_prefetch_size=16,
device_prefetch_size=8,
host_to_device_transfer_threads=1,
input_sharding=None):
self._cpu_loader = cpu_loader
self._loader = loader
self._devices = [torch.device(x) for x in devices]
self._batchdim = batchdim
self._batches_per_execution = batches_per_execution
Expand Down Expand Up @@ -139,15 +137,15 @@ def close(self):
self._done = True
for dqueue in self._queues.values():
dqueue.queue.close()
dqueue.cpu_loader_queue.close()
dqueue.loader_queue.close()

@property
def batches_per_execution(self):
return self._batches_per_execution

def _loader_worker(self):
queues = list(self._queues.values())
data_iter = enumerate(self._cpu_loader)
data_iter = enumerate(self._loader)
batch = []
while not self._done:
try:
Expand All @@ -157,73 +155,27 @@ def _loader_worker(self):
batch.append(data)
if len(batch) == len(self._devices):
for queue_no, device_batch in enumerate(batch):
queues[queue_no].cpu_loader_queue.put(device_batch)
queues[queue_no].loader_queue.put(device_batch)
batch = []
for dqueue in queues:
dqueue.cpu_loader_queue.close_write()
dqueue.loader_queue.close_write()

def _get_batch(self, dqueue):
batch = []
while len(batch) < dqueue.queue.max_size():
item = dqueue.cpu_loader_queue.get()
while dqueue.queue.max_size() > len(batch):
item = dqueue.loader_queue.get()
if item is None:
break
batch.append(item)
return batch

def send_cpu_data_to_device(self, batches, device):
"""Move batch to device.
Args:
batch -> List(torch.Tensor), List(Dict(str: torch.Tensor)): Input batch
present in the cpu memory
device: TPU device where the batch should be moved
Returns:
result -> List(torch.Tensor), Dict(str: torch.Tensor): Returns a dict if the
input batch is a dict. Otherwise, returns a list of torch.Tensor.
"""
result = None
if isinstance(batches[0], dict):
if self._input_sharding and not isinstance(self._input_sharding, dict):
return [
ValueError(
f"input_sharding should be a dict or None when input batch is a dict."
)
]
result = []
for batch in batches:
xla_batch = {}
missing_keys = []
for key, tensor in batch.items():
assert type(tensor) == torch.Tensor
sharding_spec = None
if self._input_sharding:
if key not in self._input_sharding:
missing_keys.append(key)
continue
sharding_spec = self._input_sharding[key]

# xla_tensor is a list of tensors.
xla_tensor = xm.send_cpu_data_to_device(tensor, device, sharding_spec)
xla_batch[key] = xla_tensor[0]
if len(missing_keys) != 0:
# Returning exception as raising in the dataloading thread doesn't surface the problem in the main thread.
return [
KeyError(f"Keys: {missing_keys} are missing from input_sharding.")
]
result.append(xla_batch)
else:
result = xm.send_cpu_data_to_device(batches, device, self._input_sharding)
return result

def _worker(self, dqueue, host_to_device_transfer_threads):
device = torch.device(dqueue.device)
while True:
batch = self._get_batch(dqueue)
if not batch:
break
batch = self.send_cpu_data_to_device(batch, device)
batch = xm.send_cpu_data_to_device(batch, device, self._input_sharding)
for data in batch:
dqueue.queue.put(data)
close_queue_count = next(dqueue.close_queue_count)
Expand Down

0 comments on commit 59570c7

Please sign in to comment.