Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistency results of quantized sparse layer and weighted sparse layer. #2044

Open
jiannanWang opened this issue May 27, 2024 · 0 comments

Comments

@jiannanWang
Copy link

There are large inconsistency results when running a single forward pass of a torchrec model (a dense layer, a sparse layer, a weighted sparse layer, and an over layer) under distributed and non-distributed settings.

Below is the code to reproduce the inconsistency. In the code, I created a model and inputs and quantized the model with dtype = torch.qint8 and output_dtype = torch.qint8. I then run a forward pass with the distributed model and the non-distributed model. Since the model's weights are copied, I expect their results to be the same. However, there are large inconsistencies in the results. The inconsistencies are shown in the log. The environment is Python 3.10.14, torch 2.3.0+cu121, torchrec 0.7.0

Note that this code is updated from torchrec 0.2.0. When running the below code in 0.2.0, the sparse layer prints NaN output.

The inconsistencies should be bugs because the distributed model and the non-distributed model have the same parameters and inputs. When running a single forward pass, they should return the same results.

Reproduction code

import copy
import traceback
from typing import Any, Type, Dict, List, Optional, Union, Tuple

import torch
import torch.nn as nn
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import (
    EmbeddingShardingPlanner,
    Topology,
)
from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder
from torchrec.distributed.test_utils.test_model import (
    ModelInput,
)
from torchrec.distributed.types import (
    ModuleSharder,
    ShardingEnv,
    ShardingPlan,
)
from torchrec.sparse.jagged_tensor import KeyedTensor
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.distributed.test_utils.test_sharding import copy_state_dict
from torchrec.inference.modules import quantize_embeddings

class ReproduceModel(nn.Module):
    def __init__(self):
        super().__init__()

        table_params = [
            [187, 128],
            [844, 288],
            [310, 444],
            [870, 20],
            [704, 512],
        ]

        weighted_table_params = [
            [975, 316],
            [439, 612],
            [855, 284],
        ]

        self.tables = [
            EmbeddingBagConfig(
                num_embeddings=table_params[i][0],
                embedding_dim=table_params[i][1],
                name="table_" + str(i),
                feature_names=["feature_" + str(i)],
                data_type=torch.int64,
            )
            for i in range(len(table_params))
        ]
        self.weighted_tables = [
            EmbeddingBagConfig(
                num_embeddings=weighted_table_params[i][0],
                embedding_dim=weighted_table_params[i][1],
                name="weighted_table_" + str(i),
                feature_names=["weighted_feature_" + str(i)],
            )
            for i in range(len(weighted_table_params))
        ]

        self.dense = nn.Linear(in_features=984, out_features=551, bias=True)

        self.sparse = EmbeddingBagCollection(
            tables=self.tables,
            is_weighted=False,
        )
        self.sparse_weighted = EmbeddingBagCollection(
            tables=self.weighted_tables, 
            is_weighted=True,
        )

        in_features_concat = (
            self.dense.out_features
            + sum(
                [
                    table.embedding_dim * len(table.feature_names)
                    for table in self.tables
                ]
            )
            + sum(
                [
                    table.embedding_dim * len(table.feature_names)
                    for table in self.weighted_tables
                ]
            )
        )

        self.over = nn.Linear(in_features=in_features_concat, out_features=21, bias=True)

    def forward(
        self,
        input,
    ):
        dense_r = self.dense(input.float_features)
        sparse_r = self.sparse(input.idlist_features)
        sparse_weighted_r = self.sparse_weighted(input.idscore_features)
        result = KeyedTensor(
            keys=sparse_r.keys() + sparse_weighted_r.keys(),
            length_per_key=sparse_r.length_per_key()
            + sparse_weighted_r.length_per_key(),
            values=torch.cat([sparse_r.values(), sparse_weighted_r.values()], dim=1),
        )

        _features = [
            feature for table in self.tables for feature in table.feature_names
        ]
        _weighted_features = [
            feature for table in self.weighted_tables for feature in table.feature_names
        ]

        ret_list = []
        ret_list.append(dense_r)
        for feature_name in _features:
            ret_list.append(result[feature_name])
        for feature_name in _weighted_features:
            ret_list.append(result[feature_name])
        ret_concat = torch.cat(ret_list, dim=1)

        over_r = self.over(ret_concat)
        pred = torch.sigmoid(torch.mean(over_r, dim=1))
        if self.training:
            return (
                torch.nn.functional.binary_cross_entropy_with_logits(pred, input.label),
                pred, (dense_r, sparse_r, sparse_weighted_r, over_r),
            )
        else:
            return pred, (dense_r, sparse_r, sparse_weighted_r, over_r)


def quantize_embeddings(
    module: nn.Module,
    dtype: torch.dtype,
    inplace: bool,
    additional_qconfig_spec_keys: Optional[List[Type[nn.Module]]] = None,
    additional_mapping: Optional[Dict[Type[nn.Module], Type[nn.Module]]] = None,
    output_dtype: torch.dtype = torch.float32,
) -> nn.Module:
    import torch.quantization as quant
    import torchrec as trec
    import torchrec.quant as trec_quant
    qconfig = quant.QConfig(
        activation=quant.PlaceholderObserver.with_args(dtype=output_dtype),
        weight=quant.PlaceholderObserver.with_args(dtype=dtype),
    )
    qconfig_spec: Dict[Type[nn.Module], quant.QConfig] = {
        trec.EmbeddingBagCollection: qconfig,
    }
    mapping: Dict[Type[nn.Module], Type[nn.Module]] = {
        trec.EmbeddingBagCollection: trec_quant.EmbeddingBagCollection,
    }
    if additional_qconfig_spec_keys is not None:
        for t in additional_qconfig_spec_keys:
            qconfig_spec[t] = qconfig
    if additional_mapping is not None:
        mapping.update(additional_mapping)
    return quant.quantize_dynamic(
        module,
        qconfig_spec=qconfig_spec,
        mapping=mapping,
        inplace=inplace,
    )


def sharding_single_rank_test(
    world_size: int,
    model,
    inputs,
    sharders: List[ModuleSharder[nn.Module]],
    quant_dtype = None,
    quant_output_dtype = None,
) -> None:
    device = torch.device("cuda:0")
    model = model.to(device)
    
    model = quantize_embeddings(model, dtype=quant_dtype, inplace=True, output_dtype=quant_output_dtype)

    global_model = copy.deepcopy(model)

    global_model = global_model.to(device)
    global_input = inputs[0][0].to(device)

    local_model = copy.deepcopy(model)
        
    planner = EmbeddingShardingPlanner(
        topology=Topology(
            world_size, device.type
        )
    )
    plan: ShardingPlan = planner.plan(local_model, sharders)

    local_model = DistributedModelParallel(
        local_model,
        env=ShardingEnv.from_local(world_size=world_size, rank=0),
        plan=plan,
        sharders=sharders,
        device=device,
        init_data_parallel=False,
    )

    copy_state_dict(local_model.state_dict(), global_model.state_dict())
    
    local_pred, (local_dense_r, local_sparse_r, local_sparse_weighted_r, local_over_r) = gen_full_pred_after_one_step(local_model, global_input)

    global_pred, (global_dense_r, global_sparse_r, global_sparse_weighted_r, global_over_r) = gen_full_pred_after_one_step(global_model, global_input)

    print("Linf: ", torch.max(torch.abs(global_pred - local_pred)))
    print("Linf dense: ", torch.max(torch.abs(global_dense_r - local_dense_r)))
    print("Linf sparse: ", torch.max(torch.abs(local_sparse_r.values() - global_sparse_r.values())))
    print("Linf sparse weighted: ", torch.max(torch.abs(local_sparse_weighted_r.values() - global_sparse_weighted_r.values())))
    print("Linf over: ", torch.max(torch.abs(global_over_r - local_over_r)))

def gen_full_pred_after_one_step(
    model: nn.Module,
    input: ModelInput,
) -> torch.Tensor:
    # Run a forward pass of the global model.
    with torch.no_grad():
        model.train(False)
        full_pred, intermediate_list = model(input)
        return full_pred, intermediate_list

from torchrec.distributed.embedding_types import EmbeddingTableConfig
from typing import Protocol, cast
class ModelInputCallable(Protocol):
    def __call__(
        self,
        batch_size: int,
        world_size: int,
        num_float_features: int,
        tables: Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]],
        weighted_tables: Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]],
        pooling_avg: int = 10,
        dedup_tables: Optional[
            Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]]
        ] = None,
        variable_batch_size: bool = False,
        long_indices: bool = True,
    ) -> Tuple["ModelInput", List["ModelInput"]]: ...

def main_test_quant(
    sharders: List[ModuleSharder[nn.Module]],
    world_size: int = 2,
    quant_dtype = None,
    quant_output_dtype = None,
) -> None:
    model = ReproduceModel()
    inputs = [
        (
            cast(ModelInputCallable, ModelInput.generate)(
                world_size=world_size,
                tables=model.tables,
                weighted_tables=model.weighted_tables,
                num_float_features=model.dense.in_features,
                batch_size=256,
            )
        )
    ]
    inputs[0][0].idlist_features._values = inputs[0][0].idlist_features._values.to(dtype=torch.int32)
    inputs[0][0].idscore_features._values = inputs[0][0].idscore_features._values.to(dtype=torch.int32)

    sharding_single_rank_test(
        world_size=world_size,
        model=model,
        inputs=inputs,
        sharders=sharders,
        quant_dtype = quant_dtype,
        quant_output_dtype = quant_output_dtype,
    )

class TestQuantEBCSharder(QuantEmbeddingBagCollectionSharder):
    def __init__(self, sharding_type: str, kernel_type: str) -> None:
        super().__init__()
        self._sharding_type = sharding_type
        self._kernel_type = kernel_type

    def sharding_types(self, compute_device_type: str) -> List[str]:
        return [self._sharding_type]

    def compute_kernels(
        self, sharding_type: str, compute_device_type: str
    ) -> List[str]:
        return [self._kernel_type]

    @property
    def fused_params(self) -> Optional[Dict[str, Any]]:
        return None
    
def main():
    # backend = "nccl"
    world_size = 3
    
    dtype = torch.qint8
    output_dtype = torch.qint8

    sharding_type = "table_wise"
    kernel_type = "quant"
    sharders = [TestQuantEBCSharder(sharding_type, kernel_type)]

    main_test_quant(
        sharders = sharders,
        world_size = world_size,
        quant_dtype=dtype,
        quant_output_dtype=output_dtype,
    )


if __name__ == "__main__":
    main()

Logs

Linf:  tensor(0.9984, device='cuda:0')
Linf dense:  tensor(0., device='cuda:0')
Linf sparse:  tensor(255, device='cuda:0', dtype=torch.uint8)
Linf sparse weighted:  tensor(255, device='cuda:0', dtype=torch.uint8)
Linf over:  tensor(130.5735, device='cuda:0')

Logs in torchrec 0.2.0

Linf:  tensor(nan, device='cuda:0')
Linf dense:  tensor(0., device='cuda:0')
Linf sparse:  tensor(nan, device='cuda:0')
Traceback (most recent call last):
  File "/root/miniconda3/envs/pt112tr02/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/root/miniconda3/envs/pt112tr02/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/mnt/util/reproduce_quant_nan_2.py", line 312, in <module>
    main()
  File "/mnt/util/reproduce_quant_nan_2.py", line 303, in main
    main_test_quant(
  File "/mnt/util/reproduce_quant_nan_2.py", line 265, in main_test_quant
    sharding_single_rank_test(
  File "/mnt/util/reproduce_quant_nan_2.py", line 213, in sharding_single_rank_test
    print("Linf sparse weighted: ", torch.max(torch.abs(local_sparse_weighted_r.values() - global_sparse_weighted_r.values())))
RuntimeError: The size of tensor a (1212) must match the size of tensor b (1236) at non-singleton dimension 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant