fix(nodes): fix cancel; fix callback for img2img, inpaint

This commit is contained in:
psychedelicious 2023-03-18 21:37:50 +11:00
parent 5fe38f7c88
commit c34ac91ff0
3 changed files with 102 additions and 29 deletions

View File

@ -4,15 +4,16 @@ from functools import partial
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
import numpy as np import numpy as np
from torch import Tensor
from pydantic import Field from pydantic import Field
from ..services.image_storage import ImageType from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput from .image import ImageField, ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL from ..util.util import diffusers_step_callback_adapter, CanceledException
SAMPLER_NAME_VALUES = Literal[ SAMPLER_NAME_VALUES = Literal[
tuple(InvokeAIGenerator.schedulers()) tuple(InvokeAIGenerator.schedulers())
@ -43,33 +44,24 @@ class TextToImageInvocation(BaseInvocation):
def dispatch_progress( def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None: ) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step step = intermediate_state.step
if intermediate_state.predicted_original is not None: if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep, # Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be. # but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original sample = intermediate_state.predicted_original
else: else:
sample = intermediate_state.latents sample = intermediate_state.latents
image = Generator(context.services.model_manager.get_model()).sample_to_image(sample) diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
(width, height) = image.size
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
self.id,
{
"width" : width,
"height": height,
"dataURL": dataURL,
},
step,
self.steps,
)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
def step_callback(state: PipelineIntermediateState): # def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step) # if (context.services.queue.is_canceled(context.graph_execution_state_id)):
# raise CanceledException
# self.dispatch_progress(context, state.latents, state.step)
# Handle invalid model parameter # Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache # TODO: figure out if this can be done via a validator that uses the model_cache
@ -115,6 +107,22 @@ class ImageToImageInvocation(TextToImageInvocation):
description="Whether or not the result should be fit to the aspect ratio of the input image", description="Whether or not the result should be fit to the aspect ratio of the input image",
) )
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = ( image = (
None None
@ -129,17 +137,19 @@ class ImageToImageInvocation(TextToImageInvocation):
# TODO: figure out if this can be done via a validator that uses the model_cache # TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now? # TODO: How to get the default model name now?
model = context.services.model_manager.get_model() model = context.services.model_manager.get_model()
generator_output = next( outputs = Img2Img(model).generate(
Img2Img(model).generate(
prompt=self.prompt, prompt=self.prompt,
init_image=image, init_image=image,
init_mask=mask, init_mask=mask,
step_callback=partial(self.dispatch_progress, context), step_callback=partial(self.dispatch_progress, context),
**self.dict( **self.dict(
exclude={"prompt", "image", "mask"} exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually ), # Shorthand for passing all of the parameters above manually
) )
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generator_output = next(outputs)
result_image = generator_output.image result_image = generator_output.image
@ -169,6 +179,22 @@ class InpaintInvocation(ImageToImageInvocation):
description="The amount by which to replace masked areas with latent noise", description="The amount by which to replace masked areas with latent noise",
) )
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = ( image = (
None None
@ -187,8 +213,7 @@ class InpaintInvocation(ImageToImageInvocation):
# TODO: figure out if this can be done via a validator that uses the model_cache # TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now? # TODO: How to get the default model name now?
model = context.services.model_manager.get_model() model = context.services.model_manager.get_model()
generator_output = next( outputs = Inpaint(model).generate(
Inpaint(model).generate(
prompt=self.prompt, prompt=self.prompt,
init_img=image, init_img=image,
init_mask=mask, init_mask=mask,
@ -197,7 +222,10 @@ class InpaintInvocation(ImageToImageInvocation):
exclude={"prompt", "image", "mask"} exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually ), # Shorthand for passing all of the parameters above manually
) )
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generator_output = next(outputs)
result_image = generator_output.image result_image = generator_output.image

View File

@ -4,7 +4,7 @@ from threading import Event, Thread
from ..invocations.baseinvocation import InvocationContext from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker from .invoker import InvocationProcessorABC, Invoker
from ..util.util import CanceledException
class DefaultInvocationProcessor(InvocationProcessorABC): class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread __invoker_thread: Thread
@ -82,6 +82,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
except CanceledException:
pass
except Exception as e: except Exception as e:
error = traceback.format_exc() error = traceback.format_exc()

42
invokeai/app/util/util.py Normal file
View File

@ -0,0 +1,42 @@
import torch
from PIL import Image
from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator
from ...backend.stable_diffusion import PipelineIntermediateState
class CanceledException(Exception):
pass
def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id: str, context: InvocationContext, ):
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step,
steps,
)
def diffusers_step_callback_adapter(*cb_args, **kwargs):
"""
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
This adapter grabs the needed data and passes it along to the callback function.
"""
if isinstance(cb_args[0], PipelineIntermediateState):
progress_state: PipelineIntermediateState = cb_args[0]
return fast_latents_step_callback(progress_state.latents, progress_state.step, **kwargs)
else:
return fast_latents_step_callback(*cb_args, **kwargs)