mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into fix/ui/viewer-localisation
This commit is contained in:
commit
3ba7e966b5
@ -270,3 +270,18 @@ async def invoke_session(
|
|||||||
|
|
||||||
ApiDependencies.invoker.invoke(session, invoke_all=all)
|
ApiDependencies.invoker.invoke(session, invoke_all=all)
|
||||||
return Response(status_code=202)
|
return Response(status_code=202)
|
||||||
|
|
||||||
|
|
||||||
|
@session_router.delete(
|
||||||
|
"/{session_id}/invoke",
|
||||||
|
operation_id="cancel_session_invoke",
|
||||||
|
responses={
|
||||||
|
202: {"description": "The invocation is canceled"}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def cancel_session_invoke(
|
||||||
|
session_id: str = Path(description="The id of the session to cancel"),
|
||||||
|
) -> None:
|
||||||
|
"""Invokes a session"""
|
||||||
|
ApiDependencies.invoker.cancel(session_id)
|
||||||
|
return Response(status_code=202)
|
||||||
|
@ -4,15 +4,16 @@ from functools import partial
|
|||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from ..services.image_storage import ImageType
|
from ..services.image_storage import ImageType
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageField, ImageOutput
|
from .image import ImageField, ImageOutput
|
||||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator
|
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.util.util import image_to_dataURL
|
from ..util.util import diffusers_step_callback_adapter, CanceledException
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[
|
SAMPLER_NAME_VALUES = Literal[
|
||||||
tuple(InvokeAIGenerator.schedulers())
|
tuple(InvokeAIGenerator.schedulers())
|
||||||
@ -43,33 +44,24 @@ class TextToImageInvocation(BaseInvocation):
|
|||||||
def dispatch_progress(
|
def dispatch_progress(
|
||||||
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||||
|
raise CanceledException
|
||||||
|
|
||||||
step = intermediate_state.step
|
step = intermediate_state.step
|
||||||
if intermediate_state.predicted_original is not None:
|
if intermediate_state.predicted_original is not None:
|
||||||
# Some schedulers report not only the noisy latents at the current timestep,
|
# Some schedulers report not only the noisy latents at the current timestep,
|
||||||
# but also their estimate so far of what the de-noised latents will be.
|
# but also their estimate so far of what the de-noised latents will be.
|
||||||
|
|
||||||
sample = intermediate_state.predicted_original
|
sample = intermediate_state.predicted_original
|
||||||
else:
|
else:
|
||||||
sample = intermediate_state.latents
|
sample = intermediate_state.latents
|
||||||
|
|
||||||
image = Generator(context.services.model_manager.get_model()).sample_to_image(sample)
|
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
||||||
(width, height) = image.size
|
|
||||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
|
||||||
context.services.events.emit_generator_progress(
|
|
||||||
context.graph_execution_state_id,
|
|
||||||
self.id,
|
|
||||||
{
|
|
||||||
"width" : width,
|
|
||||||
"height": height,
|
|
||||||
"dataURL": dataURL,
|
|
||||||
},
|
|
||||||
step,
|
|
||||||
self.steps,
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
def step_callback(state: PipelineIntermediateState):
|
# def step_callback(state: PipelineIntermediateState):
|
||||||
self.dispatch_progress(context, state.latents, state.step)
|
# if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||||
|
# raise CanceledException
|
||||||
|
# self.dispatch_progress(context, state.latents, state.step)
|
||||||
|
|
||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
@ -115,6 +107,22 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def dispatch_progress(
|
||||||
|
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
||||||
|
) -> None:
|
||||||
|
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||||
|
raise CanceledException
|
||||||
|
|
||||||
|
step = intermediate_state.step
|
||||||
|
if intermediate_state.predicted_original is not None:
|
||||||
|
# Some schedulers report not only the noisy latents at the current timestep,
|
||||||
|
# but also their estimate so far of what the de-noised latents will be.
|
||||||
|
sample = intermediate_state.predicted_original
|
||||||
|
else:
|
||||||
|
sample = intermediate_state.latents
|
||||||
|
|
||||||
|
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = (
|
image = (
|
||||||
None
|
None
|
||||||
@ -129,17 +137,19 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
# TODO: How to get the default model name now?
|
# TODO: How to get the default model name now?
|
||||||
model = context.services.model_manager.get_model()
|
model = context.services.model_manager.get_model()
|
||||||
generator_output = next(
|
outputs = Img2Img(model).generate(
|
||||||
Img2Img(model).generate(
|
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
init_image=image,
|
init_image=image,
|
||||||
init_mask=mask,
|
init_mask=mask,
|
||||||
step_callback=partial(self.dispatch_progress, context),
|
step_callback=partial(self.dispatch_progress, context),
|
||||||
**self.dict(
|
**self.dict(
|
||||||
exclude={"prompt", "image", "mask"}
|
exclude={"prompt", "image", "mask"}
|
||||||
), # Shorthand for passing all of the parameters above manually
|
), # Shorthand for passing all of the parameters above manually
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||||
|
# each time it is called. We only need the first one.
|
||||||
|
generator_output = next(outputs)
|
||||||
|
|
||||||
result_image = generator_output.image
|
result_image = generator_output.image
|
||||||
|
|
||||||
@ -169,6 +179,22 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
description="The amount by which to replace masked areas with latent noise",
|
description="The amount by which to replace masked areas with latent noise",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def dispatch_progress(
|
||||||
|
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
||||||
|
) -> None:
|
||||||
|
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||||
|
raise CanceledException
|
||||||
|
|
||||||
|
step = intermediate_state.step
|
||||||
|
if intermediate_state.predicted_original is not None:
|
||||||
|
# Some schedulers report not only the noisy latents at the current timestep,
|
||||||
|
# but also their estimate so far of what the de-noised latents will be.
|
||||||
|
sample = intermediate_state.predicted_original
|
||||||
|
else:
|
||||||
|
sample = intermediate_state.latents
|
||||||
|
|
||||||
|
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = (
|
image = (
|
||||||
None
|
None
|
||||||
@ -187,8 +213,7 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
# TODO: How to get the default model name now?
|
# TODO: How to get the default model name now?
|
||||||
model = context.services.model_manager.get_model()
|
model = context.services.model_manager.get_model()
|
||||||
generator_output = next(
|
outputs = Inpaint(model).generate(
|
||||||
Inpaint(model).generate(
|
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
init_img=image,
|
init_img=image,
|
||||||
init_mask=mask,
|
init_mask=mask,
|
||||||
@ -197,7 +222,10 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
exclude={"prompt", "image", "mask"}
|
exclude={"prompt", "image", "mask"}
|
||||||
), # Shorthand for passing all of the parameters above manually
|
), # Shorthand for passing all of the parameters above manually
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||||
|
# each time it is called. We only need the first one.
|
||||||
|
generator_output = next(outputs)
|
||||||
|
|
||||||
result_image = generator_output.image
|
result_image = generator_output.image
|
||||||
|
|
||||||
|
@ -28,12 +28,28 @@ class ImageOutput(BaseInvocationOutput):
|
|||||||
image: ImageField = Field(default=None, description="The output image")
|
image: ImageField = Field(default=None, description="The output image")
|
||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
'required': [
|
||||||
|
'type',
|
||||||
|
'image',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
class MaskOutput(BaseInvocationOutput):
|
class MaskOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output a mask"""
|
"""Base class for invocations that output a mask"""
|
||||||
#fmt: off
|
#fmt: off
|
||||||
type: Literal["mask"] = "mask"
|
type: Literal["mask"] = "mask"
|
||||||
mask: ImageField = Field(default=None, description="The output mask")
|
mask: ImageField = Field(default=None, description="The output mask")
|
||||||
#fomt: on
|
#fmt: on
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
'required': [
|
||||||
|
'type',
|
||||||
|
'mask',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
# TODO: this isn't really necessary anymore
|
# TODO: this isn't really necessary anymore
|
||||||
class LoadImageInvocation(BaseInvocation):
|
class LoadImageInvocation(BaseInvocation):
|
||||||
|
@ -12,3 +12,11 @@ class PromptOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
prompt: str = Field(default=None, description="The output prompt")
|
prompt: str = Field(default=None, description="The output prompt")
|
||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
'required': [
|
||||||
|
'type',
|
||||||
|
'prompt',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
@ -127,6 +127,13 @@ class NodeAlreadyExecutedError(Exception):
|
|||||||
class GraphInvocationOutput(BaseInvocationOutput):
|
class GraphInvocationOutput(BaseInvocationOutput):
|
||||||
type: Literal["graph_output"] = "graph_output"
|
type: Literal["graph_output"] = "graph_output"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
'required': [
|
||||||
|
'type',
|
||||||
|
'image',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
# TODO: Fill this out and move to invocations
|
# TODO: Fill this out and move to invocations
|
||||||
class GraphInvocation(BaseInvocation):
|
class GraphInvocation(BaseInvocation):
|
||||||
@ -147,6 +154,13 @@ class IterateInvocationOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
item: Any = Field(description="The item being iterated over")
|
item: Any = Field(description="The item being iterated over")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
'required': [
|
||||||
|
'type',
|
||||||
|
'item',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
# TODO: Fill this out and move to invocations
|
# TODO: Fill this out and move to invocations
|
||||||
class IterateInvocation(BaseInvocation):
|
class IterateInvocation(BaseInvocation):
|
||||||
@ -169,6 +183,13 @@ class CollectInvocationOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
collection: list[Any] = Field(description="The collection of input items")
|
collection: list[Any] = Field(description="The collection of input items")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
'required': [
|
||||||
|
'type',
|
||||||
|
'collection',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
class CollectInvocation(BaseInvocation):
|
class CollectInvocation(BaseInvocation):
|
||||||
"""Collects values into a collection"""
|
"""Collects values into a collection"""
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
# TODO: make this serializable
|
# TODO: make this serializable
|
||||||
@ -10,6 +11,7 @@ class InvocationQueueItem:
|
|||||||
graph_execution_state_id: str
|
graph_execution_state_id: str
|
||||||
invocation_id: str
|
invocation_id: str
|
||||||
invoke_all: bool
|
invoke_all: bool
|
||||||
|
timestamp: float
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -22,6 +24,7 @@ class InvocationQueueItem:
|
|||||||
self.graph_execution_state_id = graph_execution_state_id
|
self.graph_execution_state_id = graph_execution_state_id
|
||||||
self.invocation_id = invocation_id
|
self.invocation_id = invocation_id
|
||||||
self.invoke_all = invoke_all
|
self.invoke_all = invoke_all
|
||||||
|
self.timestamp = time.time()
|
||||||
|
|
||||||
|
|
||||||
class InvocationQueueABC(ABC):
|
class InvocationQueueABC(ABC):
|
||||||
@ -35,15 +38,44 @@ class InvocationQueueABC(ABC):
|
|||||||
def put(self, item: InvocationQueueItem | None) -> None:
|
def put(self, item: InvocationQueueItem | None) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cancel(self, graph_execution_state_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def is_canceled(self, graph_execution_state_id: str) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MemoryInvocationQueue(InvocationQueueABC):
|
class MemoryInvocationQueue(InvocationQueueABC):
|
||||||
__queue: Queue
|
__queue: Queue
|
||||||
|
__cancellations: dict[str, float]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.__queue = Queue()
|
self.__queue = Queue()
|
||||||
|
self.__cancellations = dict()
|
||||||
|
|
||||||
def get(self) -> InvocationQueueItem:
|
def get(self) -> InvocationQueueItem:
|
||||||
return self.__queue.get()
|
item = self.__queue.get()
|
||||||
|
|
||||||
|
while isinstance(item, InvocationQueueItem) \
|
||||||
|
and item.graph_execution_state_id in self.__cancellations \
|
||||||
|
and self.__cancellations[item.graph_execution_state_id] > item.timestamp:
|
||||||
|
item = self.__queue.get()
|
||||||
|
|
||||||
|
# Clear old items
|
||||||
|
for graph_execution_state_id in list(self.__cancellations.keys()):
|
||||||
|
if self.__cancellations[graph_execution_state_id] < item.timestamp:
|
||||||
|
del self.__cancellations[graph_execution_state_id]
|
||||||
|
|
||||||
|
return item
|
||||||
|
|
||||||
def put(self, item: InvocationQueueItem | None) -> None:
|
def put(self, item: InvocationQueueItem | None) -> None:
|
||||||
self.__queue.put(item)
|
self.__queue.put(item)
|
||||||
|
|
||||||
|
def cancel(self, graph_execution_state_id: str) -> None:
|
||||||
|
if graph_execution_state_id not in self.__cancellations:
|
||||||
|
self.__cancellations[graph_execution_state_id] = time.time()
|
||||||
|
|
||||||
|
def is_canceled(self, graph_execution_state_id: str) -> bool:
|
||||||
|
return graph_execution_state_id in self.__cancellations
|
||||||
|
@ -50,6 +50,10 @@ class Invoker:
|
|||||||
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
|
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
|
||||||
self.services.graph_execution_manager.set(new_state)
|
self.services.graph_execution_manager.set(new_state)
|
||||||
return new_state
|
return new_state
|
||||||
|
|
||||||
|
def cancel(self, graph_execution_state_id: str) -> None:
|
||||||
|
"""Cancels the given execution state"""
|
||||||
|
self.services.queue.cancel(graph_execution_state_id)
|
||||||
|
|
||||||
def __start_service(self, service) -> None:
|
def __start_service(self, service) -> None:
|
||||||
# Call start() method on any services that have it
|
# Call start() method on any services that have it
|
||||||
|
@ -4,7 +4,7 @@ from threading import Event, Thread
|
|||||||
from ..invocations.baseinvocation import InvocationContext
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
from .invocation_queue import InvocationQueueItem
|
from .invocation_queue import InvocationQueueItem
|
||||||
from .invoker import InvocationProcessorABC, Invoker
|
from .invoker import InvocationProcessorABC, Invoker
|
||||||
|
from ..util.util import CanceledException
|
||||||
|
|
||||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||||
__invoker_thread: Thread
|
__invoker_thread: Thread
|
||||||
@ -58,6 +58,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check queue to see if this is canceled, and skip if so
|
||||||
|
if self.__invoker.services.queue.is_canceled(
|
||||||
|
graph_execution_state.id
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
# Save outputs and history
|
# Save outputs and history
|
||||||
graph_execution_state.complete(invocation.id, outputs)
|
graph_execution_state.complete(invocation.id, outputs)
|
||||||
|
|
||||||
@ -76,6 +82,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
except CanceledException:
|
||||||
|
pass
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error = traceback.format_exc()
|
error = traceback.format_exc()
|
||||||
|
|
||||||
@ -95,6 +104,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Check queue to see if this is canceled, and skip if so
|
||||||
|
if self.__invoker.services.queue.is_canceled(
|
||||||
|
graph_execution_state.id
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
# Queue any further commands if invoking all
|
# Queue any further commands if invoking all
|
||||||
is_complete = graph_execution_state.is_complete()
|
is_complete = graph_execution_state.is_complete()
|
||||||
|
42
invokeai/app/util/util.py
Normal file
42
invokeai/app/util/util.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
|
from ...backend.util.util import image_to_dataURL
|
||||||
|
from ...backend.generator.base import Generator
|
||||||
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
|
|
||||||
|
class CanceledException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id: str, context: InvocationContext, ):
|
||||||
|
# TODO: only output a preview image when requested
|
||||||
|
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||||
|
|
||||||
|
(width, height) = image.size
|
||||||
|
width *= 8
|
||||||
|
height *= 8
|
||||||
|
|
||||||
|
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||||
|
|
||||||
|
context.services.events.emit_generator_progress(
|
||||||
|
context.graph_execution_state_id,
|
||||||
|
id,
|
||||||
|
{
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"dataURL": dataURL
|
||||||
|
},
|
||||||
|
step,
|
||||||
|
steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def diffusers_step_callback_adapter(*cb_args, **kwargs):
|
||||||
|
"""
|
||||||
|
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
|
||||||
|
This adapter grabs the needed data and passes it along to the callback function.
|
||||||
|
"""
|
||||||
|
if isinstance(cb_args[0], PipelineIntermediateState):
|
||||||
|
progress_state: PipelineIntermediateState = cb_args[0]
|
||||||
|
return fast_latents_step_callback(progress_state.latents, progress_state.step, **kwargs)
|
||||||
|
else:
|
||||||
|
return fast_latents_step_callback(*cb_args, **kwargs)
|
@ -21,7 +21,7 @@ from PIL import Image, ImageChops, ImageFilter
|
|||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
from typing import List, Iterator, Type
|
from typing import Callable, List, Iterator, Optional, Type
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
|
||||||
@ -35,23 +35,23 @@ downsampling = 8
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InvokeAIGeneratorBasicParams:
|
class InvokeAIGeneratorBasicParams:
|
||||||
seed: int=None
|
seed: Optional[int]=None
|
||||||
width: int=512
|
width: int=512
|
||||||
height: int=512
|
height: int=512
|
||||||
cfg_scale: int=7.5
|
cfg_scale: float=7.5
|
||||||
steps: int=20
|
steps: int=20
|
||||||
ddim_eta: float=0.0
|
ddim_eta: float=0.0
|
||||||
scheduler: int='ddim'
|
scheduler: str='ddim'
|
||||||
precision: str='float16'
|
precision: str='float16'
|
||||||
perlin: float=0.0
|
perlin: float=0.0
|
||||||
threshold: int=0.0
|
threshold: float=0.0
|
||||||
seamless: bool=False
|
seamless: bool=False
|
||||||
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
|
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
|
||||||
h_symmetry_time_pct: float=None
|
h_symmetry_time_pct: Optional[float]=None
|
||||||
v_symmetry_time_pct: float=None
|
v_symmetry_time_pct: Optional[float]=None
|
||||||
variation_amount: float = 0.0
|
variation_amount: float = 0.0
|
||||||
with_variations: list=field(default_factory=list)
|
with_variations: list=field(default_factory=list)
|
||||||
safety_checker: SafetyChecker=None
|
safety_checker: Optional[SafetyChecker]=None
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InvokeAIGeneratorOutput:
|
class InvokeAIGeneratorOutput:
|
||||||
@ -61,10 +61,10 @@ class InvokeAIGeneratorOutput:
|
|||||||
and the model hash, as well as all the generate() parameters that went into
|
and the model hash, as well as all the generate() parameters that went into
|
||||||
generating the image (in .params, also available as attributes)
|
generating the image (in .params, also available as attributes)
|
||||||
'''
|
'''
|
||||||
image: Image
|
image: Image.Image
|
||||||
seed: int
|
seed: int
|
||||||
model_hash: str
|
model_hash: str
|
||||||
attention_maps_images: List[Image]
|
attention_maps_images: List[Image.Image]
|
||||||
params: Namespace
|
params: Namespace
|
||||||
|
|
||||||
# we are interposing a wrapper around the original Generator classes so that
|
# we are interposing a wrapper around the original Generator classes so that
|
||||||
@ -92,8 +92,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
prompt: str='',
|
prompt: str='',
|
||||||
callback: callable=None,
|
callback: Optional[Callable]=None,
|
||||||
step_callback: callable=None,
|
step_callback: Optional[Callable]=None,
|
||||||
iterations: int=1,
|
iterations: int=1,
|
||||||
**keyword_args,
|
**keyword_args,
|
||||||
)->Iterator[InvokeAIGeneratorOutput]:
|
)->Iterator[InvokeAIGeneratorOutput]:
|
||||||
@ -206,10 +206,10 @@ class Txt2Img(InvokeAIGenerator):
|
|||||||
# ------------------------------------
|
# ------------------------------------
|
||||||
class Img2Img(InvokeAIGenerator):
|
class Img2Img(InvokeAIGenerator):
|
||||||
def generate(self,
|
def generate(self,
|
||||||
init_image: Image | torch.FloatTensor,
|
init_image: Image.Image | torch.FloatTensor,
|
||||||
strength: float=0.75,
|
strength: float=0.75,
|
||||||
**keyword_args
|
**keyword_args
|
||||||
)->List[InvokeAIGeneratorOutput]:
|
)->Iterator[InvokeAIGeneratorOutput]:
|
||||||
return super().generate(init_image=init_image,
|
return super().generate(init_image=init_image,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
**keyword_args
|
**keyword_args
|
||||||
@ -223,7 +223,7 @@ class Img2Img(InvokeAIGenerator):
|
|||||||
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
||||||
class Inpaint(Img2Img):
|
class Inpaint(Img2Img):
|
||||||
def generate(self,
|
def generate(self,
|
||||||
mask_image: Image | torch.FloatTensor,
|
mask_image: Image.Image | torch.FloatTensor,
|
||||||
# Seam settings - when 0, doesn't fill seam
|
# Seam settings - when 0, doesn't fill seam
|
||||||
seam_size: int = 0,
|
seam_size: int = 0,
|
||||||
seam_blur: int = 0,
|
seam_blur: int = 0,
|
||||||
@ -236,7 +236,7 @@ class Inpaint(Img2Img):
|
|||||||
inpaint_height=None,
|
inpaint_height=None,
|
||||||
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||||
**keyword_args
|
**keyword_args
|
||||||
)->List[InvokeAIGeneratorOutput]:
|
)->Iterator[InvokeAIGeneratorOutput]:
|
||||||
return super().generate(
|
return super().generate(
|
||||||
mask_image=mask_image,
|
mask_image=mask_image,
|
||||||
seam_size=seam_size,
|
seam_size=seam_size,
|
||||||
@ -263,7 +263,7 @@ class Embiggen(Txt2Img):
|
|||||||
embiggen: list=None,
|
embiggen: list=None,
|
||||||
embiggen_tiles: list = None,
|
embiggen_tiles: list = None,
|
||||||
strength: float=0.75,
|
strength: float=0.75,
|
||||||
**kwargs)->List[InvokeAIGeneratorOutput]:
|
**kwargs)->Iterator[InvokeAIGeneratorOutput]:
|
||||||
return super().generate(embiggen=embiggen,
|
return super().generate(embiggen=embiggen,
|
||||||
embiggen_tiles=embiggen_tiles,
|
embiggen_tiles=embiggen_tiles,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
|
Loading…
Reference in New Issue
Block a user