Fixed controlnet preprocessors and controlnet handling in TextToLatents to work with revised Image services.

This commit is contained in:
user1 2023-05-26 16:47:27 -07:00 committed by Kent Keirsey
parent 1ad4eb3a7b
commit 9a796364da
2 changed files with 111 additions and 69 deletions

View File

@ -7,14 +7,13 @@ from typing import Literal, Optional, Union, List
from PIL import Image, ImageFilter, ImageOps from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..models.image import ImageField, ImageType from ..models.image import ImageField, ImageType, ImageCategory
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
InvocationContext, InvocationContext,
InvocationConfig, InvocationConfig,
) )
from controlnet_aux import ( from controlnet_aux import (
CannyDetector, CannyDetector,
HEDdetector, HEDdetector,
@ -26,10 +25,11 @@ from controlnet_aux import (
OpenposeDetector, OpenposeDetector,
PidiNetDetector, PidiNetDetector,
ContentShuffleDetector, ContentShuffleDetector,
# ZoeDetector, # FIXME: uncomment once ZoeDetector is availabel in official controlnet_aux release ZoeDetector,
MediapipeFaceDetector,
) )
from .image import ImageOutput, build_image_output, PILInvocationConfig from .image import ImageOutput, PILInvocationConfig
CONTROLNET_DEFAULT_MODELS = [ CONTROLNET_DEFAULT_MODELS = [
########################################### ###########################################
@ -161,33 +161,41 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
return image return image
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
raw_image = context.services.images.get(
raw_image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_type, self.image.image_name
) )
# image type should be PIL.PngImagePlugin.PngImageFile ? # image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image) processed_image = self.run_processor(raw_image)
# FIXME: what happened to image metadata?
# metadata = context.services.metadata.build_metadata(
# session_id=context.graph_execution_state_id, node=self
# )
# currently can't see processed image in node UI without a showImage node, # currently can't see processed image in node UI without a showImage node,
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery # so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
# image_type = ImageType.INTERMEDIATE image_dto = context.services.images.create(
image_type = ImageType.RESULT image=processed_image,
image_name = context.services.images.create_name( image_type=ImageType.RESULT,
context.graph_execution_state_id, self.id image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate
) )
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, processed_image, metadata)
"""Builds an ImageOutput and its ImageField""" """Builds an ImageOutput and its ImageField"""
processed_image_field = ImageField( processed_image_field = ImageField(
image_name=image_name, image_name=image_dto.image_name,
image_type=image_type, image_type=image_dto.image_type,
) )
return ImageOutput( return ImageOutput(
image=processed_image_field, image=processed_image_field,
width=processed_image.width, # width=processed_image.width,
height=processed_image.height, width = image_dto.width,
mode=processed_image.mode, # height=processed_image.height,
height = image_dto.height,
# mode=processed_image.mode,
) )
@ -392,18 +400,17 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca
return processed_image return processed_image
# # FIXME: ZoeDetector was implemented _after_ most recent official release of controlnet_aux (v0.0.3) # should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
# # so it is commented out until a new release is made class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
# class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): """Applies Zoe depth processing to image"""
# """Applies Zoe depth processing to image""" # fmt: off
# # fmt: off type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
# type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor" # fmt: on
# # fmt: on
# def run_processor(self, image):
# def run_processor(self, image): zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
# zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators") processed_image = zoe_depth_processor(image)
# processed_image = zoe_depth_processor(image) return processed_image
# return processed_image
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):

View File

@ -6,10 +6,11 @@ from typing import Literal, Optional, Union, List
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, validator
import torch import torch
from invokeai.app.invocations.util.choose_model import choose_model from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.models.image import ImageCategory
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
@ -27,9 +28,9 @@ from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np import numpy as np
from ..services.image_storage import ImageType from ..services.image_file_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput, build_image_output from .image import ImageField, ImageOutput
from .compel import ConditioningField from .compel import ConditioningField
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
@ -146,12 +147,17 @@ class NoiseInvocation(BaseInvocation):
}, },
} }
@validator("seed", pre=True)
def modulo_seed(cls, v):
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
return v % SEED_MAX
def invoke(self, context: InvocationContext) -> NoiseOutput: def invoke(self, context: InvocationContext) -> NoiseOutput:
device = torch.device(choose_torch_device()) device = torch.device(choose_torch_device())
noise = get_noise(self.width, self.height, device, self.seed) noise = get_noise(self.width, self.height, device, self.seed)
name = f'{context.graph_execution_state_id}__{self.id}' name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, noise) context.services.latents.save(name, noise)
return build_noise_output(latents_name=name, latents=noise) return build_noise_output(latents_name=name, latents=noise)
@ -168,19 +174,18 @@ class TextToLatentsInvocation(BaseInvocation):
noise: Optional[LatentsField] = Field(description="The noise to use") noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" ) scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)") model: str = Field(default="", description="The model to use (currently ignored)")
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") # seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", ) # seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
control: Union[ControlField, List[ControlField]] = Field(default=None, description="The controlnet(s) to use")
# fmt: on # fmt: on
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {
"tags": ["latents"], "tags": ["latents", "image"],
"type_hints": { "type_hints": {
"model": "model", "model": "model",
"control": "control", "control": "control",
@ -209,17 +214,17 @@ class TextToLatentsInvocation(BaseInvocation):
scheduler_name=self.scheduler scheduler_name=self.scheduler
) )
if isinstance(model, DiffusionPipeline): # if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]: # for component in [model.unet, model.vae]:
configure_model_padding(component, # configure_model_padding(component,
self.seamless, # self.seamless,
self.seamless_axes # self.seamless_axes
) # )
else: # else:
configure_model_padding(model, # configure_model_padding(model,
self.seamless, # self.seamless,
self.seamless_axes # self.seamless_axes
) # )
return model return model
@ -292,7 +297,9 @@ class TextToLatentsInvocation(BaseInvocation):
torch_dtype=model.unet.dtype).to(model.device) torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model) control_models.append(control_model)
control_image_field = control_info.image control_image_field = control_info.image
input_image = context.services.images.get(control_image_field.image_type, control_image_field.image_name) input_image = context.services.images.get_pil_image(control_image_field.image_type,
control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes # FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt? # and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance? # and do real check for classifier_free_guidance?
@ -348,7 +355,7 @@ class TextToLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}' name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents) context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents) return build_latents_output(latents_name=name, latents=result_latents)
@ -361,6 +368,18 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
latents: Optional[LatentsField] = Field(description="The latents to use as a base image") latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use") strength: float = Field(default=0.5, description="The strength of the latents to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model",
"control": "control",
}
},
}
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name) latent = context.services.latents.get(self.latents.latents_name)
@ -402,7 +421,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}' name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents) context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents) return build_latents_output(latents_name=name, latents=result_latents)
@ -439,20 +458,30 @@ class LatentsToImageInvocation(BaseInvocation):
np_image = model.decode_latents(latents) np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0] image = model.numpy_to_pil(np_image)[0]
image_type = ImageType.RESULT # what happened to metadata?
image_name = context.services.images.create_name( # metadata = context.services.metadata.build_metadata(
context.graph_execution_state_id, self.id # session_id=context.graph_execution_state_id, node=self
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
torch.cuda.empty_cache() torch.cuda.empty_cache()
context.services.images.save(image_type, image_name, image, metadata) # new (post Image service refactor) way of using services to save image
return build_image_output( # and gnenerate unique image_name
image_type=image_type, image_name=image_name, image=image image_dto = context.services.images.create(
image=image,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
) )
@ -487,7 +516,8 @@ class ResizeLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents) # context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents) return build_latents_output(latents_name=name, latents=resized_latents)
@ -517,7 +547,8 @@ class ScaleLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents) # context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents) return build_latents_output(latents_name=name, latents=resized_latents)
@ -541,7 +572,10 @@ class ImageToLatentsInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.services.images.get( # image = context.services.images.get(
# self.image.image_type, self.image.image_name
# )
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_type, self.image.image_name
) )
@ -561,5 +595,6 @@ class ImageToLatentsInvocation(BaseInvocation):
) )
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, latents) # context.services.latents.set(name, latents)
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents) return build_latents_output(latents_name=name, latents=latents)