Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Shuffle Inputs Globally in cuGraph-PyG #4606

Open
wants to merge 8 commits into
base: branch-24.10
Choose a base branch
from
3 changes: 3 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def run_train(
directory=train_path,
shuffle=True,
drop_last=True,
global_shuffle=True,
**kwargs,
)

Expand All @@ -217,6 +218,7 @@ def run_train(
shuffle=True,
drop_last=True,
local_seeds_per_call=80000,
global_shuffle=False,
**kwargs,
)

Expand All @@ -229,6 +231,7 @@ def run_train(
directory=valid_path,
shuffle=True,
drop_last=True,
global_shuffle=False,
**kwargs,
)

Expand Down
3 changes: 3 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def run_train(
directory=train_path,
shuffle=True,
drop_last=True,
global_shuffle=True,
**kwargs,
)

Expand All @@ -140,6 +141,7 @@ def run_train(
shuffle=True,
drop_last=True,
local_seeds_per_call=80000,
global_shuffle=False,
**kwargs,
)

Expand All @@ -152,6 +154,7 @@ def run_train(
directory=valid_path,
shuffle=True,
drop_last=True,
global_shuffle=False,
**kwargs,
)

Expand Down
70 changes: 70 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/loader/loader_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from cugraph.utilities.utils import import_optional
from typing import List

torch = import_optional("torch")


def scatter(
t: "torch.Tensor", scatter_perm: List["torch.Tensor"], rank: int, world_size: int
):
"""
t: torch.Tensor
The local tensor being scattered.
scatter_perm: List[torch.Tensor]
The indices to send to each rank.
rank: int
The global rank of this worker.
world_size: int
The total number of workers.
"""

scatter_len = torch.tensor(
[s.numel() for s in scatter_perm], device="cuda", dtype=torch.int64
)

scatter_len_all = [
torch.empty((world_size,), device="cuda", dtype=torch.int64)
for _ in range(world_size)
]
torch.distributed.all_gather(scatter_len_all, scatter_len)

t = t.cuda()
local_tensors = [
torch.empty((scatter_len_all[r][rank],), device="cuda", dtype=torch.int64)
for r in range(world_size)
]

qx = []
for r in range(world_size):
send_rank = (rank + r) % world_size
send_op = torch.distributed.P2POp(
torch.distributed.isend,
t[scatter_perm[send_rank]],
send_rank,
)

recv_rank = (rank - r) % world_size
recv_op = torch.distributed.P2POp(
torch.distributed.irecv,
local_tensors[recv_rank],
recv_rank,
)
qx += torch.distributed.batch_isend_irecv([send_op, recv_op])

for x in qx:
x.wait()

return torch.concat(local_tensors)
54 changes: 50 additions & 4 deletions python/cugraph-pyg/cugraph_pyg/loader/node_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Union, Tuple, Callable, Optional

from cugraph.utilities.utils import import_optional
from .loader_utils import scatter

torch_geometric = import_optional("torch_geometric")
torch = import_optional("torch")
Expand Down Expand Up @@ -50,6 +51,7 @@ def __init__(
batch_size: int = 1,
shuffle: bool = False,
drop_last: bool = False,
global_shuffle: bool = True,
**kwargs,
):
"""
Expand All @@ -74,7 +76,17 @@ def __init__(
always return a Data or HeteroData object.
input_id: OptTensor
See torch_geometric.loader.NodeLoader.

batch_size: int
The size of each batch.
shuffle: bool
Whether to shuffle data into random batches.
drop_last: bool
Whether to drop remaining inputs that can't form a full
batch.
global_shuffle: bool
(cuGraph-PyG only) Whether or not to shuffle globally.
It might make sense to turn this off if comms are slow,
but there may be a penalty to accuracy.
"""
if not isinstance(data, (list, tuple)) or not isinstance(
data[1], cugraph_pyg.data.GraphStore
Expand Down Expand Up @@ -124,8 +136,41 @@ def __init__(
self.__batch_size = batch_size
self.__shuffle = shuffle
self.__drop_last = drop_last
self.__global_shuffle = global_shuffle

def __get_input(self):
_, graph_store = self.__data
if graph_store.is_multi_gpu and self.__shuffle and self.__global_shuffle:
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
scatter_perm = torch.tensor_split(
torch.randperm(
self.__input_data.node.numel(), device="cpu", dtype=torch.int64
),
world_size,
)

new_node = scatter(self.__input_data.node, scatter_perm, rank, world_size)
local_perm = torch.randperm(new_node.numel())
if self.__drop_last:
d = local_perm.numel() % self.__batch_size
local_perm = local_perm[:-d]

return torch_geometric.loader.node_loader.NodeSamplerInput(
input_id=None
if self.__input_data.input_id is None
else scatter(
self.__input_data.input_id, scatter_perm, rank, world_size
)[local_perm],
time=None
if self.__input_data.time is None
else scatter(self.__input_data.time, scatter_perm, rank, world_size)[
local_perm
],
node=new_node[local_perm],
input_type=self.__input_data.input_type,
)

def __iter__(self):
if self.__shuffle:
perm = torch.randperm(self.__input_data.node.numel())
else:
Expand All @@ -135,7 +180,7 @@ def __iter__(self):
d = perm.numel() % self.__batch_size
perm = perm[:-d]

input_data = torch_geometric.loader.node_loader.NodeSamplerInput(
return torch_geometric.loader.node_loader.NodeSamplerInput(
input_id=None
if self.__input_data.input_id is None
else self.__input_data.input_id[perm],
Expand All @@ -146,6 +191,7 @@ def __iter__(self):
input_type=self.__input_data.input_type,
)

def __iter__(self):
return cugraph_pyg.sampler.SampleIterator(
self.__data, self.__node_sampler.sample_from_nodes(input_data)
self.__data, self.__node_sampler.sample_from_nodes(self.__get_input())
)
107 changes: 107 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/tests/loader/test_loader_utils_mg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import pytest

from cugraph.utilities.utils import import_optional, MissingModule


from cugraph.gnn import (
cugraph_comms_init,
cugraph_comms_shutdown,
cugraph_comms_create_unique_id,
)

from cugraph_pyg.loader.loader_utils import scatter

torch = import_optional("torch")
torch_geometric = import_optional("torch_geometric")


def init_pytorch_worker(rank, world_size, cugraph_id):
import rmm

rmm.reinitialize(
devices=rank,
)

import cupy

cupy.cuda.Device(rank).use()
from rmm.allocators.cupy import rmm_cupy_allocator

cupy.cuda.set_allocator(rmm_cupy_allocator)

from cugraph.testing.mg_utils import enable_spilling

enable_spilling()

torch.cuda.set_device(rank)

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size)

cugraph_comms_init(rank=rank, world_size=world_size, uid=cugraph_id, device=rank)


def run_test_loader_utils_scatter(rank, world_size, uid):
init_pytorch_worker(rank, world_size, uid)

num_values_rank = (1 + rank) * 9
local_values = torch.arange(0, num_values_rank) + 9 * (
rank + ((rank * (rank - 1)) // 2)
)

scatter_perm = torch.tensor_split(torch.arange(local_values.numel()), world_size)

new_values = scatter(local_values, scatter_perm, rank, world_size)
print(
rank,
local_values,
new_values,
flush=True,
)

offset = 0
for send_rank in range(world_size):
num_values_send_rank = (1 + send_rank) * 9

expected_values = torch.tensor_split(
torch.arange(0, num_values_send_rank)
+ 9 * (send_rank + ((send_rank * (send_rank - 1)) // 2)),
world_size,
)[rank]

ix_sent = torch.arange(expected_values.numel())
values_rec = new_values[ix_sent + offset].cpu()
offset += values_rec.numel()

assert (values_rec == expected_values).all()

cugraph_comms_shutdown()


@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.mg
def test_loader_utils_scatter():
uid = cugraph_comms_create_unique_id()
world_size = torch.cuda.device_count()

torch.multiprocessing.spawn(
run_test_loader_utils_scatter,
args=(world_size, uid),
nprocs=world_size,
)
Loading