From ed4c8f6a8a50e26363f7222955fe0d41b5c5edeb Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 20 Feb 2023 12:50:32 -0500 Subject: [PATCH] fix crash in CLI when --save_intermediates called Fixes #2733 --- ldm/invoke/CLI.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/ldm/invoke/CLI.py b/ldm/invoke/CLI.py index d639c16640..c47f19b6f2 100644 --- a/ldm/invoke/CLI.py +++ b/ldm/invoke/CLI.py @@ -15,17 +15,18 @@ if sys.platform == "darwin": import pyparsing # type: ignore import ldm.invoke -from ldm.generate import Generate -from ldm.invoke.args import (Args, dream_cmd_from_png, metadata_dumps, +from ..generate import Generate +from .args import (Args, dream_cmd_from_png, metadata_dumps, metadata_from_png) -from ldm.invoke.globals import Globals -from ldm.invoke.image_util import make_grid -from ldm.invoke.log import write_log -from ldm.invoke.model_manager import ModelManager -from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata -from ldm.invoke.prompt_parser import PromptParser -from ldm.invoke.readline import Completer, get_completer -from ldm.util import url_attachment_name +from .generator.diffusers_pipeline import PipelineIntermediateState +from .globals import Globals +from .image_util import make_grid +from .log import write_log +from .model_manager import ModelManager +from .pngwriter import PngWriter, retrieve_metadata, write_metadata +from .prompt_parser import PromptParser +from .readline import Completer, get_completer +from ..util import url_attachment_name # global used in multiple functions (fix) infile = None @@ -1262,10 +1263,13 @@ def make_step_callback(gen, opt, prefix): os.makedirs(destination, exist_ok=True) print(f">> Intermediate images will be written into {destination}") - def callback(img, step): + def callback(state: PipelineIntermediateState): + latents = state.latents + step = state.step if step % opt.save_intermediates == 0 or step == opt.steps - 1: filename = os.path.join(destination, f"{step:04}.png") - image = gen.sample_to_image(img) + image = gen.sample_to_lowres_estimated_image(latents) + image = image.resize((image.size[0]*8,image.size[1]*8)) image.save(filename, "PNG") return callback