fix crash in CLI when --save_intermediates called

Fixes #2733
This commit is contained in:
Lincoln Stein 2023-02-20 12:50:32 -05:00
parent 8352ab2076
commit ed4c8f6a8a

View File

@ -15,17 +15,18 @@ if sys.platform == "darwin":
import pyparsing # type: ignore import pyparsing # type: ignore
import ldm.invoke import ldm.invoke
from ldm.generate import Generate from ..generate import Generate
from ldm.invoke.args import (Args, dream_cmd_from_png, metadata_dumps, from .args import (Args, dream_cmd_from_png, metadata_dumps,
metadata_from_png) metadata_from_png)
from ldm.invoke.globals import Globals from .generator.diffusers_pipeline import PipelineIntermediateState
from ldm.invoke.image_util import make_grid from .globals import Globals
from ldm.invoke.log import write_log from .image_util import make_grid
from ldm.invoke.model_manager import ModelManager from .log import write_log
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata from .model_manager import ModelManager
from ldm.invoke.prompt_parser import PromptParser from .pngwriter import PngWriter, retrieve_metadata, write_metadata
from ldm.invoke.readline import Completer, get_completer from .prompt_parser import PromptParser
from ldm.util import url_attachment_name from .readline import Completer, get_completer
from ..util import url_attachment_name
# global used in multiple functions (fix) # global used in multiple functions (fix)
infile = None infile = None
@ -1262,10 +1263,13 @@ def make_step_callback(gen, opt, prefix):
os.makedirs(destination, exist_ok=True) os.makedirs(destination, exist_ok=True)
print(f">> Intermediate images will be written into {destination}") print(f">> Intermediate images will be written into {destination}")
def callback(img, step): def callback(state: PipelineIntermediateState):
latents = state.latents
step = state.step
if step % opt.save_intermediates == 0 or step == opt.steps - 1: if step % opt.save_intermediates == 0 or step == opt.steps - 1:
filename = os.path.join(destination, f"{step:04}.png") filename = os.path.join(destination, f"{step:04}.png")
image = gen.sample_to_image(img) image = gen.sample_to_lowres_estimated_image(latents)
image = image.resize((image.size[0]*8,image.size[1]*8))
image.save(filename, "PNG") image.save(filename, "PNG")
return callback return callback