feat(nodes): add invocation schema customisation, add model selection

- add invocation schema customisation

done via fastapi's `Config` class and `schema_extra`. when using `Config`, inherit from `InvocationConfig` to get type hints.

where it makes sense - like for all math invocations - define a `MathInvocationConfig` class and have all invocations inherit from it.

this customisation can provide any arbitrary additional data to the UI. currently it provides tags and field type hints.

this is necessary for `model` type fields, which are actually string fields. without something like this, we can't reliably differentiate  `model` fields from normal `string` fields.

can also be used for future field types.

all invocations now have tags, and all `model` fields have ui type hints.

- fix model handling for invocations

added a helper to fall back to the default model if an invalid model name is chosen. model names in graphs now work.

- fix latents progress callback

noticed this wasn't correct while working on everything else.
This commit is contained in:
psychedelicious 2023-04-10 19:07:48 +10:00
parent 427db7c7e2
commit 07e3a0ec15
9 changed files with 228 additions and 64 deletions

View File

@ -5,14 +5,27 @@ from typing import Literal
import cv2 as cv
import numpy
from PIL import Image, ImageOps
from pydantic import Field
from pydantic import BaseModel, Field
from invokeai.app.invocations.models.config import InvocationConfig
from invokeai.app.models.image import ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageOutput
class CvInpaintInvocation(BaseInvocation):
class CVInvocation(BaseModel):
"""Helper class to provide all OpenCV invocations with additional config"""
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["cv", "image"],
},
}
class CvInpaintInvocation(BaseInvocation, CVInvocation):
"""Simple inpaint using opencv."""
#fmt: off
type: Literal["cv_inpaint"] = "cv_inpaint"

View File

@ -6,9 +6,13 @@ from typing import Literal, Optional, Union
import numpy as np
from torch import Tensor
from pydantic import Field
from pydantic import BaseModel, Field
from invokeai.app.invocations.models.config import (
InvocationConfig,
)
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.invocations.util.get_model import choose_model
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
@ -16,12 +20,26 @@ from ...backend.stable_diffusion import PipelineIntermediateState
from ..models.exceptions import CanceledException
from ..util.step_callback import diffusers_step_callback_adapter
SAMPLER_NAME_VALUES = Literal[
tuple(InvokeAIGenerator.schedulers())
]
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
class SDImageInvocation(BaseModel):
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["stable-diffusion", "image"],
"type_hints": {
"model": "model",
},
},
}
# Text to image
class TextToImageInvocation(BaseInvocation):
class TextToImageInvocation(BaseInvocation, SDImageInvocation):
"""Generates an image using text2img."""
type: Literal["txt2img"] = "txt2img"
@ -59,16 +77,9 @@ class TextToImageInvocation(BaseInvocation):
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
def invoke(self, context: InvocationContext) -> ImageOutput:
# def step_callback(state: PipelineIntermediateState):
# if (context.services.queue.is_canceled(context.graph_execution_state_id)):
# raise CanceledException
# self.dispatch_progress(context, state.latents, state.step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
# (right now uses whatever current model is set in model manager)
model= context.services.model_manager.get_model()
model = choose_model(context.services.model_manager, self.model)
outputs = Txt2Img(model).generate(
prompt=self.prompt,
step_callback=partial(self.dispatch_progress, context),
@ -135,9 +146,8 @@ class ImageToImageInvocation(TextToImageInvocation):
mask = None
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
model = context.services.model_manager.get_model()
model = choose_model(context.services.model_manager, self.model)
outputs = Img2Img(model).generate(
prompt=self.prompt,
init_image=image,
@ -211,9 +221,8 @@ class InpaintInvocation(ImageToImageInvocation):
)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
model = context.services.model_manager.get_model()
model = choose_model(context.services.model_manager, self.model)
outputs = Inpaint(model).generate(
prompt=self.prompt,
init_img=image,

View File

@ -7,10 +7,23 @@ import numpy
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field
from invokeai.app.invocations.models.config import InvocationConfig
from ..models.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
class PILInvocationConfig(BaseModel):
"""Helper class to provide all PIL invocations with additional config"""
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["PIL", "image"],
},
}
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
#fmt: off
@ -82,7 +95,7 @@ class ShowImageInvocation(BaseInvocation):
)
class CropImageInvocation(BaseInvocation):
class CropImageInvocation(BaseInvocation, PILInvocationConfig):
"""Crops an image to a specified box. The box can be outside of the image."""
#fmt: off
type: Literal["crop"] = "crop"
@ -115,7 +128,7 @@ class CropImageInvocation(BaseInvocation):
)
class PasteImageInvocation(BaseInvocation):
class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
"""Pastes an image into another image."""
#fmt: off
type: Literal["paste"] = "paste"
@ -165,7 +178,7 @@ class PasteImageInvocation(BaseInvocation):
)
class MaskFromAlphaInvocation(BaseInvocation):
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
"""Extracts the alpha channel of an image as a mask."""
#fmt: off
type: Literal["tomask"] = "tomask"
@ -192,7 +205,7 @@ class MaskFromAlphaInvocation(BaseInvocation):
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
class BlurInvocation(BaseInvocation):
class BlurInvocation(BaseInvocation, PILInvocationConfig):
"""Blurs an image"""
#fmt: off
@ -226,7 +239,7 @@ class BlurInvocation(BaseInvocation):
)
class LerpInvocation(BaseInvocation):
class LerpInvocation(BaseInvocation, PILInvocationConfig):
"""Linear interpolation of all pixels of an image"""
#fmt: off
type: Literal["lerp"] = "lerp"
@ -257,7 +270,7 @@ class LerpInvocation(BaseInvocation):
)
class InverseLerpInvocation(BaseInvocation):
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
"""Inverse linear interpolation of all pixels of an image"""
#fmt: off
type: Literal["ilerp"] = "ilerp"

View File

@ -2,9 +2,13 @@
from typing import Literal, Optional
from pydantic import BaseModel, Field
from torch import Tensor
import torch
from invokeai.app.invocations.models.config import InvocationConfig
from invokeai.app.models.exceptions import CanceledException
from invokeai.app.invocations.util.get_model import choose_model
from invokeai.app.util.step_callback import diffusers_step_callback_adapter
from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
@ -13,13 +17,10 @@ from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
import numpy as np
from accelerate.utils import set_seed
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from ...backend.generator import Generator
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL
from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers
from diffusers import DiffusionPipeline
@ -109,6 +110,15 @@ class NoiseInvocation(BaseInvocation):
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "noise"],
},
}
def invoke(self, context: InvocationContext) -> NoiseOutput:
device = torch.device(choose_torch_device())
noise = get_noise(self.width, self.height, device, self.seed)
@ -143,33 +153,37 @@ class TextToLatentsInvocation(BaseInvocation):
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
# fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "image"],
"type_hints": {
"model": "model"
}
},
}
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, sample: Tensor, step: int
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None:
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
(width, height) = image.size
width *= 8
height *= 8
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
dataURL = image_to_dataURL(image, image_format="JPEG")
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
self.id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step,
self.steps,
)
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
model_info = model_manager.get_model(self.model)
model_info = choose_model(model_manager, self.model)
model_name = model_info['model_name']
model_hash = model_info['hash']
model: StableDiffusionGeneratorPipeline = model_info['model']
@ -214,7 +228,7 @@ class TextToLatentsInvocation(BaseInvocation):
noise = context.services.latents.get(self.noise.latents_name)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step)
self.dispatch_progress(context, state)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
@ -244,6 +258,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
type: Literal["l2l"] = "l2l"
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model"
}
},
}
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use")
@ -253,7 +278,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
latent = context.services.latents.get(self.latents.latents_name)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step)
self.dispatch_progress(context, state)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
@ -299,12 +324,23 @@ class LatentsToImageInvocation(BaseInvocation):
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
model: str = Field(default="", description="The model to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "image"],
"type_hints": {
"model": "model"
}
},
}
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name)
# TODO: this only really needs the vae
model_info = context.services.model_manager.get_model(self.model)
model_info = choose_model(context.services.model_manager, self.model)
model: StableDiffusionGeneratorPipeline = model_info['model']
with torch.inference_mode():

View File

@ -1,17 +1,26 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from typing import Literal, Optional
from typing import Literal
import numpy
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from invokeai.app.invocations.models.config import InvocationConfig
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
class MathInvocationConfig(BaseModel):
"""Helper class to provide all math invocations with additional config"""
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["math"],
}
}
class IntOutput(BaseInvocationOutput):
"""An integer output"""
#fmt: off
@ -20,7 +29,7 @@ class IntOutput(BaseInvocationOutput):
#fmt: on
class AddInvocation(BaseInvocation):
class AddInvocation(BaseInvocation, MathInvocationConfig):
"""Adds two numbers"""
#fmt: off
type: Literal["add"] = "add"
@ -32,7 +41,7 @@ class AddInvocation(BaseInvocation):
return IntOutput(a=self.a + self.b)
class SubtractInvocation(BaseInvocation):
class SubtractInvocation(BaseInvocation, MathInvocationConfig):
"""Subtracts two numbers"""
#fmt: off
type: Literal["sub"] = "sub"
@ -44,7 +53,7 @@ class SubtractInvocation(BaseInvocation):
return IntOutput(a=self.a - self.b)
class MultiplyInvocation(BaseInvocation):
class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
"""Multiplies two numbers"""
#fmt: off
type: Literal["mul"] = "mul"
@ -56,7 +65,7 @@ class MultiplyInvocation(BaseInvocation):
return IntOutput(a=self.a * self.b)
class DivideInvocation(BaseInvocation):
class DivideInvocation(BaseInvocation, MathInvocationConfig):
"""Divides two numbers"""
#fmt: off
type: Literal["div"] = "div"

View File

@ -0,0 +1,54 @@
from typing import Dict, List, Literal, TypedDict
from pydantic import BaseModel
# TODO: when we can upgrade to python 3.11, we can use the`NotRequired` type instead of `total=False`
class UIConfig(TypedDict, total=False):
type_hints: Dict[
str,
Literal[
"integer",
"float",
"boolean",
"string",
"enum",
"image",
"latents",
"model",
],
]
tags: List[str]
class CustomisedSchemaExtra(TypedDict):
ui: UIConfig
class InvocationConfig(BaseModel.Config):
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
`tags`
- A list of strings, used to categorise invocations.
`type_hints`
- A dict of field types which override the types in the invocation definition.
- Each key should be the name of one of the invocation's fields.
- Each value should be one of the valid types:
- `integer`, `float`, `boolean`, `string`, `enum`, `image`, `latents`, `model`
```python
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["stable-diffusion", "image"],
"type_hints": {
"initial_image": "image",
},
},
}
```
"""
schema_extra: CustomisedSchemaExtra

View File

@ -2,6 +2,7 @@ from datetime import datetime, timezone
from typing import Literal, Union
from pydantic import Field
from invokeai.app.invocations.models.config import InvocationConfig
from invokeai.app.models.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices
@ -18,6 +19,14 @@ class RestoreFaceInvocation(BaseInvocation):
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
#fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["restoration", "image"],
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name

View File

@ -4,6 +4,7 @@ from datetime import datetime, timezone
from typing import Literal, Union
from pydantic import Field
from invokeai.app.invocations.models.config import InvocationConfig
from invokeai.app.models.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices
@ -22,6 +23,15 @@ class UpscaleInvocation(BaseInvocation):
level: Literal[2, 4] = Field(default=2, description="The upscale level")
#fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["upscaling", "image"],
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name

View File

@ -0,0 +1,11 @@
from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.backend.model_management.model_manager import ModelManager
def choose_model(model_manager: ModelManager, model_name: str):
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
if model_manager.valid_model(model_name):
return model_manager.get_model(model_name)
else:
print(f"* Warning: '{model_name}' is not a valid model name. Using default model instead.")
return model_manager.get_model()