mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fixed synax errors; now channel mismatch issue
This commit is contained in:
parent
be8a992b85
commit
a2e53892ec
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from einops import repeat
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from ldm.invoke.generator.base import downsampling
|
from ldm.invoke.generator.base import downsampling
|
||||||
from ldm.invoke.generator.img2img import Img2Img
|
from ldm.invoke.generator.img2img import Img2Img
|
||||||
@ -33,6 +34,7 @@ class Omnibus(Img2Img,Txt2Img):
|
|||||||
Return value depends on the seed at the time you call it.
|
Return value depends on the seed at the time you call it.
|
||||||
"""
|
"""
|
||||||
self.perlin = perlin
|
self.perlin = perlin
|
||||||
|
num_samples = 1
|
||||||
|
|
||||||
sampler.make_schedule(
|
sampler.make_schedule(
|
||||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||||
@ -77,7 +79,7 @@ class Omnibus(Img2Img,Txt2Img):
|
|||||||
masked_image,
|
masked_image,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
device=model.device,
|
device=model.device,
|
||||||
num_samples=1
|
num_samples=num_samples,
|
||||||
)
|
)
|
||||||
|
|
||||||
c = model.cond_stage_model.encode(batch["txt"])
|
c = model.cond_stage_model.encode(batch["txt"])
|
||||||
@ -86,7 +88,7 @@ class Omnibus(Img2Img,Txt2Img):
|
|||||||
for ck in model.concat_keys:
|
for ck in model.concat_keys:
|
||||||
cc = batch[ck].float()
|
cc = batch[ck].float()
|
||||||
if ck != model.masked_image_key:
|
if ck != model.masked_image_key:
|
||||||
bchw = [num_samples, 4, h//8, w//8]
|
bchw = [num_samples, 4, height//8, width//8]
|
||||||
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
||||||
else:
|
else:
|
||||||
cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
|
cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
|
||||||
@ -99,7 +101,7 @@ class Omnibus(Img2Img,Txt2Img):
|
|||||||
# uncond cond
|
# uncond cond
|
||||||
uc_cross = model.get_unconditional_conditioning(num_samples, "")
|
uc_cross = model.get_unconditional_conditioning(num_samples, "")
|
||||||
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
||||||
shape = [model.channels, h//8, w//8]
|
shape = [model.channels, height//8, width//8]
|
||||||
|
|
||||||
samples, = sampler.sample(
|
samples, = sampler.sample(
|
||||||
batch_size = 1,
|
batch_size = 1,
|
||||||
@ -121,6 +123,7 @@ class Omnibus(Img2Img,Txt2Img):
|
|||||||
return make_image
|
return make_image
|
||||||
|
|
||||||
def make_batch_sd(
|
def make_batch_sd(
|
||||||
|
self,
|
||||||
image,
|
image,
|
||||||
mask,
|
mask,
|
||||||
masked_image,
|
masked_image,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user