Mac MPS FP16 fixes

This PR is to allow FP16 precision to work on Macs with MPS. In addition, it centralizes the torch fixes/workarounds
required for MPS into a new backend utility file `mps_fixes.py`. This is conditionally imported in `api_app.py`/`cli_app.py`.

Many MANY thanks to StAlKeR7779 for patiently working to debug and fix these issues.
This commit is contained in:
gogurtenjoyer
2023-07-04 18:05:01 -04:00
parent 92b163e95c
commit 233869b56a
9 changed files with 103 additions and 73 deletions

View File

@ -32,7 +32,7 @@ def get_noise(
perlin: float = 0.0,
):
"""Generate noise for a given image size."""
noise_device_type = "cpu" if (use_cpu or device.type == "mps") else device.type
noise_device_type = "cpu" if use_cpu else device.type
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(latent_channels, 4)