From 6ab9a5e108b09df8a9295f324ef15fbc74b67149 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 5 Jul 2023 20:00:43 +0300 Subject: [PATCH] Draft --- .../controlnet_image_processors.py | 17 ++- invokeai/app/invocations/latent.py | 101 +++++++++++------- 2 files changed, 77 insertions(+), 41 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index c37dcda998..c9fad11987 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -9,6 +9,7 @@ from typing import Literal, Optional, Union, List, Dict from PIL import Image from pydantic import BaseModel, Field, validator +from ...backend.model_management import BaseModelType, ModelType from ..models.image import ImageField, ImageCategory, ResourceOrigin from .baseinvocation import ( BaseInvocation, @@ -105,9 +106,15 @@ CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control # CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])] +class ControlNetModelField(BaseModel): + """ControlNet model field""" + + model_name: str = Field(description="Name of the ControlNet model") + base_model: BaseModelType = Field(description="Base model") + class ControlField(BaseModel): image: ImageField = Field(default=None, description="The control image") - control_model: Optional[str] = Field(default=None, description="The ControlNet model to use") + control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use") # control_weight: Optional[float] = Field(default=1, description="weight given to controlnet") control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") begin_step_percent: float = Field(default=0, ge=0, le=1, @@ -154,7 +161,7 @@ class ControlNetInvocation(BaseInvocation): type: Literal["controlnet"] = "controlnet" # Inputs image: ImageField = Field(default=None, description="The control image") - control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny", + control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny", description="control model used") control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet") begin_step_percent: float = Field(default=0, ge=0, le=1, @@ -182,7 +189,11 @@ class ControlNetInvocation(BaseInvocation): return ControlOutput( control=ControlField( image=self.image, - control_model=self.control_model, + #control_model=self.control_model, + control_model=ControlNetModelField( + model_name="canny", + base_model=BaseModelType.StableDiffusion1, + ), control_weight=self.control_weight, begin_step_percent=self.begin_step_percent, end_step_percent=self.end_step_percent, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index b3f95f3658..1e41a9c96f 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -71,16 +71,21 @@ def get_scheduler( scheduler_name: str, ) -> Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get( - scheduler_name, SCHEDULER_MAP['ddim']) + scheduler_name, SCHEDULER_MAP['ddim'] + ) orig_scheduler_info = context.services.model_manager.get_model( - **scheduler_info.dict()) + **scheduler_info.dict() + ) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config if "_backup" in scheduler_config: scheduler_config = scheduler_config["_backup"] - scheduler_config = {**scheduler_config, ** - scheduler_extra_config, "_backup": scheduler_config} + scheduler_config = { + **scheduler_config, + **scheduler_extra_config, + "_backup": scheduler_config, + } scheduler = scheduler_class.from_config(scheduler_config) # hack copied over from generate.py @@ -137,8 +142,11 @@ class TextToLatentsInvocation(BaseInvocation): # TODO: pass this an emitter method or something? or a session for dispatching? def dispatch_progress( - self, context: InvocationContext, source_node_id: str, - intermediate_state: PipelineIntermediateState) -> None: + self, + context: InvocationContext, + source_node_id: str, + intermediate_state: PipelineIntermediateState, + ) -> None: stable_diffusion_step_callback( context=context, intermediate_state=intermediate_state, @@ -147,11 +155,16 @@ class TextToLatentsInvocation(BaseInvocation): ) def get_conditioning_data( - self, context: InvocationContext, scheduler) -> ConditioningData: + self, + context: InvocationContext, + scheduler, + ) -> ConditioningData: c, extra_conditioning_info = context.services.latents.get( - self.positive_conditioning.conditioning_name) + self.positive_conditioning.conditioning_name + ) uc, _ = context.services.latents.get( - self.negative_conditioning.conditioning_name) + self.negative_conditioning.conditioning_name + ) conditioning_data = ConditioningData( unconditioned_embeddings=uc, @@ -178,7 +191,10 @@ class TextToLatentsInvocation(BaseInvocation): return conditioning_data def create_pipeline( - self, unet, scheduler) -> StableDiffusionGeneratorPipeline: + self, + unet, + scheduler, + ) -> StableDiffusionGeneratorPipeline: # TODO: # configure_model_padding( # unet, @@ -213,6 +229,7 @@ class TextToLatentsInvocation(BaseInvocation): model: StableDiffusionGeneratorPipeline, control_input: List[ControlField], latents_shape: List[int], + exit_stack: ExitStack, do_classifier_free_guidance: bool = True, ) -> List[ControlNetData]: @@ -238,25 +255,19 @@ class TextToLatentsInvocation(BaseInvocation): control_data = [] control_models = [] for control_info in control_list: - # handle control models - if ("," in control_info.control_model): - control_model_split = control_info.control_model.split(",") - control_name = control_model_split[0] - control_subfolder = control_model_split[1] - print("Using HF model subfolders") - print(" control_name: ", control_name) - print(" control_subfolder: ", control_subfolder) - control_model = ControlNetModel.from_pretrained( - control_name, subfolder=control_subfolder, - torch_dtype=model.unet.dtype).to( - model.device) - else: - control_model = ControlNetModel.from_pretrained( - control_info.control_model, torch_dtype=model.unet.dtype).to(model.device) + control_model = exit_stack.enter_context( + context.model_manager.get_model( + model_name=control_info.control_model.model_name, + model_type=ModelType.ControlNet, + base_model=control_info.control_model.base_model, + ) + ) + control_models.append(control_model) control_image_field = control_info.image input_image = context.services.images.get_pil_image( - control_image_field.image_name) + control_image_field.image_name + ) # self.image.image_type, self.image.image_name # FIXME: still need to test with different widths, heights, devices, dtypes # and add in batch_size, num_images_per_prompt? @@ -278,7 +289,8 @@ class TextToLatentsInvocation(BaseInvocation): weight=control_info.control_weight, begin_step_percent=control_info.begin_step_percent, end_step_percent=control_info.end_step_percent, - control_mode=control_info.control_mode,) + control_mode=control_info.control_mode, + ) control_data.append(control_item) # MultiControlNetModel has been refactored out, just need list[ControlNetData] return control_data @@ -289,7 +301,8 @@ class TextToLatentsInvocation(BaseInvocation): # Get the source node id (we are invoking the prepared node) graph_execution_state = context.services.graph_execution_manager.get( - context.graph_execution_state_id) + context.graph_execution_state_id + ) source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): @@ -298,14 +311,17 @@ class TextToLatentsInvocation(BaseInvocation): def _lora_loader(): for lora in self.unet.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"})) + **lora.dict(exclude={"weight"}) + ) yield (lora_info.context.model, lora.weight) del lora_info return unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict()) - with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ + **self.unet.unet.dict() + ) + with ExitStack() as exit_stack,\ + ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ unet_info as unet: scheduler = get_scheduler( @@ -322,6 +338,7 @@ class TextToLatentsInvocation(BaseInvocation): latents_shape=noise.shape, # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) do_classifier_free_guidance=True, + exit_stack=exit_stack, ) # TODO: Verify the noise is the right size @@ -374,7 +391,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): # Get the source node id (we are invoking the prepared node) graph_execution_state = context.services.graph_execution_manager.get( - context.graph_execution_state_id) + context.graph_execution_state_id + ) source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): @@ -383,14 +401,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): def _lora_loader(): for lora in self.unet.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"})) + **lora.dict(exclude={"weight"}) + ) yield (lora_info.context.model, lora.weight) del lora_info return unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict()) - with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ + **self.unet.unet.dict() + ) + with ExitStack() as exit_stack,\ + ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ unet_info as unet: scheduler = get_scheduler( @@ -407,11 +428,13 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): latents_shape=noise.shape, # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) do_classifier_free_guidance=True, + exit_stack=exit_stack, ) # TODO: Verify the noise is the right size initial_latents = latent if self.strength < 1.0 else torch.zeros_like( - latent, device=unet.device, dtype=latent.dtype) + latent, device=unet.device, dtype=latent.dtype + ) timesteps, _ = pipeline.get_img2img_timesteps( self.steps, @@ -535,7 +558,8 @@ class ResizeLatentsInvocation(BaseInvocation): resized_latents = torch.nn.functional.interpolate( latents, size=(self.height // 8, self.width // 8), mode=self.mode, antialias=self.antialias - if self.mode in ["bilinear", "bicubic"] else False,) + if self.mode in ["bilinear", "bicubic"] else False, + ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -569,7 +593,8 @@ class ScaleLatentsInvocation(BaseInvocation): resized_latents = torch.nn.functional.interpolate( latents, scale_factor=self.scale_factor, mode=self.mode, antialias=self.antialias - if self.mode in ["bilinear", "bicubic"] else False,) + if self.mode in ["bilinear", "bicubic"] else False, + ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache()