diff --git a/invokeai/backend/generator/base.py b/invokeai/backend/generator/base.py index 8f5b1a8395..6f2f33e6af 100644 --- a/invokeai/backend/generator/base.py +++ b/invokeai/backend/generator/base.py @@ -75,9 +75,11 @@ class InvokeAIGenerator(metaclass=ABCMeta): def __init__(self, model_info: dict, params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(), + **kwargs, ): self.model_info=model_info self.params=params + self.kwargs = kwargs def generate(self, prompt: str='', @@ -118,9 +120,12 @@ class InvokeAIGenerator(metaclass=ABCMeta): model=model, scheduler_name=generator_args.get('scheduler') ) - uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model) + + # get conditioning from prompt via Compel package + uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt, model=model) + gen_class = self._generator_class() - generator = gen_class(model, self.params.precision) + generator = gen_class(model, self.params.precision, **self.kwargs) if self.params.variation_amount > 0: generator.set_variation(generator_args.get('seed'), generator_args.get('variation_amount'), @@ -276,7 +281,7 @@ class Generator: precision: str model: DiffusionPipeline - def __init__(self, model: DiffusionPipeline, precision: str): + def __init__(self, model: DiffusionPipeline, precision: str, **kwargs): self.model = model self.precision = precision self.seed = None diff --git a/invokeai/backend/generator/txt2img.py b/invokeai/backend/generator/txt2img.py index e5a96212f0..a83a8e0c31 100644 --- a/invokeai/backend/generator/txt2img.py +++ b/invokeai/backend/generator/txt2img.py @@ -4,6 +4,10 @@ invokeai.backend.generator.txt2img inherits from invokeai.backend.generator import PIL.Image import torch +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from diffusers.models.controlnet import ControlNetModel, ControlNetOutput +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel + from ..stable_diffusion import ( ConditioningData, PostprocessingSettings, @@ -13,8 +17,13 @@ from .base import Generator class Txt2Img(Generator): - def __init__(self, model, precision): - super().__init__(model, precision) + def __init__(self, model, precision, + control_model: Optional[Union[ControlNetModel, List[ControlNetModel]]] = None, + **kwargs): + self.control_model = control_model + if isinstance(self.control_model, list): + self.control_model = MultiControlNetModel(self.control_model) + super().__init__(model, precision, **kwargs) @torch.no_grad() def get_make_image( @@ -42,9 +51,12 @@ class Txt2Img(Generator): kwargs are 'width' and 'height' """ self.perlin = perlin + control_image = kwargs.get("control_image", None) + do_classifier_free_guidance = cfg_scale > 1.0 # noinspection PyTypeChecker pipeline: StableDiffusionGeneratorPipeline = self.model + pipeline.control_model = self.control_model pipeline.scheduler = sampler uc, c, extra_conditioning_info = conditioning @@ -61,6 +73,37 @@ class Txt2Img(Generator): ), ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta) + # FIXME: still need to test with different widths, heights, devices, dtypes + # and add in batch_size, num_images_per_prompt? + if control_image is not None: + if isinstance(self.control_model, ControlNetModel): + control_image = pipeline.prepare_control_image( + image=control_image, + do_classifier_free_guidance=do_classifier_free_guidance, + width=width, + height=height, + # batch_size=batch_size * num_images_per_prompt, + # num_images_per_prompt=num_images_per_prompt, + device=self.control_model.device, + dtype=self.control_model.dtype, + ) + elif isinstance(self.control_model, MultiControlNetModel): + images = [] + for image_ in control_image: + image_ = self.model.prepare_control_image( + image=image_, + do_classifier_free_guidance=do_classifier_free_guidance, + width=width, + height=height, + # batch_size=batch_size * num_images_per_prompt, + # num_images_per_prompt=num_images_per_prompt, + device=self.control_model.device, + dtype=self.control_model.dtype, + ) + images.append(image_) + control_image = images + kwargs["control_image"] = control_image + def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image: pipeline_output = pipeline.image_from_embeddings( latents=torch.zeros_like(x_T, dtype=self.torch_dtype()), @@ -68,6 +111,7 @@ class Txt2Img(Generator): num_inference_steps=steps, conditioning_data=conditioning_data, callback=step_callback, + **kwargs, ) if (