diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 54e9d555af..388d7a3342 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -578,11 +578,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps) return self.check_for_safety(output, dtype=conditioning_data.dtype) - def non_noised_latents_from_image(self, init_image, *, device, dtype): + def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype): init_image = init_image.to(device=device, dtype=dtype) with torch.inference_mode(): + if device.type == 'mps': + # workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222 + # TODO remove this workaround once kulinseth#222 is merged to pytorch mainline + self.vae.to('cpu') + init_image = init_image.to('cpu') init_latent_dist = self.vae.encode(init_image).latent_dist init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible! + if device.type == 'mps': + self.vae.to(device) + init_latents = init_latents.to(device) + init_latents = 0.18215 * init_latents return init_latents