From 233869b56a9266b344950c6e8e5f873f7eaf284c Mon Sep 17 00:00:00 2001 From: gogurtenjoyer <36354352+gogurtenjoyer@users.noreply.github.com> Date: Tue, 4 Jul 2023 18:05:01 -0400 Subject: [PATCH 1/2] Mac MPS FP16 fixes This PR is to allow FP16 precision to work on Macs with MPS. In addition, it centralizes the torch fixes/workarounds required for MPS into a new backend utility file `mps_fixes.py`. This is conditionally imported in `api_app.py`/`cli_app.py`. Many MANY thanks to StAlKeR7779 for patiently working to debug and fix these issues. --- invokeai/app/api_app.py | 4 ++ invokeai/app/cli_app.py | 4 ++ invokeai/app/invocations/noise.py | 2 +- invokeai/backend/generator/base.py | 32 +++------ invokeai/backend/generator/img2img.py | 5 +- .../stable_diffusion/diffusers_pipeline.py | 68 ++++++++----------- .../diffusion/shared_invokeai_diffusion.py | 6 -- invokeai/backend/util/devices.py | 2 + invokeai/backend/util/mps_fixes.py | 53 +++++++++++++++ 9 files changed, 103 insertions(+), 73 deletions(-) create mode 100644 invokeai/backend/util/mps_fixes.py diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index e14c58bab7..421ead5797 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -28,6 +28,10 @@ from .api.routers import sessions, models, images, boards, board_images from .api.sockets import SocketIO from .invocations.baseinvocation import BaseInvocation +import torch +if torch.backends.mps.is_available(): + import invokeai.backend.util.mps_fixes + # Create the app # TODO: create this all in a method so configuration/etc. can be passed in? app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None) diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index 07193c8500..116fa89ccc 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -53,6 +53,10 @@ from .services.processor import DefaultInvocationProcessor from .services.restoration_services import RestorationServices from .services.sqlite import SqliteItemStorage +import torch +if torch.backends.mps.is_available(): + import invokeai.backend.util.mps_fixes + class CliCommand(BaseModel): command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index c5866f3608..0d62ada34e 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -32,7 +32,7 @@ def get_noise( perlin: float = 0.0, ): """Generate noise for a given image size.""" - noise_device_type = "cpu" if (use_cpu or device.type == "mps") else device.type + noise_device_type = "cpu" if use_cpu else device.type # limit noise to only the diffusion image channels, not the mask channels input_channels = min(latent_channels, 4) diff --git a/invokeai/backend/generator/base.py b/invokeai/backend/generator/base.py index 462b1a4f4b..d267e7fbba 100644 --- a/invokeai/backend/generator/base.py +++ b/invokeai/backend/generator/base.py @@ -570,28 +570,16 @@ class Generator: device = self.model.device # limit noise to only the diffusion image channels, not the mask channels input_channels = min(self.latent_channels, 4) - if self.use_mps_noise or device.type == "mps": - x = torch.randn( - [ - 1, - input_channels, - height // self.downsampling_factor, - width // self.downsampling_factor, - ], - dtype=self.torch_dtype(), - device="cpu", - ).to(device) - else: - x = torch.randn( - [ - 1, - input_channels, - height // self.downsampling_factor, - width // self.downsampling_factor, - ], - dtype=self.torch_dtype(), - device=device, - ) + x = torch.randn( + [ + 1, + input_channels, + height // self.downsampling_factor, + width // self.downsampling_factor, + ], + dtype=self.torch_dtype(), + device=device, + ) if self.perlin > 0.0: perlin_noise = self.get_perlin_noise( width // self.downsampling_factor, height // self.downsampling_factor diff --git a/invokeai/backend/generator/img2img.py b/invokeai/backend/generator/img2img.py index 1cfbeb66c0..b3b0e8f510 100644 --- a/invokeai/backend/generator/img2img.py +++ b/invokeai/backend/generator/img2img.py @@ -88,10 +88,7 @@ class Img2Img(Generator): def get_noise_like(self, like: torch.Tensor): device = like.device - if device.type == "mps": - x = torch.randn_like(like, device="cpu").to(device) - else: - x = torch.randn_like(like, device=device) + x = torch.randn_like(like, device=device) if self.perlin > 0.0: shape = like.shape x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise( diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 8493b4286f..7cf3f22dda 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -361,37 +361,34 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): ): self.enable_xformers_memory_efficient_attention() else: - if torch.backends.mps.is_available(): - # until pytorch #91617 is fixed, slicing is borked on MPS - # https://github.com/pytorch/pytorch/issues/91617 - # fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline. - pass + if self.device.type == "cpu" or self.device.type == "mps": + mem_free = psutil.virtual_memory().free + elif self.device.type == "cuda": + mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device)) else: - if self.device.type == "cpu" or self.device.type == "mps": - mem_free = psutil.virtual_memory().free - elif self.device.type == "cuda": - mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device)) - else: - raise ValueError(f"unrecognized device {self.device}") - # input tensor of [1, 4, h/8, w/8] - # output tensor of [16, (h/8 * w/8), (h/8 * w/8)] - bytes_per_element_needed_for_baddbmm_duplication = ( - latents.element_size() + 4 - ) - max_size_required_for_baddbmm = ( - 16 - * latents.size(dim=2) - * latents.size(dim=3) - * latents.size(dim=2) - * latents.size(dim=3) - * bytes_per_element_needed_for_baddbmm_duplication - ) - if max_size_required_for_baddbmm > ( - mem_free * 3.0 / 4.0 - ): # 3.3 / 4.0 is from old Invoke code - self.enable_attention_slicing(slice_size="max") - else: - self.disable_attention_slicing() + raise ValueError(f"unrecognized device {self.device}") + # input tensor of [1, 4, h/8, w/8] + # output tensor of [16, (h/8 * w/8), (h/8 * w/8)] + bytes_per_element_needed_for_baddbmm_duplication = ( + latents.element_size() + 4 + ) + max_size_required_for_baddbmm = ( + 16 + * latents.size(dim=2) + * latents.size(dim=3) + * latents.size(dim=2) + * latents.size(dim=3) + * bytes_per_element_needed_for_baddbmm_duplication + ) + if max_size_required_for_baddbmm > ( + mem_free * 3.0 / 4.0 + ): # 3.3 / 4.0 is from old Invoke code + self.enable_attention_slicing(slice_size="max") + elif torch.backends.mps.is_available(): + # diffusers recommends always enabling for mps + self.enable_attention_slicing(slice_size="max") + else: + self.disable_attention_slicing() def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False): # overridden method; types match the superclass. @@ -917,20 +914,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype): init_image = init_image.to(device=device, dtype=dtype) with torch.inference_mode(): - if device.type == "mps": - # workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222 - # TODO remove this workaround once kulinseth#222 is merged to pytorch mainline - self.vae.to(CPU_DEVICE) - init_image = init_image.to(CPU_DEVICE) - else: - self._model_group.load(self.vae) + self._model_group.load(self.vae) init_latent_dist = self.vae.encode(init_image).latent_dist init_latents = init_latent_dist.sample().to( dtype=dtype ) # FIXME: uses torch.randn. make reproducible! - if device.type == "mps": - self.vae.to(device) - init_latents = init_latents.to(device) init_latents = 0.18215 * init_latents return init_latents diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index f3b09f6a9f..1175475bba 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -248,9 +248,6 @@ class InvokeAIDiffuserComponent: x_twice, sigma_twice, both_conditionings, **kwargs, ) unconditioned_next_x, conditioned_next_x = both_results.chunk(2) - if conditioned_next_x.device.type == "mps": - # prevent a result filled with zeros. seems to be a torch bug. - conditioned_next_x = conditioned_next_x.clone() return unconditioned_next_x, conditioned_next_x def _apply_standard_conditioning_sequentially( @@ -264,9 +261,6 @@ class InvokeAIDiffuserComponent: # low-memory sequential path unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs) conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs) - if conditioned_next_x.device.type == "mps": - # prevent a result filled with zeros. seems to be a torch bug. - conditioned_next_x = conditioned_next_x.clone() return unconditioned_next_x, conditioned_next_x # TODO: looks unused diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index 615209d98d..bf4eead3a0 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -28,6 +28,8 @@ def choose_precision(device: torch.device) -> str: device_name = torch.cuda.get_device_name(device) if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name): return "float16" + elif device.type == "mps": + return "float16" return "float32" diff --git a/invokeai/backend/util/mps_fixes.py b/invokeai/backend/util/mps_fixes.py new file mode 100644 index 0000000000..b9900dffc7 --- /dev/null +++ b/invokeai/backend/util/mps_fixes.py @@ -0,0 +1,53 @@ +import torch + + +if torch.backends.mps.is_available(): + torch.empty = torch.zeros + + +_torch_layer_norm = torch.nn.functional.layer_norm +def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05): + if input.device.type == "mps" and input.dtype == torch.float16: + input = input.float() + if weight is not None: + weight = weight.float() + if bias is not None: + bias = bias.float() + return _torch_layer_norm(input, normalized_shape, weight, bias, eps).half() + else: + return _torch_layer_norm(input, normalized_shape, weight, bias, eps) + +torch.nn.functional.layer_norm = new_layer_norm + + +_torch_tensor_permute = torch.Tensor.permute +def new_torch_tensor_permute(input, *dims): + result = _torch_tensor_permute(input, *dims) + if input.device == "mps" and input.dtype == torch.float16: + result = result.contiguous() + return result + +torch.Tensor.permute = new_torch_tensor_permute + + +_torch_lerp = torch.lerp +def new_torch_lerp(input, end, weight, *, out=None): + if input.device.type == "mps" and input.dtype == torch.float16: + input = input.float() + end = end.float() + if isinstance(weight, torch.Tensor): + weight = weight.float() + if out is not None: + out_fp32 = torch.zeros_like(out, dtype=torch.float32) + else: + out_fp32 = None + result = _torch_lerp(input, end, weight, out=out_fp32) + if out is not None: + out.copy_(out_fp32.half()) + del out_fp32 + return result.half() + + else: + return _torch_lerp(input, end, weight, out=out) + +torch.lerp = new_torch_lerp \ No newline at end of file From 169ff6368b05a924247cfb0a4e7d7a47487a6886 Mon Sep 17 00:00:00 2001 From: gogurtenjoyer <36354352+gogurtenjoyer@users.noreply.github.com> Date: Wed, 5 Jul 2023 17:47:23 -0400 Subject: [PATCH 2/2] Update mps_fixes.py - additional torch op for nodes This fixes scaling in the nodes UI. --- invokeai/backend/util/mps_fixes.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/invokeai/backend/util/mps_fixes.py b/invokeai/backend/util/mps_fixes.py index b9900dffc7..1fc58f9c98 100644 --- a/invokeai/backend/util/mps_fixes.py +++ b/invokeai/backend/util/mps_fixes.py @@ -50,4 +50,14 @@ def new_torch_lerp(input, end, weight, *, out=None): else: return _torch_lerp(input, end, weight, out=out) -torch.lerp = new_torch_lerp \ No newline at end of file +torch.lerp = new_torch_lerp + + +_torch_interpolate = torch.nn.functional.interpolate +def new_torch_interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): + if input.device.type == "mps" and input.dtype == torch.float16: + return _torch_interpolate(input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias).half() + else: + return _torch_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias) + +torch.nn.functional.interpolate = new_torch_interpolate