diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index b8140b11e9..85e8b41289 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -1,17 +1,13 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -from datetime import datetime, timezone -from typing import Any, Literal, Optional, Union +from functools import partial +from typing import Literal, Optional, Union import numpy as np -from torch import Tensor -from PIL import Image from pydantic import Field -from skimage.exposure.histogram_matching import match_histograms from ..services.image_storage import ImageType -from ..services.invocation_services import InvocationServices from .baseinvocation import BaseInvocation, InvocationContext from .image import ImageField, ImageOutput from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator @@ -45,24 +41,27 @@ class TextToImageInvocation(BaseInvocation): # TODO: pass this an emitter method or something? or a session for dispatching? def dispatch_progress( - self, context: InvocationContext, sample: Tensor, step: int - ) -> None: - # TODO: only output a preview image when requested - image = Generator.sample_to_lowres_estimated_image(sample) - + self, context: InvocationContext, intermediate_state: PipelineIntermediateState + ) -> None: + step = intermediate_state.step + if intermediate_state.predicted_original is not None: + # Some schedulers report not only the noisy latents at the current timestep, + # but also their estimate so far of what the de-noised latents will be. + + sample = intermediate_state.predicted_original + else: + sample = intermediate_state.latents + + image = Generator(context.services.model_manager.get_model()).sample_to_image(sample) (width, height) = image.size - width *= 8 - height *= 8 - dataURL = image_to_dataURL(image, image_format="JPEG") - context.services.events.emit_generator_progress( context.graph_execution_state_id, self.id, { - "width": width, + "width" : width, "height": height, - "dataURL": dataURL + "dataURL": dataURL, }, step, self.steps, @@ -79,7 +78,7 @@ class TextToImageInvocation(BaseInvocation): model= context.services.model_manager.get_model() outputs = Txt2Img(model).generate( prompt=self.prompt, - step_callback=step_callback, + step_callback=partial(self.dispatch_progress, context), **self.dict( exclude={"prompt"} ), # Shorthand for passing all of the parameters above manually @@ -126,9 +125,6 @@ class ImageToImageInvocation(TextToImageInvocation): ) mask = None - def step_callback(sample, step=0): - self.dispatch_progress(context, sample, step) - # Handle invalid model parameter # TODO: figure out if this can be done via a validator that uses the model_cache # TODO: How to get the default model name now? @@ -138,7 +134,7 @@ class ImageToImageInvocation(TextToImageInvocation): prompt=self.prompt, init_image=image, init_mask=mask, - step_callback=step_callback, + step_callback=partial(self.dispatch_progress, context), **self.dict( exclude={"prompt", "image", "mask"} ), # Shorthand for passing all of the parameters above manually @@ -187,19 +183,16 @@ class InpaintInvocation(ImageToImageInvocation): else context.services.images.get(self.mask.image_type, self.mask.image_name) ) - def step_callback(sample, step=0): - self.dispatch_progress(context, sample, step) - # Handle invalid model parameter # TODO: figure out if this can be done via a validator that uses the model_cache # TODO: How to get the default model name now? - manager = context.services.model_manager.get_model() + model = context.services.model_manager.get_model() generator_output = next( Inpaint(model).generate( prompt=self.prompt, - init_image=image, - mask_image=mask, - step_callback=step_callback, + init_img=image, + init_mask=mask, + step_callback=partial(self.dispatch_progress, context), **self.dict( exclude={"prompt", "image", "mask"} ), # Shorthand for passing all of the parameters above manually diff --git a/invokeai/backend/web/invoke_ai_web_server.py b/invokeai/backend/web/invoke_ai_web_server.py index dc77ff4723..7209e31449 100644 --- a/invokeai/backend/web/invoke_ai_web_server.py +++ b/invokeai/backend/web/invoke_ai_web_server.py @@ -1022,7 +1022,7 @@ class InvokeAIWebServer: "RGB" ) - def image_progress(sample, step): + def image_progress(intermediate_state: PipelineIntermediateState): if self.canceled.is_set(): raise CanceledException @@ -1030,6 +1030,14 @@ class InvokeAIWebServer: nonlocal generation_parameters nonlocal progress + step = intermediate_state.step + if intermediate_state.predicted_original is not None: + # Some schedulers report not only the noisy latents at the current timestep, + # but also their estimate so far of what the de-noised latents will be. + sample = intermediate_state.predicted_original + else: + sample = intermediate_state.latents + generation_messages = { "txt2img": "common.statusGeneratingTextToImage", "img2img": "common.statusGeneratingImageToImage", @@ -1302,16 +1310,9 @@ class InvokeAIWebServer: progress.set_current_iteration(progress.current_iteration + 1) - def diffusers_step_callback_adapter(*cb_args, **kwargs): - if isinstance(cb_args[0], PipelineIntermediateState): - progress_state: PipelineIntermediateState = cb_args[0] - return image_progress(progress_state.latents, progress_state.step) - else: - return image_progress(*cb_args, **kwargs) - self.generate.prompt2image( **generation_parameters, - step_callback=diffusers_step_callback_adapter, + step_callback=image_progress, image_callback=image_done, )