fixed synax errors; now channel mismatch issue

This commit is contained in:
Lincoln Stein 2022-10-25 00:47:13 -04:00
parent be8a992b85
commit a2e53892ec

View File

@ -2,6 +2,7 @@
import torch import torch
import numpy as np import numpy as np
from einops import repeat
from PIL import Image from PIL import Image
from ldm.invoke.generator.base import downsampling from ldm.invoke.generator.base import downsampling
from ldm.invoke.generator.img2img import Img2Img from ldm.invoke.generator.img2img import Img2Img
@ -33,6 +34,7 @@ class Omnibus(Img2Img,Txt2Img):
Return value depends on the seed at the time you call it. Return value depends on the seed at the time you call it.
""" """
self.perlin = perlin self.perlin = perlin
num_samples = 1
sampler.make_schedule( sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
@ -77,7 +79,7 @@ class Omnibus(Img2Img,Txt2Img):
masked_image, masked_image,
prompt=prompt, prompt=prompt,
device=model.device, device=model.device,
num_samples=1 num_samples=num_samples,
) )
c = model.cond_stage_model.encode(batch["txt"]) c = model.cond_stage_model.encode(batch["txt"])
@ -86,7 +88,7 @@ class Omnibus(Img2Img,Txt2Img):
for ck in model.concat_keys: for ck in model.concat_keys:
cc = batch[ck].float() cc = batch[ck].float()
if ck != model.masked_image_key: if ck != model.masked_image_key:
bchw = [num_samples, 4, h//8, w//8] bchw = [num_samples, 4, height//8, width//8]
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else: else:
cc = model.get_first_stage_encoding(model.encode_first_stage(cc)) cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
@ -99,7 +101,7 @@ class Omnibus(Img2Img,Txt2Img):
# uncond cond # uncond cond
uc_cross = model.get_unconditional_conditioning(num_samples, "") uc_cross = model.get_unconditional_conditioning(num_samples, "")
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]} uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
shape = [model.channels, h//8, w//8] shape = [model.channels, height//8, width//8]
samples, = sampler.sample( samples, = sampler.sample(
batch_size = 1, batch_size = 1,
@ -121,6 +123,7 @@ class Omnibus(Img2Img,Txt2Img):
return make_image return make_image
def make_batch_sd( def make_batch_sd(
self,
image, image,
mask, mask,
masked_image, masked_image,