Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Is there FP8 embeddings support for training? #2264

Open
ShijieZZZZ opened this issue Jul 31, 2024 · 2 comments
Open

[Question] Is there FP8 embeddings support for training? #2264

ShijieZZZZ opened this issue Jul 31, 2024 · 2 comments

Comments

@ShijieZZZZ
Copy link

Hello, it looks like EmbeddingBagCollection forces data type to be float32 or float16 during initialization.
https://github.com/pytorch/torchrec/blob/main/torchrec/modules/embedding_modules.py#L179

Is there any support to make embedding be float8? Note, this is for training.

屏幕截图 2024-07-31 160639

@PaulZhang12
Copy link
Contributor

Doesn't look like it no, what is your use case? Feel free to put up a pull request

@ShijieZZZZ
Copy link
Author

ShijieZZZZ commented Aug 5, 2024

Hello @PaulZhang12, thanks for your reply. The use case is a normal deep learning recommendation model training with all the embeddings in FP8 format. The reason I do not use FP32 or FP16 embeddings is because I want to save memory.
A simple example as below:

import torch
import torchrec
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

class myModel(torch.nn.Module):
    def __init__(self, input_size: int, output_size: int):
        super(myModel, self).__init__()

        self.L= torch.nn.Linear(input_size, output_size)
        self.ebc = torchrec.EmbeddingBagCollection(
            device="cpu",
            tables=[
                torchrec.EmbeddingBagConfig(
                    name="t1",
                    embedding_dim=8,
                    num_embeddings=32,
                    feature_names=["f1"],
                    pooling=torchrec.PoolingType.SUM,
                    data_type=torchrec.modules.embedding_configs.DataType.FP8,
                ),
                torchrec.EmbeddingBagConfig(
                    name="t2",
                    embedding_dim=8,
                    num_embeddings=32,
                    feature_names=["f2"],
                    pooling=torchrec.PoolingType.SUM,
                    data_type=torchrec.modules.embedding_configs.DataType.FP8,
                ),
            ],
        ) 

    def forward(self, kjt):

        embeddings = self.ebc(kjt)
        input = [embeddings ["f1"], embeddings ["f2"]]

        cat = torch.cat(input, dim=1)
        output = self.L(cat)
        return output

#Training

model = myModel(input_size=16, output_size=1)
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

for _ in range(1000):
    optimizer.zero_grad()
    kjt = KeyedJaggedTensor(
        keys=["f1", "f2"],
        values=torch.randint(0, 31, (8,)),
        lengths=torch.tensor([2, 2, 1, 3]),
    )

    prediction = model(kjt)
    target = torch.randint(0, 1, (2, 1))
    loss = loss_fn(prediction, target.float())
    loss.backward()
    optimizer.step()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants