Fixed controlnet preprocessors and controlnet handling in TextToLatents to work with revised Image services.

This commit is contained in:
user1 2023-05-26 16:47:27 -07:00 committed by Kent Keirsey
parent 1ad4eb3a7b
commit 9a796364da
2 changed files with 111 additions and 69 deletions

View File

@ -7,14 +7,13 @@ from typing import Literal, Optional, Union, List
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field
from ..models.image import ImageField, ImageType
from ..models.image import ImageField, ImageType, ImageCategory
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
from controlnet_aux import (
CannyDetector,
HEDdetector,
@ -26,10 +25,11 @@ from controlnet_aux import (
OpenposeDetector,
PidiNetDetector,
ContentShuffleDetector,
# ZoeDetector, # FIXME: uncomment once ZoeDetector is availabel in official controlnet_aux release
ZoeDetector,
MediapipeFaceDetector,
)
from .image import ImageOutput, build_image_output, PILInvocationConfig
from .image import ImageOutput, PILInvocationConfig
CONTROLNET_DEFAULT_MODELS = [
###########################################
@ -161,33 +161,41 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
return image
def invoke(self, context: InvocationContext) -> ImageOutput:
raw_image = context.services.images.get(
raw_image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
# image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image)
# FIXME: what happened to image metadata?
# metadata = context.services.metadata.build_metadata(
# session_id=context.graph_execution_state_id, node=self
# )
# currently can't see processed image in node UI without a showImage node,
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
# image_type = ImageType.INTERMEDIATE
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=processed_image,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, processed_image, metadata)
"""Builds an ImageOutput and its ImageField"""
processed_image_field = ImageField(
image_name=image_name,
image_type=image_type,
image_name=image_dto.image_name,
image_type=image_dto.image_type,
)
return ImageOutput(
image=processed_image_field,
width=processed_image.width,
height=processed_image.height,
mode=processed_image.mode,
# width=processed_image.width,
width = image_dto.width,
# height=processed_image.height,
height = image_dto.height,
# mode=processed_image.mode,
)
@ -392,18 +400,17 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca
return processed_image
# # FIXME: ZoeDetector was implemented _after_ most recent official release of controlnet_aux (v0.0.3)
# # so it is commented out until a new release is made
# class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
# """Applies Zoe depth processing to image"""
# # fmt: off
# type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
# # fmt: on
#
# def run_processor(self, image):
# zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
# processed_image = zoe_depth_processor(image)
# return processed_image
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies Zoe depth processing to image"""
# fmt: off
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
# fmt: on
def run_processor(self, image):
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = zoe_depth_processor(image)
return processed_image
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):

View File

@ -6,10 +6,11 @@ from typing import Literal, Optional, Union, List
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validator
import torch
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.models.image import ImageCategory
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback
@ -27,9 +28,9 @@ from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np
from ..services.image_storage import ImageType
from ..services.image_file_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput, build_image_output
from .image import ImageField, ImageOutput
from .compel import ConditioningField
from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler
@ -146,12 +147,17 @@ class NoiseInvocation(BaseInvocation):
},
}
@validator("seed", pre=True)
def modulo_seed(cls, v):
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
return v % SEED_MAX
def invoke(self, context: InvocationContext) -> NoiseOutput:
device = torch.device(choose_torch_device())
noise = get_noise(self.width, self.height, device, self.seed)
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, noise)
context.services.latents.save(name, noise)
return build_noise_output(latents_name=name, latents=noise)
@ -168,19 +174,18 @@ class TextToLatentsInvocation(BaseInvocation):
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)")
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
control: Union[ControlField, List[ControlField]] = Field(default=None, description="The controlnet(s) to use")
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"tags": ["latents", "image"],
"type_hints": {
"model": "model",
"control": "control",
@ -209,17 +214,17 @@ class TextToLatentsInvocation(BaseInvocation):
scheduler_name=self.scheduler
)
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
self.seamless,
self.seamless_axes
)
else:
configure_model_padding(model,
self.seamless,
self.seamless_axes
)
# if isinstance(model, DiffusionPipeline):
# for component in [model.unet, model.vae]:
# configure_model_padding(component,
# self.seamless,
# self.seamless_axes
# )
# else:
# configure_model_padding(model,
# self.seamless,
# self.seamless_axes
# )
return model
@ -292,7 +297,9 @@ class TextToLatentsInvocation(BaseInvocation):
torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get(control_image_field.image_type, control_image_field.image_name)
input_image = context.services.images.get_pil_image(control_image_field.image_type,
control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
@ -348,7 +355,7 @@ class TextToLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
@ -361,6 +368,18 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model",
"control": "control",
}
},
}
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
@ -402,7 +421,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
@ -439,20 +458,30 @@ class LatentsToImageInvocation(BaseInvocation):
np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0]
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
# what happened to metadata?
# metadata = context.services.metadata.build_metadata(
# session_id=context.graph_execution_state_id, node=self
torch.cuda.empty_cache()
context.services.images.save(image_type, image_name, image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=image
# new (post Image service refactor) way of using services to save image
# and gnenerate unique image_name
image_dto = context.services.images.create(
image=image,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)
@ -487,7 +516,8 @@ class ResizeLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
# context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents)
@ -517,7 +547,8 @@ class ScaleLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
# context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents)
@ -541,7 +572,10 @@ class ImageToLatentsInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.services.images.get(
# image = context.services.images.get(
# self.image.image_type, self.image.image_name
# )
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
@ -561,5 +595,6 @@ class ImageToLatentsInvocation(BaseInvocation):
)
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, latents)
# context.services.latents.set(name, latents)
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)