resolve conflicts between PR #1108 and #1243

This commit is contained in:
Lincoln Stein 2022-10-26 15:37:24 -04:00
commit 9b7159720f
17 changed files with 444 additions and 38 deletions

View File

@ -13,6 +13,13 @@ stable-diffusion-1.4:
width: 512
height: 512
default: true
inpainting-1.5:
description: runwayML tuned inpainting model v1.5
weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
config: configs/stable-diffusion/v1-inpainting-inference.yaml
# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
width: 512
height: 512
stable-diffusion-1.5:
config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt

View File

@ -0,0 +1,79 @@
model:
base_learning_rate: 7.5e-05
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: hybrid # important
monitor: val/loss_simple_ema
scale_factor: 0.18215
finetune_keys: null
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ['face', 'man', 'photo', 'africanmale']
per_image_tokens: false
num_vectors_per_token: 1
progressive_words: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

View File

@ -421,7 +421,10 @@ class Generate:
)
# TODO: Hacky selection of operation to perform. Needs to be refactored.
if (init_image is not None) and (mask_image is not None):
if self.sampler.conditioning_key() in ('hybrid','concat'):
print(f'** Inpainting model detected. Will try it! **')
generator = self._make_omnibus()
elif (init_image is not None) and (mask_image is not None):
generator = self._make_inpaint()
elif (embiggen != None or embiggen_tiles != None):
generator = self._make_embiggen()
@ -677,6 +680,7 @@ class Generate:
return init_image,init_mask
# lots o' repeated code here! Turn into a make_func()
def _make_base(self):
if not self.generators.get('base'):
from ldm.invoke.generator import Generator
@ -687,6 +691,7 @@ class Generate:
if not self.generators.get('img2img'):
from ldm.invoke.generator.img2img import Img2Img
self.generators['img2img'] = Img2Img(self.model, self.precision)
self.generators['img2img'].free_gpu_mem = self.free_gpu_mem
return self.generators['img2img']
def _make_embiggen(self):
@ -715,6 +720,15 @@ class Generate:
self.generators['inpaint'] = Inpaint(self.model, self.precision)
return self.generators['inpaint']
# "omnibus" supports the runwayML custom inpainting model, which does
# txt2img, img2img and inpainting using slight variations on the same code
def _make_omnibus(self):
if not self.generators.get('omnibus'):
from ldm.invoke.generator.omnibus import Omnibus
self.generators['omnibus'] = Omnibus(self.model, self.precision)
self.generators['omnibus'].free_gpu_mem = self.free_gpu_mem
return self.generators['omnibus']
def load_model(self):
'''
preload model identified in self.model_name

View File

@ -181,7 +181,9 @@ class Args(object):
switches_started = False
for element in elements:
if element[0] == '-' and not switches_started:
if len(element) == 0: # empty prompt
pass
elif element[0] == '-' and not switches_started:
switches_started = True
if switches_started:
switches.append(element)

View File

@ -123,8 +123,8 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
else:
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens)
conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
return (
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
cross_attention_control_args=cac_args
@ -166,4 +166,25 @@ def get_tokens_length(model, fragments: list[Fragment]):
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
return sum([len(x) for x in tokens])
def flatten_hybrid_conditioning(uncond, cond):
'''
This handles the choice between a conditional conditioning
that is a tensor (used by cross attention) vs one that has additional
dimensions as well, as used by 'hybrid'
'''
if isinstance(cond, dict):
assert isinstance(uncond, dict)
cond_in = dict()
for k in cond:
if isinstance(cond[k], list):
cond_in[k] = [
torch.cat([uncond[k][i], cond[k][i]])
for i in range(len(cond[k]))
]
else:
cond_in[k] = torch.cat([uncond[k], cond[k]])
return cond_in
else:
return cond

View File

@ -6,6 +6,7 @@ import torch
import numpy as np
import random
import os
import traceback
from tqdm import tqdm, trange
from PIL import Image, ImageFilter
from einops import rearrange, repeat
@ -43,7 +44,7 @@ class Generator():
self.variation_amount = variation_amount
self.with_variations = with_variations
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None,
**kwargs):
@ -51,6 +52,7 @@ class Generator():
self.safety_checker = safety_checker
make_image = self.get_make_image(
prompt,
sampler = sampler,
init_image = init_image,
width = width,
height = height,
@ -59,12 +61,14 @@ class Generator():
perlin = perlin,
**kwargs
)
results = []
seed = seed if seed is not None else self.new_seed()
first_seed = seed
seed, initial_noise = self.generate_initial_noise(seed, width, height)
with scope(self.model.device.type), self.model.ema_scope():
# There used to be an additional self.model.ema_scope() here, but it breaks
# the inpaint-1.5 model. Not sure what it did.... ?
with scope(self.model.device.type):
for n in trange(iterations, desc='Generating'):
x_T = None
if self.variation_amount > 0:
@ -79,7 +83,8 @@ class Generator():
try:
x_T = self.get_noise(width,height)
except:
pass
print('** An error occurred while getting initial noise **')
print(traceback.format_exc())
image = make_image(x_T)
@ -95,10 +100,10 @@ class Generator():
return results
def sample_to_image(self,samples):
def sample_to_image(self,samples)->Image.Image:
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
Given samples returned from a sampler, converts
it into a PIL Image
"""
x_samples = self.model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

View File

@ -15,7 +15,7 @@ from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserCompo
class Img2Img(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
self.init_latent = None # by get_noise()
self.init_latent = None # by get_noise()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
@ -80,7 +80,10 @@ class Img2Img(Generator):
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
if len(image.shape) == 2: # 'L' image, as in a mask
image = image[None,None]
else: # 'RGB' image
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
if normalize:
image = 2.0 * image - 1.0

View File

@ -0,0 +1,151 @@
"""omnibus module to be used with the runwayml 9-channel custom inpainting model"""
import torch
import numpy as np
from einops import repeat
from PIL import Image, ImageOps
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import downsampling
from ldm.invoke.generator.img2img import Img2Img
from ldm.invoke.generator.txt2img import Txt2Img
class Omnibus(Img2Img,Txt2Img):
def __init__(self, model, precision):
super().__init__(model, precision)
def get_make_image(
self,
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
width,
height,
init_image = None,
mask_image = None,
strength = None,
step_callback=None,
threshold=0.0,
perlin=0.0,
**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it.
"""
self.perlin = perlin
num_samples = 1
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
if isinstance(init_image, Image.Image):
init_image = self._image_to_tensor(init_image)
if isinstance(mask_image, Image.Image):
mask_image = self._image_to_tensor(ImageOps.invert(mask_image).convert('L'),normalize=False)
t_enc = steps
if init_image is not None and mask_image is not None: # inpainting
masked_image = init_image * (1 - mask_image) # masked image is the image masked by mask - masked regions zero
elif init_image is not None: # img2img
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
# create a completely black mask (1s)
mask_image = torch.ones(1, 1, init_image.shape[2], init_image.shape[3], device=self.model.device)
# and the masked image is just a copy of the original
masked_image = init_image
else: # txt2img
init_image = torch.zeros(1, 3, height, width, device=self.model.device)
mask_image = torch.ones(1, 1, height, width, device=self.model.device)
masked_image = init_image
self.init_latent = init_image
height = init_image.shape[2]
width = init_image.shape[3]
model = self.model
def make_image(x_T):
with torch.no_grad():
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
batch = self.make_batch_sd(
init_image,
mask_image,
masked_image,
prompt=prompt,
device=model.device,
num_samples=num_samples,
)
c = model.cond_stage_model.encode(batch["txt"])
c_cat = list()
for ck in model.concat_keys:
cc = batch[ck].float()
if ck != model.masked_image_key:
bchw = [num_samples, 4, height//8, width//8]
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
# cond
cond={"c_concat": [c_cat], "c_crossattn": [c]}
# uncond cond
uc_cross = model.get_unconditional_conditioning(num_samples, "")
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
shape = [model.channels, height//8, width//8]
samples, _ = sampler.sample(
batch_size = 1,
S = steps,
x_T = x_T,
conditioning = cond,
shape = shape,
verbose = False,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc_full,
eta = 1.0,
img_callback = step_callback,
threshold = threshold,
)
if self.free_gpu_mem:
self.model.model.to("cpu")
return self.sample_to_image(samples)
return make_image
def make_batch_sd(
self,
image,
mask,
masked_image,
prompt,
device,
num_samples=1):
batch = {
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
"txt": num_samples * [prompt],
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
}
return batch
def get_noise(self, width:int, height:int):
if self.init_latent is not None:
height = self.init_latent.shape[2]
width = self.init_latent.shape[3]
return Txt2Img.get_noise(self,width,height)

View File

@ -13,6 +13,7 @@ import gc
import hashlib
import psutil
import transformers
import traceback
import os
from sys import getrefcount
from omegaconf import OmegaConf
@ -73,6 +74,7 @@ class ModelCache(object):
self.models[model_name]['hash'] = hash
except Exception as e:
print(f'** model {model_name} could not be loaded: {str(e)}')
print(traceback.format_exc())
print(f'** restoring {self.current_model}')
self.get_model(self.current_model)
return None

View File

@ -89,6 +89,9 @@ class Outcrop(object):
def _extend(self,image:Image,pixels:int)-> Image:
extended_img = Image.new('RGBA',(image.width,image.height+pixels))
mask_height = pixels if self.generate.model.model.conditioning_key in ('hybrid','concat') \
else pixels *2
# first paste places old image at top of extended image, stretch
# it, and applies a gaussian blur to it
# take the top half region, stretch and paste it
@ -105,7 +108,9 @@ class Outcrop(object):
# now make the top part transparent to use as a mask
alpha = extended_img.getchannel('A')
alpha.paste(0,(0,0,extended_img.width,pixels*2))
alpha.paste(0,(0,0,extended_img.width,mask_height))
extended_img.putalpha(alpha)
extended_img.save('outputs/curly_extended.png')
return extended_img

View File

@ -66,7 +66,7 @@ class VQModel(pl.LightningModule):
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self)
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
print(f'>> Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)

View File

@ -53,12 +53,14 @@ class DDIMSampler(Sampler):
# damian0815 would like to know when/if this code path is used
e_t = self.model.apply_model(x, t, c)
else:
# step_index counts in the opposite direction to index
step_index = step_count-(index+1)
e_t = self.invokeai_diffuser.do_diffusion_step(x, t,
unconditional_conditioning, c,
unconditional_guidance_scale,
step_index=step_index)
e_t = self.invokeai_diffuser.do_diffusion_step(
x, t,
unconditional_conditioning, c,
unconditional_guidance_scale,
step_index=step_index
)
if score_corrector is not None:
assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score(

View File

@ -19,6 +19,7 @@ from functools import partial
from tqdm import tqdm
from torchvision.utils import make_grid
from pytorch_lightning.utilities.distributed import rank_zero_only
from omegaconf import ListConfig
import urllib
from ldm.util import (
@ -120,7 +121,7 @@ class DDPM(pl.LightningModule):
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model)
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
print(f' | Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
self.use_scheduler = scheduler_config is not None
if self.use_scheduler:
@ -1883,6 +1884,24 @@ class LatentDiffusion(DDPM):
return samples, intermediates
@torch.no_grad()
def get_unconditional_conditioning(self, batch_size, null_label=None):
if null_label is not None:
xc = null_label
if isinstance(xc, ListConfig):
xc = list(xc)
if isinstance(xc, dict) or isinstance(xc, list):
c = self.get_learned_conditioning(xc)
else:
if hasattr(xc, "to"):
xc = xc.to(self.device)
c = self.get_learned_conditioning(xc)
else:
# todo: get null label from cond_stage_model
raise NotImplementedError()
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
return c
@torch.no_grad()
def log_images(
self,
@ -2147,8 +2166,8 @@ class DiffusionWrapper(pl.LightningModule):
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc)
elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1)
xc = torch.cat([x] + c_concat, dim=1)
out = self.diffusion_model(xc, t, context=cc)
elif self.conditioning_key == 'adm':
cc = c_crossattn[0]
@ -2187,3 +2206,58 @@ class Layout2ImgDiffusion(LatentDiffusion):
cond_img = torch.stack(bbox_imgs, dim=0)
logs['bbox_image'] = cond_img
return logs
class LatentInpaintDiffusion(LatentDiffusion):
def __init__(
self,
concat_keys=("mask", "masked_image"),
masked_image_key="masked_image",
finetune_keys=None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
self.concat_keys = concat_keys
@torch.no_grad()
def get_input(
self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
):
# note: restricted to non-trainable encoders currently
assert (
not self.cond_stage_trainable
), "trainable cond stages not yet supported for inpainting"
z, c, x, xrec, xc = super().get_input(
batch,
self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
return_original_cond=True,
bs=bs,
)
assert exists(self.concat_keys)
c_cat = list()
for ck in self.concat_keys:
cc = (
rearrange(batch[ck], "b h w c -> b c h w")
.to(memory_format=torch.contiguous_format)
.float()
)
if bs is not None:
cc = cc[:bs]
cc = cc.to(self.device)
bchw = z.shape
if ck != self.masked_image_key:
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
if return_first_stage_outputs:
return z, all_conds, x, xrec, xc
return z, all_conds

View File

@ -23,9 +23,10 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
class CFGDenoiser(nn.Module):
def __init__(self, model, threshold = 0, warmup = 0):
def __init__(self, sampler, threshold = 0, warmup = 0):
super().__init__()
self.inner_model = model
self.inner_model = sampler.model
self.sampler = sampler
self.threshold = threshold
self.warmup_max = warmup
self.warmup = max(warmup / 10, 1)
@ -43,10 +44,14 @@ class CFGDenoiser(nn.Module):
def forward(self, x, sigma, uncond, cond, cond_scale):
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
# apply threshold
if isinstance(cond,dict): # hybrid model
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = self.sampler.make_cond_in(uncond,cond)
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
next_x = uncond + (cond - uncond) * cond_scale
else: # cross attention model
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
self.warmup += 1
@ -56,8 +61,6 @@ class CFGDenoiser(nn.Module):
thresh = self.threshold
return cfg_apply_threshold(next_x, thresh)
class KSampler(Sampler):
def __init__(self, model, schedule='lms', device=None, **kwargs):
denoiser = K.external.CompVisDenoiser(model)
@ -286,3 +289,6 @@ class KSampler(Sampler):
'''
return self.model.inner_model.q_sample(x0,ts)
def conditioning_key(self)->str:
return self.model.inner_model.model.conditioning_key

View File

@ -14,9 +14,6 @@ class PLMSSampler(Sampler):
def __init__(self, model, schedule='linear', device=None, **kwargs):
super().__init__(model,schedule,model.num_timesteps, device)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs)
@ -67,7 +64,6 @@ class PLMSSampler(Sampler):
unconditional_conditioning, c,
unconditional_guidance_scale,
step_index=step_index)
if score_corrector is not None:
assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score(

View File

@ -11,6 +11,7 @@ import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.invoke.devices import choose_torch_device
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
@ -26,6 +27,8 @@ class Sampler(object):
self.ddpm_num_timesteps = steps
self.schedule = schedule
self.device = device or choose_torch_device()
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
@ -160,6 +163,18 @@ class Sampler(object):
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# check to see if make_schedule() has run, and if not, run it
if self.ddim_timesteps is None:
self.make_schedule(
@ -196,7 +211,7 @@ class Sampler(object):
)
return samples, intermediates
#torch.no_grad()
@torch.no_grad()
def do_sampling(
self,
cond,
@ -257,6 +272,7 @@ class Sampler(object):
)
if mask is not None:
print('DEBUG: in masking routine')
assert x0 is not None
img_orig = self.model.q_sample(
x0, ts
@ -313,7 +329,6 @@ class Sampler(object):
all_timesteps_count = None,
**kwargs
):
timesteps = (
np.arange(self.ddpm_num_timesteps)
if use_original_steps
@ -420,3 +435,27 @@ class Sampler(object):
'''
return self.model.q_sample(x0,ts)
def conditioning_key(self)->str:
return self.model.model.conditioning_key
# def make_cond_in(self, uncond, cond):
# '''
# This handles the choice between a conditional conditioning
# that is a tensor (used by cross attention) vs one that is a dict
# used by 'hybrid'
# '''
# if isinstance(cond, dict):
# assert isinstance(uncond, dict)
# cond_in = dict()
# for k in cond:
# if isinstance(cond[k], list):
# cond_in[k] = [
# torch.cat([uncond[k][i], cond[k][i]])
# for i in range(len(cond[k]))
# ]
# else:
# cond_in[k] = torch.cat([uncond[k], cond[k]])
# else:
# cond_in = torch.cat([uncond, cond])
# return cond_in

View File

@ -171,9 +171,9 @@ def main_loop(gen, opt):
except (OSError, AttributeError, KeyError):
pass
if len(opt.prompt) == 0:
print('\nTry again with a prompt!')
continue
# if len(opt.prompt) == 0:
# print('\nTry again with a prompt!')
# continue
# width and height are set by model if not specified
if not opt.width: