Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug][Dynamic Embedding] improper optimizier state_dict momentum2 key while constructing PSCollection #2177

Open
JacoCheung opened this issue Jun 26, 2024 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@JacoCheung
Copy link

JacoCheung commented Jun 26, 2024

Describe the bug

A PSCollection should contain optimizer states besides weights. The optimizer states tensors are obtained directly from EmbeddingCollection Module.

However, the sharded_module.fused_optimizer.state_dict()['state'] does not contain key {table_name}.momentum2 because

  1. TBE::get_optimizer_state() which is used by PSCollection will not return key like xxx.momentum1 or xxx.momentum2. They are customized by TBE.
  2. The states keys are renamed by torchrec::EmbeddingFusedOptimizer. The first state falls back on xxx.momentum1 while the left keys are copied from above retrived results.

See the below illustration where optimizer is Adam. The expected number of state tensors should be 2, but the it eventually gives momentum1 and leaves momentum2 (which is synonymously exp_avg_sq) out.

opt report

It will pose impact on all kinds of optimizer that contains momentum2.

@iamzainhuda
Copy link
Contributor

Yup this shouldn't be intended behaviour. Feel free to submit a PR I can go ahead and review.

@iamzainhuda iamzainhuda self-assigned this Jun 26, 2024
@iamzainhuda iamzainhuda added the bug Something isn't working label Jun 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants