Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX FE] Support lax.argmax operation for JAX #26671

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

halm-zenger
Copy link

@halm-zenger halm-zenger commented Sep 19, 2024

Details:

  • Support lax.argmax for JAX and create relevant layer test

Tickets:

@halm-zenger halm-zenger requested review from a team as code owners September 19, 2024 02:06
@github-actions github-actions bot added category: TF FE OpenVINO TensorFlow FrontEnd category: JAX FE OpenVINO JAX FrontEnd labels Sep 19, 2024
@sys-openvino-ci sys-openvino-ci added the ExternalPR External contributor label Sep 19, 2024
@@ -16,6 +16,7 @@ namespace jax {
void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs) {
auto inputs = context.inputs();
FRONT_END_OP_CONVERSION_CHECK(inputs.size() >= min_inputs, "Got less inputs than expected");
FRONT_END_OP_CONVERSION_CHECK(inputs.size() <= max_inputs, "Got more inputs than expected");
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I can fix this here, or maybe it's better to move it to a separate PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, let's do it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @rkazants Just to clarify, do you mean I should keep it here or I should move it to a separate PR?

test_data = []
for shape in input_shapes:
rank = len(shape)
# Only [0, rank - 1] are valid axes for lax.argmax
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just note that jax.numpy.argmax can accept negative axis, but jax.lax.argmax cannot.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax.numpy.argmax is a wrapper and it is represented with a sub-graph of lax operations. For us, it is sufficient to check lax operation.

test_data = generate_shape_axis_pairs(input_shapes)

@pytest.mark.parametrize("params", test_data)
@pytest.mark.parametrize("index_dtype", [np.int32, np.int64])
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't handle index_dtype in translate_argmax but everything seems fine. Is it okay?

@rkazants rkazants added this to the 2024.5 milestone Sep 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: JAX FE OpenVINO JAX FrontEnd category: TF FE OpenVINO TensorFlow FrontEnd ExternalPR External contributor
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Good First Issue][JAX FE]: Support jax.lax.argmax operation for JAX
3 participants