mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix img2img by working around pytorch bug (#2458)
horribly, temporarily send the vae to `.cpu()` so that good latents can be produced closes #2418
This commit is contained in:
commit
bd57793a65
@ -597,11 +597,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
|
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
|
||||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||||
|
|
||||||
def non_noised_latents_from_image(self, init_image, *, device, dtype):
|
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
|
||||||
init_image = init_image.to(device=device, dtype=dtype)
|
init_image = init_image.to(device=device, dtype=dtype)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
if device.type == 'mps':
|
||||||
|
# workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222
|
||||||
|
# TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
|
||||||
|
self.vae.to('cpu')
|
||||||
|
init_image = init_image.to('cpu')
|
||||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||||
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
|
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
|
||||||
|
if device.type == 'mps':
|
||||||
|
self.vae.to(device)
|
||||||
|
init_latents = init_latents.to(device)
|
||||||
|
|
||||||
init_latents = 0.18215 * init_latents
|
init_latents = 0.18215 * init_latents
|
||||||
return init_latents
|
return init_latents
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user