mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Added support for ControlNet and MultiControlNet to legacy non-nodal Txt2Img in backend/generator. Although backend/generator will likely disappear by v3.x, right now they are very useful for testing core ControlNet and MultiControlNet functionality while node codebase is rapidly evolving.
This commit is contained in:
parent
5ff98a4179
commit
a91dee87d0
@ -75,9 +75,11 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_info: dict,
|
model_info: dict,
|
||||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.model_info=model_info
|
self.model_info=model_info
|
||||||
self.params=params
|
self.params=params
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
prompt: str='',
|
prompt: str='',
|
||||||
@ -118,9 +120,12 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
model=model,
|
model=model,
|
||||||
scheduler_name=generator_args.get('scheduler')
|
scheduler_name=generator_args.get('scheduler')
|
||||||
)
|
)
|
||||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
|
||||||
|
# get conditioning from prompt via Compel package
|
||||||
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt, model=model)
|
||||||
|
|
||||||
gen_class = self._generator_class()
|
gen_class = self._generator_class()
|
||||||
generator = gen_class(model, self.params.precision)
|
generator = gen_class(model, self.params.precision, **self.kwargs)
|
||||||
if self.params.variation_amount > 0:
|
if self.params.variation_amount > 0:
|
||||||
generator.set_variation(generator_args.get('seed'),
|
generator.set_variation(generator_args.get('seed'),
|
||||||
generator_args.get('variation_amount'),
|
generator_args.get('variation_amount'),
|
||||||
@ -276,7 +281,7 @@ class Generator:
|
|||||||
precision: str
|
precision: str
|
||||||
model: DiffusionPipeline
|
model: DiffusionPipeline
|
||||||
|
|
||||||
def __init__(self, model: DiffusionPipeline, precision: str):
|
def __init__(self, model: DiffusionPipeline, precision: str, **kwargs):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.seed = None
|
self.seed = None
|
||||||
|
@ -4,6 +4,10 @@ invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
|
|||||||
import PIL.Image
|
import PIL.Image
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||||
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||||
|
|
||||||
from ..stable_diffusion import (
|
from ..stable_diffusion import (
|
||||||
ConditioningData,
|
ConditioningData,
|
||||||
PostprocessingSettings,
|
PostprocessingSettings,
|
||||||
@ -13,8 +17,13 @@ from .base import Generator
|
|||||||
|
|
||||||
|
|
||||||
class Txt2Img(Generator):
|
class Txt2Img(Generator):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision,
|
||||||
super().__init__(model, precision)
|
control_model: Optional[Union[ControlNetModel, List[ControlNetModel]]] = None,
|
||||||
|
**kwargs):
|
||||||
|
self.control_model = control_model
|
||||||
|
if isinstance(self.control_model, list):
|
||||||
|
self.control_model = MultiControlNetModel(self.control_model)
|
||||||
|
super().__init__(model, precision, **kwargs)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_make_image(
|
def get_make_image(
|
||||||
@ -42,9 +51,12 @@ class Txt2Img(Generator):
|
|||||||
kwargs are 'width' and 'height'
|
kwargs are 'width' and 'height'
|
||||||
"""
|
"""
|
||||||
self.perlin = perlin
|
self.perlin = perlin
|
||||||
|
control_image = kwargs.get("control_image", None)
|
||||||
|
do_classifier_free_guidance = cfg_scale > 1.0
|
||||||
|
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||||
|
pipeline.control_model = self.control_model
|
||||||
pipeline.scheduler = sampler
|
pipeline.scheduler = sampler
|
||||||
|
|
||||||
uc, c, extra_conditioning_info = conditioning
|
uc, c, extra_conditioning_info = conditioning
|
||||||
@ -61,6 +73,37 @@ class Txt2Img(Generator):
|
|||||||
),
|
),
|
||||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||||
|
|
||||||
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||||
|
# and add in batch_size, num_images_per_prompt?
|
||||||
|
if control_image is not None:
|
||||||
|
if isinstance(self.control_model, ControlNetModel):
|
||||||
|
control_image = pipeline.prepare_control_image(
|
||||||
|
image=control_image,
|
||||||
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
# batch_size=batch_size * num_images_per_prompt,
|
||||||
|
# num_images_per_prompt=num_images_per_prompt,
|
||||||
|
device=self.control_model.device,
|
||||||
|
dtype=self.control_model.dtype,
|
||||||
|
)
|
||||||
|
elif isinstance(self.control_model, MultiControlNetModel):
|
||||||
|
images = []
|
||||||
|
for image_ in control_image:
|
||||||
|
image_ = self.model.prepare_control_image(
|
||||||
|
image=image_,
|
||||||
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
# batch_size=batch_size * num_images_per_prompt,
|
||||||
|
# num_images_per_prompt=num_images_per_prompt,
|
||||||
|
device=self.control_model.device,
|
||||||
|
dtype=self.control_model.dtype,
|
||||||
|
)
|
||||||
|
images.append(image_)
|
||||||
|
control_image = images
|
||||||
|
kwargs["control_image"] = control_image
|
||||||
|
|
||||||
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
|
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
|
||||||
pipeline_output = pipeline.image_from_embeddings(
|
pipeline_output = pipeline.image_from_embeddings(
|
||||||
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
|
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
|
||||||
@ -68,6 +111,7 @@ class Txt2Img(Generator):
|
|||||||
num_inference_steps=steps,
|
num_inference_steps=steps,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
Loading…
x
Reference in New Issue
Block a user