mps backend support

#5
by pypry - opened

I'm trying to make the code support the MPS backend, but I found that some parts use amp.autocast to cast computation to float32. However, MPS's autocast only supports float16 and bfloat16. Is there any way to fix this issue?

As a quick try, you can omit the fp32 casting. This might slightly reduce performance, but the impact should be minor.

Sign up or log in to comment