Merge branch 'fix-prompts' of https://github.com/damian0815/InvokeAI into merge-prompt-and-inpaint-model

This commit is contained in:
Lincoln Stein
2022-10-26 08:50:55 -04:00
22 changed files with 2077 additions and 173 deletions

View File

@ -1,5 +1,5 @@
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
import pyparsing
# Derived from source code carrying the following copyrights
# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
@ -24,6 +24,7 @@ from PIL import Image, ImageOps
from torch import nn
from pytorch_lightning import seed_everything, logging
from ldm.invoke.prompt_parser import PromptParser
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
@ -32,7 +33,7 @@ from ldm.invoke.pngwriter import PngWriter
from ldm.invoke.args import metadata_from_png
from ldm.invoke.image_util import InitImageResizer
from ldm.invoke.devices import choose_torch_device, choose_precision
from ldm.invoke.conditioning import get_uc_and_c
from ldm.invoke.conditioning import get_uc_and_c_and_ec
from ldm.invoke.model_cache import ModelCache
from ldm.invoke.seamless import configure_model_padding
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
@ -404,7 +405,7 @@ class Generate:
mask_image = None
try:
uc, c = get_uc_and_c(
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
prompt, model =self.model,
skip_normalize=skip_normalize,
log_tokens =self.log_tokenization
@ -448,7 +449,7 @@ class Generate:
sampler=self.sampler,
steps=steps,
cfg_scale=cfg_scale,
conditioning=(uc, c),
conditioning=(uc, c, extra_conditioning_info),
ddim_eta=ddim_eta,
image_callback=image_callback, # called after the final image is generated
step_callback=step_callback, # called after each intermediate image is generated
@ -481,14 +482,14 @@ class Generate:
save_original = save_original,
image_callback = image_callback)
except RuntimeError as e:
print(traceback.format_exc(), file=sys.stderr)
print('>> Could not generate image.')
except KeyboardInterrupt:
if catch_interrupts:
print('**Interrupted** Partial results will be returned.')
else:
raise KeyboardInterrupt
except RuntimeError as e:
print(traceback.format_exc(), file=sys.stderr)
print('>> Could not generate image.')
toc = time.time()
print('>> Usage stats:')
@ -553,7 +554,8 @@ class Generate:
image = Image.open(image_path)
# used by multiple postfixers
uc, c = get_uc_and_c(
# todo: cross-attention control
uc, c, _ = get_uc_and_c_and_ec(
prompt, model =self.model,
skip_normalize=opt.skip_normalize,
log_tokens =opt.log_tokenization