Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batched_transpose with multiple batch dimensions #588

Open
AntonOresten opened this issue May 27, 2024 · 3 comments
Open

batched_transpose with multiple batch dimensions #588

AntonOresten opened this issue May 27, 2024 · 3 comments

Comments

@AntonOresten
Copy link

AntonOresten commented May 27, 2024

Motivation and description

There exists a method for batched_mul that reshapes arrays to allow for an arbitrary number of batch dimensions:

function batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}
    batch_size = size(x)[3:end]
    @assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays."
    x2 = reshape(x, size(x, 1), size(x, 2), :)
    y2 = reshape(y, size(y, 1), size(y, 2), :)
    z = batched_mul(x2, y2)
    return reshape(z, size(z, 1), size(z, 2), batch_size...)
  end

It would be useful to have support for this with batched_transpose and batched_adjoint as well.

Possible Implementation

The existing code is quite sophisticated and "lazy", so something like this wouldn't fly:

batched_transpose(A::AbstractArray{T, N}) where {T <: Real, N} = permutedims(A, (2, 1, 3:N))

I imagine it would be possible to generalize the code beyond three dimensions though. Indexing methods are currently hard-coded. Things like the strides would also need to be generalized:

function Base.strides(A::Union{BatchedTranspose, BatchedAdjoint{<:Real}})
    sp = strides(A.parent)
    (sp[2], sp[1], sp[3:end]...)
end

Is it better to just use PermutedDimsArray?

@AntonOresten
Copy link
Author

AntonOresten commented May 28, 2024

After some thinking and tinkering, I've concluded that PermutedDimsArray works fine.

For my use case however, where I use it to define a custom chain rule, I needed to use the inner constructor with all the type parameters like so:

# permutation needs to be passed as type parameters directly so the type can be inferred
function _batched_transpose(A::AbstractArray{T, N}) where {T, N}
    perm = (2, 1, 3:N...)
    PermutedDimsArray{T, N, perm, perm, typeof(A)}(A)
end

or else I would get an error:

function _batched_transpose(A::AbstractArray{T, N}) where {T, N}
    perm = (2, 1, 3:N...)
    PermutedDimsArray(A, perm)
end

using Test
@inferred _batched_transpose(rand(4, 5, 6))

# output:
ERROR: return type PermutedDimsArray{Float64, 3, (2, 1, 3), (2, 1, 3), Array{Float64, 3}} does not match inferred return type PermutedDimsArray{Float64, 3, _A, _B, Array{Float64, 3}} where {_A, _B}

I suspect this is because of the splat in the regular constructor:

function PermutedDimsArray(data::AbstractArray{T,N}, perm) where {T,N}
    length(perm) == N || throw(ArgumentError(string(perm, " is not a valid permutation of dimensions 1:", N)))
    iperm = invperm(perm)
    PermutedDimsArray{T,N,(perm...,),(iperm...,),typeof(data)}(data)
end

This isn't really related to the issue, but I figured I'd include it for documentation purposes.😄

EDIT: it's probably not the splatting itself, but the fact that the permutation is derived from the type parameter N, so it's essentially a constant.

EDIT 2: somewhat expectedly, CUDA doesn't like this, as it ends up wanting to do scalar indexing.

@AntonOresten
Copy link
Author

AntonOresten commented Jun 7, 2024

When an array with multiple batch dimensions needs to be transposed for use in batched_mul, I found this to work alright:

function batched_mul_transpose1(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}
    batch_size = size(x)[3:end]
    @assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays."
    x2 = reshape(x, size(x, 1), size(x, 2), :) |> batched_transpose # call batched_transpose after flattening batch dimensions
    y2 = reshape(y, size(y, 1), size(y, 2), :)
    z = batched_mul(x2, y2)
    return reshape(z, size(z, 1), size(z, 2), batch_size...)
end

This would be the same as batched_mul(batched_transpose(x), y).

@AntonOresten AntonOresten changed the title Multiple batch dimensions for batched_adjoint Multiple batch dimensions for batched_transpose Jun 7, 2024
@AntonOresten AntonOresten changed the title Multiple batch dimensions for batched_transpose batched_transpose with multiple batch dimensions Jun 7, 2024
@mcabbott
Copy link
Member

It's tricky. Perhaps there need to me methods of batched_mul accepting these >3 dimension BatchedAdjoint types, so that the reshape affects the wrapped Array (or CuArray) rather than composing another wrapper (which CUDA doesn't like, as you saw).

Or perhaps the reshaping to 3D should be done by a utility function which knows about BatchedAdjoint, not just reshape.

Xref #391 about other questions about batched_mul accepting >3 dimensions.

(Also, some regret that we didn't go with an interface like batched_mul(A, adjoint, B), instead of array wrappers!)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants