diff --git a/ldm/invoke/generator/omnibus.py b/ldm/invoke/generator/omnibus.py index 43192cd152..b6ddbfdb03 100644 --- a/ldm/invoke/generator/omnibus.py +++ b/ldm/invoke/generator/omnibus.py @@ -2,6 +2,7 @@ import torch import numpy as np +from einops import repeat from PIL import Image from ldm.invoke.generator.base import downsampling from ldm.invoke.generator.img2img import Img2Img @@ -33,6 +34,7 @@ class Omnibus(Img2Img,Txt2Img): Return value depends on the seed at the time you call it. """ self.perlin = perlin + num_samples = 1 sampler.make_schedule( ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False @@ -77,7 +79,7 @@ class Omnibus(Img2Img,Txt2Img): masked_image, prompt=prompt, device=model.device, - num_samples=1 + num_samples=num_samples, ) c = model.cond_stage_model.encode(batch["txt"]) @@ -86,7 +88,7 @@ class Omnibus(Img2Img,Txt2Img): for ck in model.concat_keys: cc = batch[ck].float() if ck != model.masked_image_key: - bchw = [num_samples, 4, h//8, w//8] + bchw = [num_samples, 4, height//8, width//8] cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) else: cc = model.get_first_stage_encoding(model.encode_first_stage(cc)) @@ -99,7 +101,7 @@ class Omnibus(Img2Img,Txt2Img): # uncond cond uc_cross = model.get_unconditional_conditioning(num_samples, "") uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]} - shape = [model.channels, h//8, w//8] + shape = [model.channels, height//8, width//8] samples, = sampler.sample( batch_size = 1, @@ -121,6 +123,7 @@ class Omnibus(Img2Img,Txt2Img): return make_image def make_batch_sd( + self, image, mask, masked_image,