From 1bb5b4ab322da5a0021a8804687db6c68388cbe1 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Fri, 27 Jan 2023 11:52:05 -0800 Subject: [PATCH] fix dimension errors when inpainting model is used with hires-fix --- ldm/invoke/generator/txt2img2img.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 47692a6bbb..4923e7daf5 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -3,10 +3,10 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator ''' import math -from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error from typing import Callable, Optional import torch +from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error from ldm.invoke.generator.base import Generator from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline, \ @@ -116,16 +116,20 @@ class Txt2Img2Img(Generator): scaled_height = height device = self.model.device + + channels = self.latent_channels + if channels == 9: + channels = 4 # we don't really want noise for all the mask channels if self.use_mps_noise or device.type == 'mps': return torch.randn([1, - self.latent_channels, + channels, scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor], dtype=self.torch_dtype(), device='cpu').to(device) else: return torch.randn([1, - self.latent_channels, + channels, scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor], dtype=self.torch_dtype(),