mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix dimension errors when inpainting model is used with hires-fix (#2440)
This commit is contained in:
commit
bde5874707
@ -3,10 +3,10 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error
|
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error
|
||||||
|
|
||||||
from ldm.invoke.generator.base import Generator
|
from ldm.invoke.generator.base import Generator
|
||||||
from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline, \
|
from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline, \
|
||||||
@ -129,17 +129,12 @@ class Txt2Img2Img(Generator):
|
|||||||
scaled_height = height
|
scaled_height = height
|
||||||
|
|
||||||
device = self.model.device
|
device = self.model.device
|
||||||
|
channels = self.latent_channels
|
||||||
|
if channels == 9:
|
||||||
|
channels = 4 # we don't really want noise for all the mask channels
|
||||||
|
shape = (1, channels,
|
||||||
|
scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor)
|
||||||
if self.use_mps_noise or device.type == 'mps':
|
if self.use_mps_noise or device.type == 'mps':
|
||||||
return torch.randn([1,
|
return torch.randn(shape, dtype=self.torch_dtype(), device='cpu').to(device)
|
||||||
self.latent_channels,
|
|
||||||
scaled_height // self.downsampling_factor,
|
|
||||||
scaled_width // self.downsampling_factor],
|
|
||||||
dtype=self.torch_dtype(),
|
|
||||||
device='cpu').to(device)
|
|
||||||
else:
|
else:
|
||||||
return torch.randn([1,
|
return torch.randn(shape, dtype=self.torch_dtype(), device=device)
|
||||||
self.latent_channels,
|
|
||||||
scaled_height // self.downsampling_factor,
|
|
||||||
scaled_width // self.downsampling_factor],
|
|
||||||
dtype=self.torch_dtype(),
|
|
||||||
device=device)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user