Mac MPS FP16 fixes

This PR is to allow FP16 precision to work on Macs with MPS. In addition, it centralizes the torch fixes/workarounds
required for MPS into a new backend utility file `mps_fixes.py`. This is conditionally imported in `api_app.py`/`cli_app.py`.

Many MANY thanks to StAlKeR7779 for patiently working to debug and fix these issues.
This commit is contained in:
gogurtenjoyer
2023-07-04 18:05:01 -04:00
parent 92b163e95c
commit 233869b56a
9 changed files with 103 additions and 73 deletions

View File

@ -248,9 +248,6 @@ class InvokeAIDiffuserComponent:
x_twice, sigma_twice, both_conditionings, **kwargs,
)
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
if conditioned_next_x.device.type == "mps":
# prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x
def _apply_standard_conditioning_sequentially(
@ -264,9 +261,6 @@ class InvokeAIDiffuserComponent:
# low-memory sequential path
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
if conditioned_next_x.device.type == "mps":
# prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x
# TODO: looks unused