fix crash in CLI when --save_intermediates called

Fixes #2733
This commit is contained in:
Lincoln Stein 2023-02-20 12:50:32 -05:00
parent 8352ab2076
commit ed4c8f6a8a

View File

@ -15,17 +15,18 @@ if sys.platform == "darwin":
import pyparsing # type: ignore
import ldm.invoke
from ldm.generate import Generate
from ldm.invoke.args import (Args, dream_cmd_from_png, metadata_dumps,
from ..generate import Generate
from .args import (Args, dream_cmd_from_png, metadata_dumps,
metadata_from_png)
from ldm.invoke.globals import Globals
from ldm.invoke.image_util import make_grid
from ldm.invoke.log import write_log
from ldm.invoke.model_manager import ModelManager
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
from ldm.invoke.prompt_parser import PromptParser
from ldm.invoke.readline import Completer, get_completer
from ldm.util import url_attachment_name
from .generator.diffusers_pipeline import PipelineIntermediateState
from .globals import Globals
from .image_util import make_grid
from .log import write_log
from .model_manager import ModelManager
from .pngwriter import PngWriter, retrieve_metadata, write_metadata
from .prompt_parser import PromptParser
from .readline import Completer, get_completer
from ..util import url_attachment_name
# global used in multiple functions (fix)
infile = None
@ -1262,10 +1263,13 @@ def make_step_callback(gen, opt, prefix):
os.makedirs(destination, exist_ok=True)
print(f">> Intermediate images will be written into {destination}")
def callback(img, step):
def callback(state: PipelineIntermediateState):
latents = state.latents
step = state.step
if step % opt.save_intermediates == 0 or step == opt.steps - 1:
filename = os.path.join(destination, f"{step:04}.png")
image = gen.sample_to_image(img)
image = gen.sample_to_lowres_estimated_image(latents)
image = image.resize((image.size[0]*8,image.size[1]*8))
image.save(filename, "PNG")
return callback