Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] pooling and aggregation operations #2257

Open
xiexbing opened this issue Jul 30, 2024 · 12 comments
Open

[Question] pooling and aggregation operations #2257

xiexbing opened this issue Jul 30, 2024 · 12 comments

Comments

@xiexbing
Copy link

in the forward pass, in the table wise sharding, when pooling is executed? is it after alltoall communication? and executed on trainer local? where can I see the exact code in torchrec code base?

in the backward pass, in the table wise sharding, when sorting and aggregation is executed? can you please point the lines in code base.

@JacoCheung
Copy link

JacoCheung commented Jul 30, 2024

in the forward pass, in the table wise sharding, when pooling is executed?

The pooling is done in FBGEMM_GPU, not in torchrec itself. For example: https://pytorch.org/FBGEMM/fbgemm_gpu-python-api/table_batched_embedding_ops.html

@JacoCheung
Copy link

The backward key sortion and gradient reduction are also done inside fbgemm_gpu.

@xiexbing
Copy link
Author

xiexbing commented Jul 30, 2024 via email

@JacoCheung
Copy link

JacoCheung commented Jul 30, 2024

It's written in cpp, specifically, cuda source code.

  • The sort
  • The aggregation is done inside CUDA kernels. (The source code is a templatized file and not readable, you can build fbgemm_gpu on your own and check the generated files. ) For example

This template file defines the autograd function that embraces forward and backward entry. A generated file example from the template file.

@xiexbing
Copy link
Author

xiexbing commented Jul 31, 2024 via email

@xiexbing
Copy link
Author

xiexbing commented Jul 31, 2024 via email

@JacoCheung
Copy link

sorry, a followup question for the forward pass communication and pooling. I profiled the forward pass with both nsight and torch profiler. I see the all2all calls at python level, and the low level NCCL calls as SendReceive for all2all, but I didn't see any calls map to the pooling. Does the pooling actually happen within SendReceive? or it is after SendReceive and within all2all? or it is after all2all, but just not appropriate profiled.

The pooling is done before all2all. Maybe this figure is clearer
image

The first all2all is for lookup keys which is fed into fbgemm and fbgemm will do pooling for you. The second all2all performs all2all on pooled embedding.
So the pooling happens actually in CUDA kernels.
image

@xiexbing
Copy link
Author

xiexbing commented Jul 31, 2024 via email

@JacoCheung
Copy link

JacoCheung commented Jul 31, 2024

One correction: For RW, there is no the 2nd all2all, instead, it should be a ReduceScatter. (Table-wise should contain all2all)

1, can I assume the first all2all actually communicate about the KJT per
batch?

AFAIK, Yes. But be noted that a KJT all2all usually is composed of several tensor all2all.

if a sample has the indices across
multiple shards (means the embedding vectors stored on different GPUs), how
to do pooling?

There should be a collective. For rw sharding, it could be a reduce scatter.

why not have only one all2all to let each KJT batch owner have the individual embedding vectors (based on
their keys) and let them launch cuda kernel to do pooling locally.

Initially, each rank has its dp input. The first all2all is used for sending keys to the corresponding sharding for lookup. I don't think we can skip it. As I clarified, the 2nd should be a RS for RW.
The local lookup result from fbgemm is a dense tensor which includes all bags, so it's ready for RS. If the fbgemm does not do pooling and return a jagged tensor, then we need a jagged tensor all2all + local reduce.

@JacoCheung
Copy link

JacoCheung commented Jul 31, 2024

before all2all, how to remove the duplicated embedding indices per table.

For embeddingbag, I think there's no dedpulication. But for embedding, there is a deduplication in input_dist. You need to explicitly set use_index_dedup=True for EmbeddingCollectionSharder to enable unique.

@xiexbing
Copy link
Author

xiexbing commented Aug 1, 2024

thanks for the clarification! very helpful.

@xiexbing xiexbing closed this as completed Aug 1, 2024
@xiexbing xiexbing reopened this Aug 1, 2024
@xiexbing
Copy link
Author

xiexbing commented Aug 1, 2024

sorry, forget to ask some details on the collective communication operators. Just for confirmation: in the forward pass, for the communication about embedding indices, there are multiple all2all calls, each call for a jagged tensor (e.g. one call for values, one call for lengths, one call for lengths per key). In both forward and backward pass, for the communications for embedding vectors, there is only one call for all batches across all keys and tables. if the 2nd statement is correct, when the system scales up, the data size of of the call will increase n^2 (n is the number of ranks in the system). Why not cut the single call to multiple calls with each call communicate about 64MB (for best network bandwidth efficiency).

@xiexbing xiexbing changed the title pooling and aggregation operations [Question] pooling and aggregation operations Aug 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants