Skip to content

zorro_mask has order of operations bug

https://github.com/lucidrains/zorro-pytorch/blob/fa2cd8137c15d82417637050ec872ef15cdb6982/zorro_pytorch/zorro_pytorch.py#L288

This line is incorrect as it first combines the mask and the token type matrix before comparing with the FUSION token enum value. It can be corrected to this:

zorro_mask = zorro_mask | (token_types_attend_from == TokenTypes.FUSION.value)