fix(nodes): fix cancel; fix callback for img2img, inpaint

This commit is contained in:
psychedelicious 2023-03-18 21:37:50 +11:00
parent 5fe38f7c88
commit c34ac91ff0
3 changed files with 102 additions and 29 deletions

View File

@ -4,15 +4,16 @@ from functools import partial
from typing import Literal, Optional, Union
import numpy as np
from torch import Tensor
from pydantic import Field
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL
from ..util.util import diffusers_step_callback_adapter, CanceledException
SAMPLER_NAME_VALUES = Literal[
tuple(InvokeAIGenerator.schedulers())
@ -43,33 +44,24 @@ class TextToImageInvocation(BaseInvocation):
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
image = Generator(context.services.model_manager.get_model()).sample_to_image(sample)
(width, height) = image.size
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
self.id,
{
"width" : width,
"height": height,
"dataURL": dataURL,
},
step,
self.steps,
)
sample = intermediate_state.latents
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
def invoke(self, context: InvocationContext) -> ImageOutput:
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step)
# def step_callback(state: PipelineIntermediateState):
# if (context.services.queue.is_canceled(context.graph_execution_state_id)):
# raise CanceledException
# self.dispatch_progress(context, state.latents, state.step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
@ -115,6 +107,22 @@ class ImageToImageInvocation(TextToImageInvocation):
description="Whether or not the result should be fit to the aspect ratio of the input image",
)
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = (
None
@ -129,17 +137,19 @@ class ImageToImageInvocation(TextToImageInvocation):
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
model = context.services.model_manager.get_model()
generator_output = next(
Img2Img(model).generate(
outputs = Img2Img(model).generate(
prompt=self.prompt,
init_image=image,
init_mask=mask,
step_callback=partial(self.dispatch_progress, context),
step_callback=partial(self.dispatch_progress, context),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generator_output = next(outputs)
result_image = generator_output.image
@ -169,6 +179,22 @@ class InpaintInvocation(ImageToImageInvocation):
description="The amount by which to replace masked areas with latent noise",
)
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = (
None
@ -187,8 +213,7 @@ class InpaintInvocation(ImageToImageInvocation):
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
model = context.services.model_manager.get_model()
generator_output = next(
Inpaint(model).generate(
outputs = Inpaint(model).generate(
prompt=self.prompt,
init_img=image,
init_mask=mask,
@ -197,7 +222,10 @@ class InpaintInvocation(ImageToImageInvocation):
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generator_output = next(outputs)
result_image = generator_output.image

View File

@ -4,7 +4,7 @@ from threading import Event, Thread
from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker
from ..util.util import CanceledException
class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread
@ -82,6 +82,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
except KeyboardInterrupt:
pass
except CanceledException:
pass
except Exception as e:
error = traceback.format_exc()

42
invokeai/app/util/util.py Normal file
View File

@ -0,0 +1,42 @@
import torch
from PIL import Image
from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator
from ...backend.stable_diffusion import PipelineIntermediateState
class CanceledException(Exception):
pass
def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id: str, context: InvocationContext, ):
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step,
steps,
)
def diffusers_step_callback_adapter(*cb_args, **kwargs):
"""
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
This adapter grabs the needed data and passes it along to the callback function.
"""
if isinstance(cb_args[0], PipelineIntermediateState):
progress_state: PipelineIntermediateState = cb_args[0]
return fast_latents_step_callback(progress_state.latents, progress_state.step, **kwargs)
else:
return fast_latents_step_callback(*cb_args, **kwargs)