diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 696ba03596..1c398fb95d 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -3,10 +3,10 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator ''' import math -from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error from typing import Callable, Optional import torch +from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error from ldm.invoke.generator.base import Generator from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline, \ @@ -128,18 +128,13 @@ class Txt2Img2Img(Generator): scaled_width = width scaled_height = height - device = self.model.device + device = self.model.device + channels = self.latent_channels + if channels == 9: + channels = 4 # we don't really want noise for all the mask channels + shape = (1, channels, + scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor) if self.use_mps_noise or device.type == 'mps': - return torch.randn([1, - self.latent_channels, - scaled_height // self.downsampling_factor, - scaled_width // self.downsampling_factor], - dtype=self.torch_dtype(), - device='cpu').to(device) + return torch.randn(shape, dtype=self.torch_dtype(), device='cpu').to(device) else: - return torch.randn([1, - self.latent_channels, - scaled_height // self.downsampling_factor, - scaled_width // self.downsampling_factor], - dtype=self.torch_dtype(), - device=device) + return torch.randn(shape, dtype=self.torch_dtype(), device=device)