Merge branch 'main' into fix/ui/viewer-localisation

This commit is contained in:
blessedcoolant 2023-03-26 20:35:12 +13:00 committed by GitHub
commit 3ba7e966b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 229 additions and 48 deletions

View File

@ -270,3 +270,18 @@ async def invoke_session(
ApiDependencies.invoker.invoke(session, invoke_all=all) ApiDependencies.invoker.invoke(session, invoke_all=all)
return Response(status_code=202) return Response(status_code=202)
@session_router.delete(
"/{session_id}/invoke",
operation_id="cancel_session_invoke",
responses={
202: {"description": "The invocation is canceled"}
},
)
async def cancel_session_invoke(
session_id: str = Path(description="The id of the session to cancel"),
) -> None:
"""Invokes a session"""
ApiDependencies.invoker.cancel(session_id)
return Response(status_code=202)

View File

@ -4,15 +4,16 @@ from functools import partial
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
import numpy as np import numpy as np
from torch import Tensor
from pydantic import Field from pydantic import Field
from ..services.image_storage import ImageType from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput from .image import ImageField, ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL from ..util.util import diffusers_step_callback_adapter, CanceledException
SAMPLER_NAME_VALUES = Literal[ SAMPLER_NAME_VALUES = Literal[
tuple(InvokeAIGenerator.schedulers()) tuple(InvokeAIGenerator.schedulers())
@ -43,33 +44,24 @@ class TextToImageInvocation(BaseInvocation):
def dispatch_progress( def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None: ) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step step = intermediate_state.step
if intermediate_state.predicted_original is not None: if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep, # Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be. # but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original sample = intermediate_state.predicted_original
else: else:
sample = intermediate_state.latents sample = intermediate_state.latents
image = Generator(context.services.model_manager.get_model()).sample_to_image(sample) diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
(width, height) = image.size
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
self.id,
{
"width" : width,
"height": height,
"dataURL": dataURL,
},
step,
self.steps,
)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
def step_callback(state: PipelineIntermediateState): # def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step) # if (context.services.queue.is_canceled(context.graph_execution_state_id)):
# raise CanceledException
# self.dispatch_progress(context, state.latents, state.step)
# Handle invalid model parameter # Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache # TODO: figure out if this can be done via a validator that uses the model_cache
@ -115,6 +107,22 @@ class ImageToImageInvocation(TextToImageInvocation):
description="Whether or not the result should be fit to the aspect ratio of the input image", description="Whether or not the result should be fit to the aspect ratio of the input image",
) )
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = ( image = (
None None
@ -129,17 +137,19 @@ class ImageToImageInvocation(TextToImageInvocation):
# TODO: figure out if this can be done via a validator that uses the model_cache # TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now? # TODO: How to get the default model name now?
model = context.services.model_manager.get_model() model = context.services.model_manager.get_model()
generator_output = next( outputs = Img2Img(model).generate(
Img2Img(model).generate(
prompt=self.prompt, prompt=self.prompt,
init_image=image, init_image=image,
init_mask=mask, init_mask=mask,
step_callback=partial(self.dispatch_progress, context), step_callback=partial(self.dispatch_progress, context),
**self.dict( **self.dict(
exclude={"prompt", "image", "mask"} exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually ), # Shorthand for passing all of the parameters above manually
) )
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generator_output = next(outputs)
result_image = generator_output.image result_image = generator_output.image
@ -169,6 +179,22 @@ class InpaintInvocation(ImageToImageInvocation):
description="The amount by which to replace masked areas with latent noise", description="The amount by which to replace masked areas with latent noise",
) )
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = ( image = (
None None
@ -187,8 +213,7 @@ class InpaintInvocation(ImageToImageInvocation):
# TODO: figure out if this can be done via a validator that uses the model_cache # TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now? # TODO: How to get the default model name now?
model = context.services.model_manager.get_model() model = context.services.model_manager.get_model()
generator_output = next( outputs = Inpaint(model).generate(
Inpaint(model).generate(
prompt=self.prompt, prompt=self.prompt,
init_img=image, init_img=image,
init_mask=mask, init_mask=mask,
@ -197,7 +222,10 @@ class InpaintInvocation(ImageToImageInvocation):
exclude={"prompt", "image", "mask"} exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually ), # Shorthand for passing all of the parameters above manually
) )
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generator_output = next(outputs)
result_image = generator_output.image result_image = generator_output.image

View File

@ -28,12 +28,28 @@ class ImageOutput(BaseInvocationOutput):
image: ImageField = Field(default=None, description="The output image") image: ImageField = Field(default=None, description="The output image")
#fmt: on #fmt: on
class Config:
schema_extra = {
'required': [
'type',
'image',
]
}
class MaskOutput(BaseInvocationOutput): class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask""" """Base class for invocations that output a mask"""
#fmt: off #fmt: off
type: Literal["mask"] = "mask" type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask") mask: ImageField = Field(default=None, description="The output mask")
#fomt: on #fmt: on
class Config:
schema_extra = {
'required': [
'type',
'mask',
]
}
# TODO: this isn't really necessary anymore # TODO: this isn't really necessary anymore
class LoadImageInvocation(BaseInvocation): class LoadImageInvocation(BaseInvocation):

View File

@ -12,3 +12,11 @@ class PromptOutput(BaseInvocationOutput):
prompt: str = Field(default=None, description="The output prompt") prompt: str = Field(default=None, description="The output prompt")
#fmt: on #fmt: on
class Config:
schema_extra = {
'required': [
'type',
'prompt',
]
}

View File

@ -127,6 +127,13 @@ class NodeAlreadyExecutedError(Exception):
class GraphInvocationOutput(BaseInvocationOutput): class GraphInvocationOutput(BaseInvocationOutput):
type: Literal["graph_output"] = "graph_output" type: Literal["graph_output"] = "graph_output"
class Config:
schema_extra = {
'required': [
'type',
'image',
]
}
# TODO: Fill this out and move to invocations # TODO: Fill this out and move to invocations
class GraphInvocation(BaseInvocation): class GraphInvocation(BaseInvocation):
@ -147,6 +154,13 @@ class IterateInvocationOutput(BaseInvocationOutput):
item: Any = Field(description="The item being iterated over") item: Any = Field(description="The item being iterated over")
class Config:
schema_extra = {
'required': [
'type',
'item',
]
}
# TODO: Fill this out and move to invocations # TODO: Fill this out and move to invocations
class IterateInvocation(BaseInvocation): class IterateInvocation(BaseInvocation):
@ -169,6 +183,13 @@ class CollectInvocationOutput(BaseInvocationOutput):
collection: list[Any] = Field(description="The collection of input items") collection: list[Any] = Field(description="The collection of input items")
class Config:
schema_extra = {
'required': [
'type',
'collection',
]
}
class CollectInvocation(BaseInvocation): class CollectInvocation(BaseInvocation):
"""Collects values into a collection""" """Collects values into a collection"""

View File

@ -2,6 +2,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from queue import Queue from queue import Queue
import time
# TODO: make this serializable # TODO: make this serializable
@ -10,6 +11,7 @@ class InvocationQueueItem:
graph_execution_state_id: str graph_execution_state_id: str
invocation_id: str invocation_id: str
invoke_all: bool invoke_all: bool
timestamp: float
def __init__( def __init__(
self, self,
@ -22,6 +24,7 @@ class InvocationQueueItem:
self.graph_execution_state_id = graph_execution_state_id self.graph_execution_state_id = graph_execution_state_id
self.invocation_id = invocation_id self.invocation_id = invocation_id
self.invoke_all = invoke_all self.invoke_all = invoke_all
self.timestamp = time.time()
class InvocationQueueABC(ABC): class InvocationQueueABC(ABC):
@ -35,15 +38,44 @@ class InvocationQueueABC(ABC):
def put(self, item: InvocationQueueItem | None) -> None: def put(self, item: InvocationQueueItem | None) -> None:
pass pass
@abstractmethod
def cancel(self, graph_execution_state_id: str) -> None:
pass
@abstractmethod
def is_canceled(self, graph_execution_state_id: str) -> bool:
pass
class MemoryInvocationQueue(InvocationQueueABC): class MemoryInvocationQueue(InvocationQueueABC):
__queue: Queue __queue: Queue
__cancellations: dict[str, float]
def __init__(self): def __init__(self):
self.__queue = Queue() self.__queue = Queue()
self.__cancellations = dict()
def get(self) -> InvocationQueueItem: def get(self) -> InvocationQueueItem:
return self.__queue.get() item = self.__queue.get()
while isinstance(item, InvocationQueueItem) \
and item.graph_execution_state_id in self.__cancellations \
and self.__cancellations[item.graph_execution_state_id] > item.timestamp:
item = self.__queue.get()
# Clear old items
for graph_execution_state_id in list(self.__cancellations.keys()):
if self.__cancellations[graph_execution_state_id] < item.timestamp:
del self.__cancellations[graph_execution_state_id]
return item
def put(self, item: InvocationQueueItem | None) -> None: def put(self, item: InvocationQueueItem | None) -> None:
self.__queue.put(item) self.__queue.put(item)
def cancel(self, graph_execution_state_id: str) -> None:
if graph_execution_state_id not in self.__cancellations:
self.__cancellations[graph_execution_state_id] = time.time()
def is_canceled(self, graph_execution_state_id: str) -> bool:
return graph_execution_state_id in self.__cancellations

View File

@ -50,6 +50,10 @@ class Invoker:
new_state = GraphExecutionState(graph=Graph() if graph is None else graph) new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
self.services.graph_execution_manager.set(new_state) self.services.graph_execution_manager.set(new_state)
return new_state return new_state
def cancel(self, graph_execution_state_id: str) -> None:
"""Cancels the given execution state"""
self.services.queue.cancel(graph_execution_state_id)
def __start_service(self, service) -> None: def __start_service(self, service) -> None:
# Call start() method on any services that have it # Call start() method on any services that have it

View File

@ -4,7 +4,7 @@ from threading import Event, Thread
from ..invocations.baseinvocation import InvocationContext from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker from .invoker import InvocationProcessorABC, Invoker
from ..util.util import CanceledException
class DefaultInvocationProcessor(InvocationProcessorABC): class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread __invoker_thread: Thread
@ -58,6 +58,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
) )
) )
# Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled(
graph_execution_state.id
):
continue
# Save outputs and history # Save outputs and history
graph_execution_state.complete(invocation.id, outputs) graph_execution_state.complete(invocation.id, outputs)
@ -76,6 +82,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
except CanceledException:
pass
except Exception as e: except Exception as e:
error = traceback.format_exc() error = traceback.format_exc()
@ -95,6 +104,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
) )
pass pass
# Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled(
graph_execution_state.id
):
continue
# Queue any further commands if invoking all # Queue any further commands if invoking all
is_complete = graph_execution_state.is_complete() is_complete = graph_execution_state.is_complete()

42
invokeai/app/util/util.py Normal file
View File

@ -0,0 +1,42 @@
import torch
from PIL import Image
from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator
from ...backend.stable_diffusion import PipelineIntermediateState
class CanceledException(Exception):
pass
def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id: str, context: InvocationContext, ):
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step,
steps,
)
def diffusers_step_callback_adapter(*cb_args, **kwargs):
"""
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
This adapter grabs the needed data and passes it along to the callback function.
"""
if isinstance(cb_args[0], PipelineIntermediateState):
progress_state: PipelineIntermediateState = cb_args[0]
return fast_latents_step_callback(progress_state.latents, progress_state.step, **kwargs)
else:
return fast_latents_step_callback(*cb_args, **kwargs)

View File

@ -21,7 +21,7 @@ from PIL import Image, ImageChops, ImageFilter
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from tqdm import trange from tqdm import trange
from typing import List, Iterator, Type from typing import Callable, List, Iterator, Optional, Type
from dataclasses import dataclass, field from dataclasses import dataclass, field
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
@ -35,23 +35,23 @@ downsampling = 8
@dataclass @dataclass
class InvokeAIGeneratorBasicParams: class InvokeAIGeneratorBasicParams:
seed: int=None seed: Optional[int]=None
width: int=512 width: int=512
height: int=512 height: int=512
cfg_scale: int=7.5 cfg_scale: float=7.5
steps: int=20 steps: int=20
ddim_eta: float=0.0 ddim_eta: float=0.0
scheduler: int='ddim' scheduler: str='ddim'
precision: str='float16' precision: str='float16'
perlin: float=0.0 perlin: float=0.0
threshold: int=0.0 threshold: float=0.0
seamless: bool=False seamless: bool=False
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y']) seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
h_symmetry_time_pct: float=None h_symmetry_time_pct: Optional[float]=None
v_symmetry_time_pct: float=None v_symmetry_time_pct: Optional[float]=None
variation_amount: float = 0.0 variation_amount: float = 0.0
with_variations: list=field(default_factory=list) with_variations: list=field(default_factory=list)
safety_checker: SafetyChecker=None safety_checker: Optional[SafetyChecker]=None
@dataclass @dataclass
class InvokeAIGeneratorOutput: class InvokeAIGeneratorOutput:
@ -61,10 +61,10 @@ class InvokeAIGeneratorOutput:
and the model hash, as well as all the generate() parameters that went into and the model hash, as well as all the generate() parameters that went into
generating the image (in .params, also available as attributes) generating the image (in .params, also available as attributes)
''' '''
image: Image image: Image.Image
seed: int seed: int
model_hash: str model_hash: str
attention_maps_images: List[Image] attention_maps_images: List[Image.Image]
params: Namespace params: Namespace
# we are interposing a wrapper around the original Generator classes so that # we are interposing a wrapper around the original Generator classes so that
@ -92,8 +92,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def generate(self, def generate(self,
prompt: str='', prompt: str='',
callback: callable=None, callback: Optional[Callable]=None,
step_callback: callable=None, step_callback: Optional[Callable]=None,
iterations: int=1, iterations: int=1,
**keyword_args, **keyword_args,
)->Iterator[InvokeAIGeneratorOutput]: )->Iterator[InvokeAIGeneratorOutput]:
@ -206,10 +206,10 @@ class Txt2Img(InvokeAIGenerator):
# ------------------------------------ # ------------------------------------
class Img2Img(InvokeAIGenerator): class Img2Img(InvokeAIGenerator):
def generate(self, def generate(self,
init_image: Image | torch.FloatTensor, init_image: Image.Image | torch.FloatTensor,
strength: float=0.75, strength: float=0.75,
**keyword_args **keyword_args
)->List[InvokeAIGeneratorOutput]: )->Iterator[InvokeAIGeneratorOutput]:
return super().generate(init_image=init_image, return super().generate(init_image=init_image,
strength=strength, strength=strength,
**keyword_args **keyword_args
@ -223,7 +223,7 @@ class Img2Img(InvokeAIGenerator):
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff # Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
class Inpaint(Img2Img): class Inpaint(Img2Img):
def generate(self, def generate(self,
mask_image: Image | torch.FloatTensor, mask_image: Image.Image | torch.FloatTensor,
# Seam settings - when 0, doesn't fill seam # Seam settings - when 0, doesn't fill seam
seam_size: int = 0, seam_size: int = 0,
seam_blur: int = 0, seam_blur: int = 0,
@ -236,7 +236,7 @@ class Inpaint(Img2Img):
inpaint_height=None, inpaint_height=None,
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF), inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
**keyword_args **keyword_args
)->List[InvokeAIGeneratorOutput]: )->Iterator[InvokeAIGeneratorOutput]:
return super().generate( return super().generate(
mask_image=mask_image, mask_image=mask_image,
seam_size=seam_size, seam_size=seam_size,
@ -263,7 +263,7 @@ class Embiggen(Txt2Img):
embiggen: list=None, embiggen: list=None,
embiggen_tiles: list = None, embiggen_tiles: list = None,
strength: float=0.75, strength: float=0.75,
**kwargs)->List[InvokeAIGeneratorOutput]: **kwargs)->Iterator[InvokeAIGeneratorOutput]:
return super().generate(embiggen=embiggen, return super().generate(embiggen=embiggen,
embiggen_tiles=embiggen_tiles, embiggen_tiles=embiggen_tiles,
strength=strength, strength=strength,