Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature request]: Load and Use Wrapped Models 'As Is' From External Pretrained Checkpoint #223

Open
6 tasks
melo-gonzo opened this issue May 23, 2024 · 0 comments
Labels
enhancement New feature or request

Comments

@melo-gonzo
Copy link
Collaborator

Feature/behavior summary

MatSciML offers various models (M3GNet, TensorNet, MACE) which are warppers around the upstream implementations, however there is currently no clean way to load up a pretrained checkpoint and use it 'as is' with the default model architecture. The hang ups arise from:

  • MatSciML creating output heads which are always added to models
  • MatSciML expecting Embeddings objects returned from the encoder forward pass

Request attributes

  • Would this be a refactor of existing code?
  • Does this proposal require new package dependencies?
  • Would this change break backwards compatibility?
  • Does this proposal include a new model?
  • Does this proposal include a new dataset?
  • Does this proposal include a new task/workflow?

Related issues

No response

Solution description

Two options to work around this:

  1. modify the existing behavior to toggle on/off the creation of output heads, as well as returning the default output from the wrapped model.
  2. Create a new 'task' which removes all of the output head creation and expected forward pass outputs, and runs the wrapped model 'as is'.

Below is an example of how 1. was implemented by subclassing MatSciML tasks and model wrappers. Note that this relies on #222 to load the proper MACE submodule (ScaleShiftMACE). The model checkpoint 2023-12-10-mace-128-L0_epoch-199.model may be used with example.

import torch
from e3nn.o3 import Irreps
from mace.modules import ScaleShiftMACE
from mace.modules.blocks import RealAgnosticResidualInteractionBlock
from torch import nn

from matsciml.common.types import AbstractGraph, BatchDict
from matsciml.datasets import LiPSDataset
from matsciml.datasets.transforms import (
    PeriodicPropertiesTransform,
    PointCloudToGraphTransform,
)
from matsciml.models.base import ForceRegressionTask
from matsciml.models.pyg.mace import MACEWrapper


class ForceRegressionTask(ForceRegressionTask):
    def forward(self, batch):
        outputs = self.encoder(batch)
        return outputs


class OGMACE(MACEWrapper):
    def _forward(
        self,
        graph: AbstractGraph,
        node_feats: torch.Tensor,
        pos: torch.Tensor,
        **kwargs,
    ):
        mace_data = {
            "positions": pos,
            "node_attrs": node_feats,
            "ptr": graph.ptr,
            "cell": kwargs["cell"],
            "shifts": kwargs["shifts"],
            "batch": graph.batch,
            "edge_index": graph.edge_index,
        }
        outputs = self.encoder(
            mace_data,
            training=self.training,
            compute_force=True,
            compute_virials=False,
            compute_stress=False,
            compute_displacement=False,
        )
        # node_embeddings = outputs["node_feats"]
        # graph_embeddings = self.readout(node_embeddings, graph.batch)
        # return Embeddings(graph_embeddings, node_embeddings)
        return outputs

    def forward(self, batch: BatchDict):
        input_data = self.read_batch(batch)
        outputs = self._forward(**input_data)
        return outputs


available_models = {
    "mace": {
        "encoder_class": OGMACE,
        "encoder_kwargs": {
            "mace_module": ScaleShiftMACE,
            "num_atom_embedding": 89,
            "r_max": 6.0,
            "num_bessel": 10,
            "num_polynomial_cutoff": 5.0,
            "max_ell": 3,
            "interaction_cls": RealAgnosticResidualInteractionBlock,
            "interaction_cls_first": RealAgnosticResidualInteractionBlock,
            "num_interactions": 2,
            "atom_embedding_dim": 128,
            "MLP_irreps": Irreps("16x0e"),
            "avg_num_neighbors": 10.0,
            "correlation": 3,
            "radial_type": "bessel",
            "gate": nn.Identity(),
            "atomic_inter_scale": 0.804154,
            "atomic_inter_shift": 0.164097,
            ###
            # fmt: off
            "atomic_energies": torch.Tensor([-3.6672, -1.3321, -3.4821, -4.7367, 
                                             -7.7249, -8.4056, -7.3601, -7.2846, 
                                             -4.8965, 0.0000, -2.7594, -2.8140, 
                                             -4.8469, -7.6948, -6.9633, -4.6726, 
                                             -2.8117, -0.0626, -2.6176, -5.3905, 
                                             -7.8858, -10.2684, -8.6651, -9.2331, 
                                             -8.3050, -7.0490, -5.5774, -5.1727, 
                                             -3.2521, -1.2902, -3.5271, -4.7085, 
                                             -3.9765, -3.8862, -2.5185, 6.7669, 
                                             -2.5635, -4.9380, -10.1498, -11.8469, 
                                             -12.1389, -8.7917, -8.7869, -7.7809,
                                             -6.8500, -4.8910, -2.0634, -0.6396, 
                                             -2.7887, -3.8186, -3.5871, -2.8804, 
                                             -1.6356, 9.8467, -2.7653, -4.9910, 
                                             -8.9337, -8.7356, -8.0190, -8.2515,
                                             -7.5917, -8.1697, -13.5927, -18.5175, 
                                             -7.6474, -8.1230, -7.6078, -6.8503,
                                             -7.8269, -3.5848, -7.4554, -12.7963,
                                             -14.1081, -9.3549, -11.3875, -9.6219, 
                                             -7.3244, -5.3047, -2.3801, 0.2495, -2.3240,
                                             -3.7300, -3.4388, -5.0629, -11.0246, 
                                             -12.2656, -13.8556, -14.9331, -15.2828])
            # fmt: on
        },
        "output_kwargs": {"lazy": False, "input_dim": 256, "hidden_dim": 256},
    }
}

ckpt = "2023-12-10-mace-128-L0_epoch-199.model"

model = ForceRegressionTask(**available_models["mace"])
model.encoder.encoder.load_state_dict(
    torch.load(ckpt, map_location=torch.device("cpu")).state_dict(), strict=True
)

transforms = [
    PeriodicPropertiesTransform(cutoff_radius=6.5, adaptive_cutoff=True),
    PointCloudToGraphTransform(
        "pyg",
        node_keys=["pos", "atomic_numbers"],
    ),
]

dset = LiPSDataset.from_devset(transforms=transforms)
sample = LiPSDataset.collate_fn([dset.__getitem__(0)])

outputs = model(sample)
print(outputs)

Additional notes

No response

@melo-gonzo melo-gonzo added the enhancement New feature or request label May 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant