You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I’m using torch.compile with DistributedModelParallel. Running the below code results in a ValueError: Tensors must be contiguous. This error seems to be specific to the model and the world size. I would expect to see no such errors, like when I run the code with other world sizes.
import os
from typing import Callable, List, Union, Tuple
import multiprocessing
import torch
import torch.distributed as dist
import torch.nn as nn
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import (
EmbeddingShardingPlanner,
Topology,
)
from torchrec.distributed.test_utils.multi_process import MultiProcessContext
from torchrec.distributed.test_utils.test_sharding import create_test_sharder
from torchrec.distributed.test_utils.test_model import (
ModelInput,
)
from torchrec.distributed.types import (
ModuleSharder,
ShardingEnv,
ShardingPlan,
)
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedTensor
from torchrec.test_utils import get_free_port
class TestModel(nn.Module):
def __init__(self):
super().__init__()
# define model parameters
self.dense_in_feature = 820
self.dense_out_feature = 784
self.table_params = [
[311, 108],
[739, 408],
]
self.weighted_table_params = [
[159, 96],
[69, 24],
[412, 564],
[940, 300],
]
self.over_out_feature = 61
# sparse layer
self.tables = [
EmbeddingBagConfig(
num_embeddings=self.table_params[i][0],
embedding_dim=self.table_params[i][1],
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(len(self.table_params))
]
self.sparse = EmbeddingBagCollection(
tables=self.tables,
is_weighted=False,
)
# weighted sparse layer
self.weighted_tables = [
EmbeddingBagConfig(
num_embeddings=self.weighted_table_params[i][0],
embedding_dim=self.weighted_table_params[i][1],
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
)
for i in range(len(self.weighted_table_params))
]
self.sparse_weighted = EmbeddingBagCollection(
tables=self.weighted_tables,
is_weighted=True,
)
# dense layer
self.dense = nn.Linear(in_features=self.dense_in_feature, out_features=self.dense_out_feature, bias=True)
# over layer
in_features_concat = (
self.dense_out_feature
+ sum([table.embedding_dim * len(table.feature_names) for table in self.tables])
+ sum([table.embedding_dim * len(table.feature_names) for table in self.weighted_tables])
)
self.over = nn.Linear(in_features=in_features_concat, out_features=self.over_out_feature, bias=True)
def forward(
self,
input: ModelInput,
print_intermediate_layer: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# dense, sparse, weighted sparse layer output
dense_r = self.dense(input.float_features)
sparse_r = self.sparse(input.idlist_features)
sparse_weighted_r = self.sparse_weighted(input.idscore_features)
# concat dense, sparse, weighted sparse layer output
result = KeyedTensor(
keys=sparse_r.keys() + sparse_weighted_r.keys(),
length_per_key=sparse_r.length_per_key()
+ sparse_weighted_r.length_per_key(),
values=torch.cat([sparse_r.values(), sparse_weighted_r.values()], dim=1),
)
_features = [feature for table in self.tables for feature in table.feature_names]
_weighted_features = [feature for table in self.weighted_tables for feature in table.feature_names]
ret_list = []
ret_list.append(dense_r)
for feature_name in _features:
ret_list.append(result[feature_name])
for feature_name in _weighted_features:
ret_list.append(result[feature_name])
ret_concat = torch.cat(ret_list, dim=1)
# over layer output
over_r = self.over(ret_concat)
# sigmoid output
pred = torch.sigmoid(torch.mean(over_r, dim=1))
return pred, (dense_r, sparse_r, sparse_weighted_r, over_r)
def sharding_single_rank_test(
rank: int,
world_size: int,
model,
inputs,
sharders: List[ModuleSharder[nn.Module]],
backend: str,
compiled = True,
) -> None:
with MultiProcessContext(rank, world_size, backend) as ctx:
if compiled:
model = torch.compile(model)
local_model = model.to(ctx.device)
planner = EmbeddingShardingPlanner(
topology=Topology(
world_size, ctx.device.type
),
)
plan: ShardingPlan = planner.collective_plan(local_model, sharders, ctx.pg)
local_model = DistributedModelParallel(
local_model,
env=ShardingEnv.from_process_group(ctx.pg),
plan=plan,
sharders=sharders,
device=ctx.device,
)
# Run a single training step of the sharded model.
local_input = inputs[0][1][rank].to(ctx.device)
with torch.no_grad():
local_pred, (dense_r, sparse_r, sparse_weighted_r, over_r) = local_model(local_input)
# record the local prediction
all_local_pred = []
for _ in range(world_size):
all_local_pred.append(torch.empty_like(local_pred))
dist.all_gather(all_local_pred, local_pred, group=ctx.pg)
# record the local model's layer output
all_dense_r = []
for _ in range(world_size):
all_dense_r.append(torch.empty_like(dense_r))
dist.all_gather(all_dense_r, dense_r, group=ctx.pg)
sparse_r_dict = sparse_r.to_dict()
all_sparse_r_dict = {}
for key in sparse_r_dict:
all_sparse_r_dict[key] = []
for _ in range(world_size):
all_sparse_r_dict[key].append(torch.empty_like(sparse_r_dict[key]))
dist.all_gather(all_sparse_r_dict[key], sparse_r_dict[key].contiguous(), group=ctx.pg)
sparse_weighted_r_dict = sparse_weighted_r.to_dict()
all_sparse_weighted_r_dict = {}
for key in sparse_weighted_r_dict:
all_sparse_weighted_r_dict[key] = []
for _ in range(world_size):
all_sparse_weighted_r_dict[key].append(torch.empty_like(sparse_weighted_r_dict[key]))
dist.all_gather(all_sparse_weighted_r_dict[key], sparse_weighted_r_dict[key].contiguous(), group=ctx.pg)
all_over_r = []
for _ in range(world_size):
all_over_r.append(torch.empty_like(over_r))
dist.all_gather(all_over_r, over_r, group=ctx.pg)
def setUp():
os.environ["MASTER_ADDR"] = str("localhost")
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["GLOO_DEVICE_TRANSPORT"] = "TCP"
os.environ["NCCL_SOCKET_IFNAME"] = "lo"
torch.use_deterministic_algorithms(True)
if torch.cuda.is_available():
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
def run_multi_process_test(
callable: Callable[
...,
None,
],
world_size: int,
# pyre-ignore
**kwargs,
) -> None:
setUp()
ctx = multiprocessing.get_context("forkserver")
processes = []
for rank in range(world_size):
kwargs["rank"] = rank
kwargs["world_size"] = world_size
p = ctx.Process(
target=callable,
kwargs=kwargs,
)
p.start()
processes.append(p)
for p in processes:
p.join()
def main_test(
sharders: List[ModuleSharder[nn.Module]],
backend: str,
world_size: int,
compiled: bool,
) -> None:
model = TestModel()
inputs = [ModelInput.generate(
batch_size=1200,
world_size=world_size,
num_float_features=model.dense_in_feature,
tables=model.tables,
weighted_tables=model.weighted_tables,
)]
run_multi_process_test(
callable=sharding_single_rank_test,
world_size=world_size,
model=model,
inputs=inputs,
sharders=sharders,
backend=backend,
compiled=compiled,
)
if __name__ == "__main__":
sharders = [create_test_sharder("embedding_bag_collection", "column_wise", "dense")]
backend = "nccl"
world_size = 2
main_test(
sharders = sharders,
backend = backend,
world_size = world_size,
compiled = True,
)
Log:
The error message is copied below.
Traceback (most recent call last):
File "/root/miniconda3/envs/torchrec/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/root/miniconda3/envs/torchrec/lib/python3.11/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/mnt/tests/reproduce_nccl_tensor_must_be_contiguous.py", line 203, in sharding_single_rank_test
dist.all_gather(all_over_r, over_r, group=ctx.pg)
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 72, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 2617, in all_gather
work = group.allgather([tensor_list], [tensor])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Tensors must be contiguous
The text was updated successfully, but these errors were encountered:
Description
I’m using torch.compile with DistributedModelParallel. Running the below code results in a ValueError: Tensors must be contiguous. This error seems to be specific to the model and the world size. I would expect to see no such errors, like when I run the code with other world sizes.
Enviroment:
python=3.11.8, torch= '2.2.2+cu121', torchrec= '0.6.0+cu121'.
Reproduction code:
Log:
The error message is copied below.
The text was updated successfully, but these errors were encountered: