diff --git a/invokeai/app/invocations/control_adapter.py b/invokeai/app/invocations/control_adapter.py index 7119ea0c34..36053e3b1c 100644 --- a/invokeai/app/invocations/control_adapter.py +++ b/invokeai/app/invocations/control_adapter.py @@ -19,6 +19,7 @@ from .baseinvocation import ( invocation_output, ) +CONTROL_ADAPTER_TYPES = Literal["ControlNet", "IP-Adapter", "T2I-Adapter"] CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"] CONTROLNET_RESIZE_VALUES = Literal[ @@ -36,8 +37,15 @@ class ControlNetModelField(BaseModel): class ControlField(BaseModel): + control_type: CONTROL_ADAPTER_TYPES = Field(default="ControlNet", description="The type of control adapter") image: ImageField = Field(description="The control image") - control_model: ControlNetModelField = Field(description="The ControlNet model to use") + # control_model and ip_adapter_models are both optional + # but must be on the two present + # if control_type == "ControlNet", then mus be control_model + # if control_type == "IP-Adapter", then must be ip_adapter_model + control_model: Optional[ControlNetModelField] = Field(description="The ControlNet model to use") + ip_adapter_model: Optional[str] = Field(description="The IP-Adapter model to use") + image_encoder_model: Optional[str] = Field(description="The clip_image_encoder model to use") control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") begin_step_percent: float = Field( default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)" @@ -97,8 +105,12 @@ class ControlNetInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> ControlOutput: return ControlOutput( control=ControlField( + control_type="ControlNet", image=self.image, control_model=self.control_model, + # ip_adapter_model is currently optional + # must be either a control_model or ip_adapter_model + # ip_adapter_model=None, control_weight=self.control_weight, begin_step_percent=self.begin_step_percent, end_step_percent=self.end_step_percent, @@ -106,3 +118,63 @@ class ControlNetInvocation(BaseInvocation): resize_mode=self.resize_mode, ), ) + +IP_ADAPTER_MODELS = Literal[ + "models_ip_adapter/models/ip-adapter_sd15.bin", + "models_ip_adapter/models/ip-adapter-plus_sd15.bin", + "models_ip_adapter/models/ip-adapter-plus-face_sd15.bin", + "models_ip_adapter/sdxl_models/ip-adapter_sdxl.bin" +] + +IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[ + "models_ip_adapter/models/image_encoder/", + "./models_ip_adapter/models/image_encoder/", + "models_ip_adapter/sdxl_models/image_encoder/" +] + +@invocation("ipadapter", title="IP-Adapter", tags=["ipadapter"], category="ipadapter") +class IPAdapterInvocation(BaseInvocation): + """Collects IP-Adapter info to pass to other nodes""" + + type: Literal["ipadapter"] = "ipadapter" + + # Inputs + image: ImageField = InputField(description="The control image") + #control_model: ControlNetModelField = InputField( + # default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct + #) + ip_adapter_model: IP_ADAPTER_MODELS = InputField(default="./models_ip_adapter/models/ip-adapter_sd15.bin", + description="The IP-Adapter model") + image_encoder_model: IP_ADAPTER_IMAGE_ENCODER_MODELS = InputField( + default="./models_ip_adapter/models/image_encoder/", + description="The image encoder model") + control_weight: Union[float, List[float]] = InputField( + default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float + ) + # begin_step_percent: float = InputField( + # default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)" + # ) + # end_step_percent: float = InputField( + # default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)" + # ) + # control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used") + # resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used") + + def invoke(self, context: InvocationContext) -> ControlOutput: + return ControlOutput( + control=ControlField( + control_type="IP-Adapter", + image=self.image, + # control_model is currently optional + # must be either a control_model or ip_adapter_model + # control_model=None, + ip_adapter_model=self.ip_adapter_model, + image_encoder_model=self.image_encoder_model, + control_weight=self.control_weight, + # rest are currently ignored + #begin_step_percent=self.begin_step_percent, + #end_step_percent=self.end_step_percent, + #control_mode=self.control_mode, + #resize_mode=self.resize_mode, + ), + ) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index ca400cd7b7..007f7bc9a2 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -40,6 +40,7 @@ from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( ConditioningData, ControlNetData, + IPAdapterData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor, ) @@ -216,9 +217,9 @@ class DenoiseLatentsInvocation(BaseInvocation): default=None, description=FieldDescriptions.mask, ) - ip_adapter_image: Optional[ImageField] = InputField(input=Input.Connection, title="IP Adapter Image", ui_order=6) - ip_adapter_strength: float = InputField(default=1.0, ge=0, le=2, ui_type=UIType.Float, - title="IP Adapter Strength", ui_order=7) + # ip_adapter_image: Optional[ImageField] = InputField(input=Input.Connection, title="IP Adapter Image", ui_order=6) + # ip_adapter_strength: float = InputField(default=1.0, ge=0, le=2, ui_type=UIType.Float, + # title="IP Adapter Strength", ui_order=7) @validator("cfg_scale") def ge_one(cls, v): @@ -340,57 +341,71 @@ class DenoiseLatentsInvocation(BaseInvocation): else: control_list = None if control_list is None: - control_data = None + controlnet_data = None + ip_adapter_data = None # from above handling, any control that is not None should now be of type list[ControlField] else: # FIXME: add checks to skip entry if model or image is None # and if weight is None, populate with default 1.0? - control_data = [] - control_models = [] + controlnet_data = [] + ip_adapter_data = [] + # control_models = [] for control_info in control_list: - control_model = exit_stack.enter_context( - context.services.model_manager.get_model( - model_name=control_info.control_model.model_name, - model_type=ModelType.ControlNet, - base_model=control_info.control_model.base_model, - context=context, + if control_info.control_type == "ControlNet": + control_model = exit_stack.enter_context( + context.services.model_manager.get_model( + model_name=control_info.control_model.model_name, + model_type=ModelType.ControlNet, + base_model=control_info.control_model.base_model, + context=context, + ) ) - ) - control_models.append(control_model) - control_image_field = control_info.image - input_image = context.services.images.get_pil_image(control_image_field.image_name) - # self.image.image_type, self.image.image_name - # FIXME: still need to test with different widths, heights, devices, dtypes - # and add in batch_size, num_images_per_prompt? - # and do real check for classifier_free_guidance? - # prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width) - control_image = prepare_control_image( - image=input_image, - do_classifier_free_guidance=do_classifier_free_guidance, - width=control_width_resize, - height=control_height_resize, - # batch_size=batch_size * num_images_per_prompt, - # num_images_per_prompt=num_images_per_prompt, - device=control_model.device, - dtype=control_model.dtype, - control_mode=control_info.control_mode, - resize_mode=control_info.resize_mode, - ) - control_item = ControlNetData( - model=control_model, - image_tensor=control_image, - weight=control_info.control_weight, - begin_step_percent=control_info.begin_step_percent, - end_step_percent=control_info.end_step_percent, - control_mode=control_info.control_mode, - # any resizing needed should currently be happening in prepare_control_image(), - # but adding resize_mode to ControlNetData in case needed in the future - resize_mode=control_info.resize_mode, - ) - control_data.append(control_item) - # MultiControlNetModel has been refactored out, just need list[ControlNetData] - return control_data + # control_models.append(control_model) + control_image_field = control_info.image + input_image = context.services.images.get_pil_image(control_image_field.image_name) + # self.image.image_type, self.image.image_name + # FIXME: still need to test with different widths, heights, devices, dtypes + # and add in batch_size, num_images_per_prompt? + # and do real check for classifier_free_guidance? + # prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width) + control_image = prepare_control_image( + image=input_image, + do_classifier_free_guidance=do_classifier_free_guidance, + width=control_width_resize, + height=control_height_resize, + # batch_size=batch_size * num_images_per_prompt, + # num_images_per_prompt=num_images_per_prompt, + device=control_model.device, + dtype=control_model.dtype, + control_mode=control_info.control_mode, + resize_mode=control_info.resize_mode, + ) + control_item = ControlNetData( + model=control_model, # model object + image_tensor=control_image, + weight=control_info.control_weight, + begin_step_percent=control_info.begin_step_percent, + end_step_percent=control_info.end_step_percent, + control_mode=control_info.control_mode, + # any resizing needed should currently be happening in prepare_control_image(), + # but adding resize_mode to ControlNetData in case needed in the future + resize_mode=control_info.resize_mode, + ) + controlnet_data.append(control_item) + # MultiControlNetModel has been refactored out, just need list[ControlNetData] + elif control_info.control_type == "IP-Adapter": + control_image_field = control_info.image + input_image = context.services.images.get_pil_image(control_image_field.image_name) + control_item = IPAdapterData( + ip_adapter_model=control_info.ip_adapter_model, # name of model (NOT model object) + image_encoder_model=control_info.image_encoder_model, # name of model (NOT model obj) + image=input_image, + weight=control_info.control_weight, + ) + ip_adapter_data.append(control_item) + + return controlnet_data, ip_adapter_data # original idea by https://github.com/AmericanPresidentJimmyCarter # TODO: research more for second order schedulers timesteps @@ -499,14 +514,14 @@ class DenoiseLatentsInvocation(BaseInvocation): pipeline = self.create_pipeline(unet, scheduler) conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed) - if self.ip_adapter_image is not None: - print("ip_adapter_image:", self.ip_adapter_image) - unwrapped_ip_adapter_image = context.services.images.get_pil_image(self.ip_adapter_image.image_name) - print("unwrapped ip_adapter_image:", unwrapped_ip_adapter_image) - else: - unwrapped_ip_adapter_image = None + # if self.ip_adapter_image is not None: + # print("ip_adapter_image:", self.ip_adapter_image) + # unwrapped_ip_adapter_image = context.services.images.get_pil_image(self.ip_adapter_image.image_name) + # print("unwrapped ip_adapter_image:", unwrapped_ip_adapter_image) + # else: + # unwrapped_ip_adapter_image = None - control_data = self.prep_control_data( + controlnet_data, ip_adapter_data = self.prep_control_data( model=pipeline, context=context, control_input=self.control, @@ -515,6 +530,8 @@ class DenoiseLatentsInvocation(BaseInvocation): do_classifier_free_guidance=True, exit_stack=exit_stack, ) + print("controlnet_data:", controlnet_data) + print("ip_adapter_data:", ip_adapter_data) num_inference_steps, timesteps, init_timestep = self.init_scheduler( scheduler, @@ -534,9 +551,10 @@ class DenoiseLatentsInvocation(BaseInvocation): masked_latents=masked_latents, num_inference_steps=num_inference_steps, conditioning_data=conditioning_data, - control_data=control_data, # list[ControlNetData], - ip_adapter_image=unwrapped_ip_adapter_image, - ip_adapter_strength=self.ip_adapter_strength, + control_data=controlnet_data, # list[ControlNetData], + ip_adapter_data=ip_adapter_data, # list[IPAdapterData], +# ip_adapter_image=unwrapped_ip_adapter_image, +# ip_adapter_strength=self.ip_adapter_strength, callback=step_callback, ) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index edada58c8a..6bc9848247 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -170,6 +170,15 @@ class ControlNetData: resize_mode: str = Field(default="just_resize") +@dataclass +class IPAdapterData: + ip_adapter_model: str = Field(default=None) + image_encoder_model: str = Field(default=None) + image: PIL.Image = Field(default=None) + # TODO: change to polymorphic so can do different weights per step (once implemented...) + # weight: Union[float, List[float]] = Field(default=1.0) + weight: float = Field(default=1.0) + @dataclass class ConditioningData: unconditioned_embeddings: BasicConditioningInfo @@ -358,8 +367,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): additional_guidance: List[Callable] = None, callback: Callable[[PipelineIntermediateState], None] = None, control_data: List[ControlNetData] = None, - ip_adapter_image: Optional[PIL.Image] = None, - ip_adapter_strength: float = 1.0, + ip_adapter_data: IPAdapterData = None, mask: Optional[torch.Tensor] = None, masked_latents: Optional[torch.Tensor] = None, seed: Optional[int] = None, @@ -411,8 +419,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): conditioning_data, additional_guidance=additional_guidance, control_data=control_data, - ip_adapter_image=ip_adapter_image, - ip_adapter_strength=ip_adapter_strength, + ip_adapter_data=ip_adapter_data, callback=callback, ) finally: @@ -432,8 +439,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): *, additional_guidance: List[Callable] = None, control_data: List[ControlNetData] = None, - ip_adapter_image: Optional[PIL.Image] = None, - ip_adapter_strength: float = 1.0, + ip_adapter_data: List[IPAdapterData] = None, callback: Callable[[PipelineIntermediateState], None] = None, ): @@ -447,26 +453,26 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if timesteps.shape[0] == 0: return latents, attention_map_saver - print("ip_adapter_image: ", type(ip_adapter_image)) - if ip_adapter_image is not None: + # print("ip_adapter_image: ", type(ip_adapter_image)) + if ip_adapter_data is not None and len(ip_adapter_data) > 0: + ip_adapter_info = ip_adapter_data[0] + ip_adapter_image = ip_adapter_info.image # initialize IPAdapter print(" width:", ip_adapter_image.width, " height:", ip_adapter_image.height) - clip_image_encoder_path = "ip_adapter_models_sd_1.5/image_encoder/" - ip_adapter_model_path = "ip_adapter_models_sd_1.5/ip-adapter_sd15.bin" # FIXME: # WARNING! # IPAdapter constructor modifies UNet model in-place # Adds additional cross-attention layers to UNet model for image embedding - # need to figure out how to only do this if UNet hasn't already been modified by prior IPAdapter - # and how to undo if ip_adapter_image is removed - # use existing model management context etc? + # need to figure out how to only do this if UNet hasn't already been modified by prior IPAdapter + # and how to undo if ip_adapter_image is removed + # Should reimplement to use existing model management context etc. # ip_adapter = IPAdapter(self, # IPAdapter first arg is StableDiffusionPipeline - clip_image_encoder_path, # hardwiring to manually downloaded dir for first pass - ip_adapter_model_path, # hardwiring to manually downloaded loc for first pass - "cuda") # hardwiring CUDA GPU for first pass + ip_adapter_info.image_encoder_model, + ip_adapter_info.ip_adapter_model, + self.unet.device) # IP-Adapter ==> add additional cross-attention layers to UNet model here? - ip_adapter.set_scale(ip_adapter_strength) + ip_adapter.set_scale(ip_adapter_info.weight) print("ip_adapter:", ip_adapter) # get image embedding from CLIP and ImageProjModel