mps backend support
#5
by
pypry
- opened
I'm trying to make the code support the MPS backend, but I found that some parts use amp.autocast to cast computation to float32. However, MPS's autocast only supports float16 and bfloat16. Is there any way to fix this issue?
As a quick try, you can omit the fp32 casting. This might slightly reduce performance, but the impact should be minor.