Solved
Hello
Your issue here comes from Einsum layer.
In your case, the operation that does not work is einsum("bhqk,bkhd->bqhd")

It seems that we don't fully support Einsum currently, so you need to replace the einsum layers (you have multiple ones in your model by simple matrixes operation instead). for example:
Equivalent Operations:
Instead of einsum("bhqk,bkhd->bqhd", A, B), use:
import torch
# Example tensors
A = torch.randn(batch, heads, query, key) # (b, h, q, k)
B = torch.randn(batch, key, heads, dim) # (b, k, h, d)
# Transpose B to (b, h, k, d) so that k aligns for matmul
B_transposed = B.permute(0, 2, 1, 3) # (b, h, k, d)
# Perform batched matrix multiplication
result = torch.matmul(A, B_transposed) # (b, h, q, d)
# Swap axes to match expected output shape (b, q, h, d)
result = result.permute(0, 2, 1, 3) # (b, q, h, d)
Have a good day,
Julian
Enter your E-mail address. We'll send you an e-mail with instructions to reset your password.
