mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Added support for ControlNet and MultiControlNet to legacy non-nodal Txt2Img in backend/generator. Although backend/generator will likely disappear by v3.x, right now they are very useful for testing core ControlNet and MultiControlNet functionality while node codebase is rapidly evolving.
This commit is contained in:
parent
5ff98a4179
commit
a91dee87d0
@ -75,9 +75,11 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
def __init__(self,
|
||||
model_info: dict,
|
||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||
**kwargs,
|
||||
):
|
||||
self.model_info=model_info
|
||||
self.params=params
|
||||
self.kwargs = kwargs
|
||||
|
||||
def generate(self,
|
||||
prompt: str='',
|
||||
@ -118,9 +120,12 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
model=model,
|
||||
scheduler_name=generator_args.get('scheduler')
|
||||
)
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
||||
|
||||
# get conditioning from prompt via Compel package
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt, model=model)
|
||||
|
||||
gen_class = self._generator_class()
|
||||
generator = gen_class(model, self.params.precision)
|
||||
generator = gen_class(model, self.params.precision, **self.kwargs)
|
||||
if self.params.variation_amount > 0:
|
||||
generator.set_variation(generator_args.get('seed'),
|
||||
generator_args.get('variation_amount'),
|
||||
@ -276,7 +281,7 @@ class Generator:
|
||||
precision: str
|
||||
model: DiffusionPipeline
|
||||
|
||||
def __init__(self, model: DiffusionPipeline, precision: str):
|
||||
def __init__(self, model: DiffusionPipeline, precision: str, **kwargs):
|
||||
self.model = model
|
||||
self.precision = precision
|
||||
self.seed = None
|
||||
|
@ -4,6 +4,10 @@ invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||
|
||||
from ..stable_diffusion import (
|
||||
ConditioningData,
|
||||
PostprocessingSettings,
|
||||
@ -13,8 +17,13 @@ from .base import Generator
|
||||
|
||||
|
||||
class Txt2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
def __init__(self, model, precision,
|
||||
control_model: Optional[Union[ControlNetModel, List[ControlNetModel]]] = None,
|
||||
**kwargs):
|
||||
self.control_model = control_model
|
||||
if isinstance(self.control_model, list):
|
||||
self.control_model = MultiControlNetModel(self.control_model)
|
||||
super().__init__(model, precision, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(
|
||||
@ -42,9 +51,12 @@ class Txt2Img(Generator):
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
self.perlin = perlin
|
||||
control_image = kwargs.get("control_image", None)
|
||||
do_classifier_free_guidance = cfg_scale > 1.0
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.control_model = self.control_model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
@ -61,6 +73,37 @@ class Txt2Img(Generator):
|
||||
),
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
if control_image is not None:
|
||||
if isinstance(self.control_model, ControlNetModel):
|
||||
control_image = pipeline.prepare_control_image(
|
||||
image=control_image,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=width,
|
||||
height=height,
|
||||
# batch_size=batch_size * num_images_per_prompt,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=self.control_model.device,
|
||||
dtype=self.control_model.dtype,
|
||||
)
|
||||
elif isinstance(self.control_model, MultiControlNetModel):
|
||||
images = []
|
||||
for image_ in control_image:
|
||||
image_ = self.model.prepare_control_image(
|
||||
image=image_,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=width,
|
||||
height=height,
|
||||
# batch_size=batch_size * num_images_per_prompt,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=self.control_model.device,
|
||||
dtype=self.control_model.dtype,
|
||||
)
|
||||
images.append(image_)
|
||||
control_image = images
|
||||
kwargs["control_image"] = control_image
|
||||
|
||||
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
|
||||
pipeline_output = pipeline.image_from_embeddings(
|
||||
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
|
||||
@ -68,6 +111,7 @@ class Txt2Img(Generator):
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if (
|
||||
|
Loading…
Reference in New Issue
Block a user