Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modify mlm example to support all backends #1856

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

Mrutyunjay01
Copy link

With reference to the draft PR, here's an attempt to make to example truly backend-agnostic. In the process, I raised the following issues: keras-team/keras#18410, and keras-team/keras#19665. While the latter got resolved, the former persists still. Kindly review the changes, and let me know how shall we proceed with the issues mentioned.

cc: @fchollet

@@ -336,37 +376,48 @@ def metrics(self):


def create_masked_language_bert_model():
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, shall I subclass the MaskedLanguageModel instead, and get rid of the tf custom training step?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you should be able to remove the train_step() and instead implement the compute_loss() method. If you write the compute_loss() method using only keras.ops functions, it will work with all backends.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that no additional logic is required here, and using SparseCategoricalCrossEntropy as Loss function; does it require an override for compute_loss()?

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

import os

# set backend ["tensorflow", "jax", "torch"]
os.environ["KERAS_BACKEND"] = "tensorflow"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it works with all backends, you can just remove this line.

@@ -336,37 +376,48 @@ def metrics(self):


def create_masked_language_bert_model():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you should be able to remove the train_step() and instead implement the compute_loss() method. If you write the compute_loss() method using only keras.ops functions, it will work with all backends.

@Mrutyunjay01
Copy link
Author

added subclassing implementation. Certain obstacles remain as mentioned:

  1. In BertEncoderLayer class, as it is inherited from layers.Layer, is it possible to access to the sub-layers here? as:
mlm_model = models.load_model("bert_mlm.keras", custom_objects={"MaskedLanguageModel": MaskedLanguageModel})

#mlm_model.layers : 
#[<Embedding name=word_embedding, built=True>,
#<PositionEmbedding name=position_embedding, built=True>,
#<BertEncoderLayer name=bert_encoder_layer, built=True>,
#<Dense name=mlm_cls, built=True>]

mlm_model.get_layer("bert_encoder_layer").get_layer("encoder_0_ffn_layernormalization")

This throws error saying

AttributeError: 'BertEncoderLayer' object has no attribute 'get_layer'
  1. The reason I want to do the above is to reproduce the following:
# Load pretrained bert model
mlm_model = models.load_model(
    "bert_mlm_imdb.keras", custom_objects={"MaskedLanguageBertModel": MaskedLanguageBertModel}
)
pretrained_bert_model = Model(
    mlm_model.input, mlm_model.get_layer("encoder_0_ffn_layernormalization").output
)

This was possible in functional declaration of the Model, but not in sub-classing (in last commit).

Kindly suggest w.r.t last commit

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants