Switching to ControlField for output from controlnet nodes.

This commit is contained in:
user1 2023-05-04 14:21:11 -07:00 committed by Kent Keirsey
parent 78cd106c23
commit 5e4c0217c7
3 changed files with 75 additions and 22 deletions

View File

@ -1,7 +1,4 @@
from typing import Literal, Optional from typing import Literal, Optional, Union, List
import numpy
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..models.image import ImageField, ImageType from ..models.image import ImageField, ImageType
@ -12,24 +9,57 @@ from .baseinvocation import (
InvocationConfig, InvocationConfig,
) )
from controlnet_aux import CannyDetector, HEDdetector, LineartDetector from controlnet_aux import CannyDetector
from .image import ImageOutput, build_image_output, PILInvocationConfig from .image import ImageOutput, build_image_output, PILInvocationConfig
# Canny Image Processor class ControlField(BaseModel):
class CannyProcessorInvocation(BaseInvocation, PILInvocationConfig):
"""Applies Canny edge detection to image""" image: ImageField = Field(default=None, description="processed image")
# width: Optional[int] = Field(default=None, description="The width of the image in pixels")
# height: Optional[int] = Field(default=None, description="The height of the image in pixels")
# mode: Optional[str] = Field(default=None, description="The mode of the image")
control_model: Optional[str] = Field(default=None, description="The control model used")
control_weight: Optional[float] = Field(default=None, description="The control weight used")
class Config:
schema_extra = {
"required": ["image", "control_model", "control_weight"]
# "required": ["type", "image", "width", "height", "mode"]
}
class ControlOutput(BaseInvocationOutput):
"""Base class for invocations that output ControlNet info"""
# fmt: off # fmt: off
type: Literal["canny"] = "canny" type: Literal["control_output"] = "control_output"
control: Optional[ControlField] = Field(default=None, description="The control info dict")
# image: ImageField = Field(default=None, description="outputs just them image info (which is also included in control output)")
# fmt: on
class CannyControlInvocation(BaseInvocation, PILInvocationConfig):
"""Canny edge detection for ControlNet"""
# fmt: off
type: Literal["cannycontrol"] = "cannycontrol"
# Inputs # Inputs
image: ImageField = Field(default=None, description="image to process") image: ImageField = Field(default=None, description="image to process")
low_threshold: float = Field(default=100, ge=0, description="low threshold of Canny pixel gradient") control_model: str = Field(default=None, description="control model to use")
control_weight: float = Field(default=0.5, ge=0, le=1, description="control weight")
# begin_step_percent: float = Field(default=0, ge=0, le=1,
# description="% of total steps at which controlnet is first applied")
# end_step_percent: float = Field(default=1, ge=0, le=1,
# description="% of total steps at which controlnet is last applied")
# guess_mode: bool = Field(default=False, description="use guess mode (controlnet ignores prompt)")
low_threshold: float = Field(default=100, ge=0, description="low threshold of Canny pixel gradient")
high_threshold: float = Field(default=200, ge=0, description="high threshold of Canny pixel gradient") high_threshold: float = Field(default=200, ge=0, description="high threshold of Canny pixel gradient")
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ControlOutput:
image = context.services.images.get( image = context.services.images.get(
self.image.image_type, self.image.image_name self.image.image_type, self.image.image_name
) )
@ -43,6 +73,17 @@ class CannyProcessorInvocation(BaseInvocation, PILInvocationConfig):
session_id=context.graph_execution_state_id, node=self session_id=context.graph_execution_state_id, node=self
) )
context.services.images.save(image_type, image_name, processed_image, metadata) context.services.images.save(image_type, image_name, processed_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=processed_image """Builds an ImageOutput and its ImageField"""
image_field = ImageField(
image_name=image_name,
image_type=image_type,
) )
return ControlOutput(
control=ControlField(
image=image_field,
control_model=self.control_model,
control_weight=self.control_weight,
)
)

View File

@ -1,10 +1,10 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import random import random
from typing import Literal, Optional, Union
import einops import einops
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
import torch import torch
from typing import Literal, Optional, Union, List
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
@ -13,6 +13,7 @@ from invokeai.app.models.image import ImageCategory
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from .controlnet_image_processors import ControlField
from ...backend.model_management.model_manager import ModelManager from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.util.devices import choose_torch_device, torch_dtype
@ -174,8 +175,7 @@ class TextToLatentsInvocation(BaseInvocation):
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) # seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") # seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", ) progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
control_model: Optional[str] = Field(default=None, description="The control model to use") control: Optional[ControlField] = Field(default=None, description="The control to use")
control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
# fmt: on # fmt: on
# Schema customisation # Schema customisation
@ -257,21 +257,32 @@ class TextToLatentsInvocation(BaseInvocation):
model = self.get_model(context.services.model_manager) model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(context, model) conditioning_data = self.get_conditioning_data(context, model)
# loading controlnet model print("type of control input: ", type(self.control))
if (self.control_model is None or self.control_model==''):
control_model = None if (self.control is None):
control_model_name = None
control_image_field = None
control_weight = None
else: else:
control_model_name = self.control.control_model
control_image_field = self.control.image
control_weight = self.control.control_weight
# # loading controlnet model
# if (self.control_model is None or self.control_model==''):
# control_model = None
# else:
# FIXME: change this to dropdown menu? # FIXME: change this to dropdown menu?
# FIXME: generalize so don't have to hardcode torch_dtype and device # FIXME: generalize so don't have to hardcode torch_dtype and device
control_model = ControlNetModel.from_pretrained(self.control_model, control_model = ControlNetModel.from_pretrained(control_model_name,
torch_dtype=torch.float16).to("cuda") torch_dtype=torch.float16).to("cuda")
model.control_model = control_model model.control_model = control_model
# loading controlnet image (currently requires pre-processed image) # loading controlnet image (currently requires pre-processed image)
control_image = ( control_image = (
None if self.control_image is None None if control_image_field is None
else context.services.images.get( else context.services.images.get(
self.control_image.image_type, self.control_image.image_name control_image_field.image_type, control_image_field.image_name
) )
) )

View File

@ -993,6 +993,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
def prepare_control_image( def prepare_control_image(
self, self,
image, image,
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
width=512, width=512,
height=512, height=512,
batch_size=1, batch_size=1,