2022-12-01 05:33:20 +00:00
|
|
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
|
2023-03-10 04:28:06 +00:00
|
|
|
from functools import partial
|
2023-03-25 20:07:18 +00:00
|
|
|
from typing import Literal, Optional, Union
|
2023-03-03 06:02:00 +00:00
|
|
|
|
2022-12-01 05:33:20 +00:00
|
|
|
import numpy as np
|
2023-03-18 10:37:50 +00:00
|
|
|
from torch import Tensor
|
2023-03-15 12:50:26 +00:00
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
from pydantic import Field
|
|
|
|
|
2023-04-04 01:05:15 +00:00
|
|
|
from invokeai.app.models.image import ImageField, ImageType
|
2023-03-03 06:02:00 +00:00
|
|
|
from .baseinvocation import BaseInvocation, InvocationContext
|
2023-04-04 01:05:15 +00:00
|
|
|
from .image import ImageOutput
|
2023-03-18 10:37:50 +00:00
|
|
|
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
2023-03-10 04:28:06 +00:00
|
|
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
2023-04-04 01:05:15 +00:00
|
|
|
from ..models.exceptions import CanceledException
|
|
|
|
from ..util.step_callback import diffusers_step_callback_adapter
|
2022-12-01 05:33:20 +00:00
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
SAMPLER_NAME_VALUES = Literal[
|
2023-03-09 05:18:29 +00:00
|
|
|
tuple(InvokeAIGenerator.schedulers())
|
2023-03-03 06:02:00 +00:00
|
|
|
]
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
# Text to image
|
|
|
|
class TextToImageInvocation(BaseInvocation):
|
|
|
|
"""Generates an image using text2img."""
|
2023-03-03 06:02:00 +00:00
|
|
|
|
|
|
|
type: Literal["txt2img"] = "txt2img"
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
# Inputs
|
|
|
|
# TODO: consider making prompt optional to enable providing prompt through a link
|
2023-03-03 19:59:17 +00:00
|
|
|
# fmt: off
|
2023-03-03 06:02:00 +00:00
|
|
|
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
2023-03-03 19:59:17 +00:00
|
|
|
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
|
|
|
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
|
|
|
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
|
|
|
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
|
|
|
|
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
2023-04-10 08:13:23 +00:00
|
|
|
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
2023-03-03 19:59:17 +00:00
|
|
|
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
|
|
|
model: str = Field(default="", description="The model to use (currently ignored)")
|
|
|
|
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
|
|
|
# fmt: on
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
# TODO: pass this an emitter method or something? or a session for dispatching?
|
2023-03-03 06:02:00 +00:00
|
|
|
def dispatch_progress(
|
2023-03-10 04:28:06 +00:00
|
|
|
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
2023-03-03 06:02:00 +00:00
|
|
|
) -> None:
|
2023-03-18 10:37:50 +00:00
|
|
|
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
|
|
|
raise CanceledException
|
|
|
|
|
2023-03-10 04:28:06 +00:00
|
|
|
step = intermediate_state.step
|
2023-03-25 20:07:18 +00:00
|
|
|
if intermediate_state.predicted_original is not None:
|
|
|
|
# Some schedulers report not only the noisy latents at the current timestep,
|
|
|
|
# but also their estimate so far of what the de-noised latents will be.
|
|
|
|
sample = intermediate_state.predicted_original
|
|
|
|
else:
|
2023-03-18 10:37:50 +00:00
|
|
|
sample = intermediate_state.latents
|
|
|
|
|
|
|
|
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
2023-03-18 10:37:50 +00:00
|
|
|
# def step_callback(state: PipelineIntermediateState):
|
|
|
|
# if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
|
|
|
# raise CanceledException
|
|
|
|
# self.dispatch_progress(context, state.latents, state.step)
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
# Handle invalid model parameter
|
|
|
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
|
|
|
# TODO: How to get the default model name now?
|
2023-03-11 14:06:46 +00:00
|
|
|
# (right now uses whatever current model is set in model manager)
|
|
|
|
model= context.services.model_manager.get_model()
|
|
|
|
outputs = Txt2Img(model).generate(
|
2023-03-03 06:02:00 +00:00
|
|
|
prompt=self.prompt,
|
2023-03-10 04:28:06 +00:00
|
|
|
step_callback=partial(self.dispatch_progress, context),
|
2023-03-03 06:02:00 +00:00
|
|
|
**self.dict(
|
|
|
|
exclude={"prompt"}
|
|
|
|
), # Shorthand for passing all of the parameters above manually
|
2022-12-01 05:33:20 +00:00
|
|
|
)
|
2023-03-09 05:18:29 +00:00
|
|
|
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
|
|
|
# each time it is called. We only need the first one.
|
|
|
|
generate_output = next(outputs)
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
# Results are image and seed, unwrap for now and ignore the seed
|
|
|
|
# TODO: pre-seed?
|
|
|
|
# TODO: can this return multiple results? Should it?
|
|
|
|
image_type = ImageType.RESULT
|
2023-03-03 06:02:00 +00:00
|
|
|
image_name = context.services.images.create_name(
|
|
|
|
context.graph_execution_state_id, self.id
|
|
|
|
)
|
2023-03-09 05:18:29 +00:00
|
|
|
context.services.images.save(image_type, image_name, generate_output.image)
|
2022-12-01 05:33:20 +00:00
|
|
|
return ImageOutput(
|
2023-03-03 06:02:00 +00:00
|
|
|
image=ImageField(image_type=image_type, image_name=image_name)
|
2022-12-01 05:33:20 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class ImageToImageInvocation(TextToImageInvocation):
|
|
|
|
"""Generates an image using img2img."""
|
2023-03-03 06:02:00 +00:00
|
|
|
|
|
|
|
type: Literal["img2img"] = "img2img"
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
# Inputs
|
2023-03-03 06:02:00 +00:00
|
|
|
image: Union[ImageField, None] = Field(description="The input image")
|
|
|
|
strength: float = Field(
|
|
|
|
default=0.75, gt=0, le=1, description="The strength of the original image"
|
|
|
|
)
|
|
|
|
fit: bool = Field(
|
|
|
|
default=True,
|
|
|
|
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
|
|
|
)
|
2022-12-01 05:33:20 +00:00
|
|
|
|
2023-03-18 10:37:50 +00:00
|
|
|
def dispatch_progress(
|
|
|
|
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
|
|
|
) -> None:
|
|
|
|
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
|
|
|
raise CanceledException
|
|
|
|
|
|
|
|
step = intermediate_state.step
|
|
|
|
if intermediate_state.predicted_original is not None:
|
|
|
|
# Some schedulers report not only the noisy latents at the current timestep,
|
|
|
|
# but also their estimate so far of what the de-noised latents will be.
|
|
|
|
sample = intermediate_state.predicted_original
|
|
|
|
else:
|
|
|
|
sample = intermediate_state.latents
|
|
|
|
|
|
|
|
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
|
|
|
|
2022-12-01 05:33:20 +00:00
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
2023-03-03 06:02:00 +00:00
|
|
|
image = (
|
|
|
|
None
|
|
|
|
if self.image is None
|
|
|
|
else context.services.images.get(
|
|
|
|
self.image.image_type, self.image.image_name
|
|
|
|
)
|
|
|
|
)
|
|
|
|
mask = None
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
# Handle invalid model parameter
|
|
|
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
|
|
|
# TODO: How to get the default model name now?
|
2023-03-11 14:06:46 +00:00
|
|
|
model = context.services.model_manager.get_model()
|
2023-03-18 10:37:50 +00:00
|
|
|
outputs = Img2Img(model).generate(
|
2023-03-09 05:18:29 +00:00
|
|
|
prompt=self.prompt,
|
2023-03-13 05:12:42 +00:00
|
|
|
init_image=image,
|
2023-03-09 05:18:29 +00:00
|
|
|
init_mask=mask,
|
2023-03-18 10:37:50 +00:00
|
|
|
step_callback=partial(self.dispatch_progress, context),
|
2023-03-09 05:18:29 +00:00
|
|
|
**self.dict(
|
|
|
|
exclude={"prompt", "image", "mask"}
|
|
|
|
), # Shorthand for passing all of the parameters above manually
|
|
|
|
)
|
2023-03-18 10:37:50 +00:00
|
|
|
|
|
|
|
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
|
|
|
# each time it is called. We only need the first one.
|
|
|
|
generator_output = next(outputs)
|
2022-12-01 05:33:20 +00:00
|
|
|
|
2023-03-09 05:18:29 +00:00
|
|
|
result_image = generator_output.image
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
# Results are image and seed, unwrap for now and ignore the seed
|
|
|
|
# TODO: pre-seed?
|
|
|
|
# TODO: can this return multiple results? Should it?
|
|
|
|
image_type = ImageType.RESULT
|
2023-03-03 06:02:00 +00:00
|
|
|
image_name = context.services.images.create_name(
|
|
|
|
context.graph_execution_state_id, self.id
|
|
|
|
)
|
2022-12-01 05:33:20 +00:00
|
|
|
context.services.images.save(image_type, image_name, result_image)
|
|
|
|
return ImageOutput(
|
2023-03-03 06:02:00 +00:00
|
|
|
image=ImageField(image_type=image_type, image_name=image_name)
|
2022-12-01 05:33:20 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
class InpaintInvocation(ImageToImageInvocation):
|
|
|
|
"""Generates an image using inpaint."""
|
2023-03-03 06:02:00 +00:00
|
|
|
|
|
|
|
type: Literal["inpaint"] = "inpaint"
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
# Inputs
|
2023-03-03 06:02:00 +00:00
|
|
|
mask: Union[ImageField, None] = Field(description="The mask")
|
|
|
|
inpaint_replace: float = Field(
|
|
|
|
default=0.0,
|
|
|
|
ge=0.0,
|
|
|
|
le=1.0,
|
|
|
|
description="The amount by which to replace masked areas with latent noise",
|
|
|
|
)
|
2022-12-01 05:33:20 +00:00
|
|
|
|
2023-03-18 10:37:50 +00:00
|
|
|
def dispatch_progress(
|
|
|
|
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
|
|
|
) -> None:
|
|
|
|
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
|
|
|
raise CanceledException
|
|
|
|
|
|
|
|
step = intermediate_state.step
|
|
|
|
if intermediate_state.predicted_original is not None:
|
|
|
|
# Some schedulers report not only the noisy latents at the current timestep,
|
|
|
|
# but also their estimate so far of what the de-noised latents will be.
|
|
|
|
sample = intermediate_state.predicted_original
|
|
|
|
else:
|
|
|
|
sample = intermediate_state.latents
|
|
|
|
|
|
|
|
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
|
|
|
|
2022-12-01 05:33:20 +00:00
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
2023-03-03 06:02:00 +00:00
|
|
|
image = (
|
|
|
|
None
|
|
|
|
if self.image is None
|
|
|
|
else context.services.images.get(
|
|
|
|
self.image.image_type, self.image.image_name
|
|
|
|
)
|
|
|
|
)
|
|
|
|
mask = (
|
|
|
|
None
|
|
|
|
if self.mask is None
|
|
|
|
else context.services.images.get(self.mask.image_type, self.mask.image_name)
|
|
|
|
)
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
# Handle invalid model parameter
|
|
|
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
|
|
|
# TODO: How to get the default model name now?
|
2023-03-25 20:07:18 +00:00
|
|
|
model = context.services.model_manager.get_model()
|
2023-03-18 10:37:50 +00:00
|
|
|
outputs = Inpaint(model).generate(
|
2023-03-09 05:18:29 +00:00
|
|
|
prompt=self.prompt,
|
|
|
|
init_img=image,
|
|
|
|
init_mask=mask,
|
2023-03-25 20:07:18 +00:00
|
|
|
step_callback=partial(self.dispatch_progress, context),
|
2023-03-09 05:18:29 +00:00
|
|
|
**self.dict(
|
|
|
|
exclude={"prompt", "image", "mask"}
|
|
|
|
), # Shorthand for passing all of the parameters above manually
|
|
|
|
)
|
2023-03-18 10:37:50 +00:00
|
|
|
|
|
|
|
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
|
|
|
# each time it is called. We only need the first one.
|
|
|
|
generator_output = next(outputs)
|
2022-12-01 05:33:20 +00:00
|
|
|
|
2023-03-09 05:18:29 +00:00
|
|
|
result_image = generator_output.image
|
2022-12-01 05:33:20 +00:00
|
|
|
|
|
|
|
# Results are image and seed, unwrap for now and ignore the seed
|
|
|
|
# TODO: pre-seed?
|
|
|
|
# TODO: can this return multiple results? Should it?
|
|
|
|
image_type = ImageType.RESULT
|
2023-03-03 06:02:00 +00:00
|
|
|
image_name = context.services.images.create_name(
|
|
|
|
context.graph_execution_state_id, self.id
|
|
|
|
)
|
2022-12-01 05:33:20 +00:00
|
|
|
context.services.images.save(image_type, image_name, result_image)
|
|
|
|
return ImageOutput(
|
2023-03-03 06:02:00 +00:00
|
|
|
image=ImageField(image_type=image_type, image_name=image_name)
|
2022-12-01 05:33:20 +00:00
|
|
|
)
|