fixed synax errors; now channel mismatch issue

This commit is contained in:
Lincoln Stein 2022-10-25 00:47:13 -04:00
parent be8a992b85
commit a2e53892ec

View File

@ -2,6 +2,7 @@
import torch
import numpy as np
from einops import repeat
from PIL import Image
from ldm.invoke.generator.base import downsampling
from ldm.invoke.generator.img2img import Img2Img
@ -33,6 +34,7 @@ class Omnibus(Img2Img,Txt2Img):
Return value depends on the seed at the time you call it.
"""
self.perlin = perlin
num_samples = 1
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
@ -77,7 +79,7 @@ class Omnibus(Img2Img,Txt2Img):
masked_image,
prompt=prompt,
device=model.device,
num_samples=1
num_samples=num_samples,
)
c = model.cond_stage_model.encode(batch["txt"])
@ -86,7 +88,7 @@ class Omnibus(Img2Img,Txt2Img):
for ck in model.concat_keys:
cc = batch[ck].float()
if ck != model.masked_image_key:
bchw = [num_samples, 4, h//8, w//8]
bchw = [num_samples, 4, height//8, width//8]
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
@ -99,7 +101,7 @@ class Omnibus(Img2Img,Txt2Img):
# uncond cond
uc_cross = model.get_unconditional_conditioning(num_samples, "")
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
shape = [model.channels, h//8, w//8]
shape = [model.channels, height//8, width//8]
samples, = sampler.sample(
batch_size = 1,
@ -121,6 +123,7 @@ class Omnibus(Img2Img,Txt2Img):
return make_image
def make_batch_sd(
self,
image,
mask,
masked_image,