mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Draft
This commit is contained in:
parent
5d5a497ed4
commit
6ab9a5e108
@ -9,6 +9,7 @@ from typing import Literal, Optional, Union, List, Dict
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
|
|
||||||
|
from ...backend.model_management import BaseModelType, ModelType
|
||||||
from ..models.image import ImageField, ImageCategory, ResourceOrigin
|
from ..models.image import ImageField, ImageCategory, ResourceOrigin
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@ -105,9 +106,15 @@ CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control
|
|||||||
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
|
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetModelField(BaseModel):
|
||||||
|
"""ControlNet model field"""
|
||||||
|
|
||||||
|
model_name: str = Field(description="Name of the ControlNet model")
|
||||||
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
class ControlField(BaseModel):
|
class ControlField(BaseModel):
|
||||||
image: ImageField = Field(default=None, description="The control image")
|
image: ImageField = Field(default=None, description="The control image")
|
||||||
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
|
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
|
||||||
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
||||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||||
@ -154,7 +161,7 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
type: Literal["controlnet"] = "controlnet"
|
type: Literal["controlnet"] = "controlnet"
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The control image")
|
image: ImageField = Field(default=None, description="The control image")
|
||||||
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
|
control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny",
|
||||||
description="control model used")
|
description="control model used")
|
||||||
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
|
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
|
||||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||||
@ -182,7 +189,11 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
return ControlOutput(
|
return ControlOutput(
|
||||||
control=ControlField(
|
control=ControlField(
|
||||||
image=self.image,
|
image=self.image,
|
||||||
control_model=self.control_model,
|
#control_model=self.control_model,
|
||||||
|
control_model=ControlNetModelField(
|
||||||
|
model_name="canny",
|
||||||
|
base_model=BaseModelType.StableDiffusion1,
|
||||||
|
),
|
||||||
control_weight=self.control_weight,
|
control_weight=self.control_weight,
|
||||||
begin_step_percent=self.begin_step_percent,
|
begin_step_percent=self.begin_step_percent,
|
||||||
end_step_percent=self.end_step_percent,
|
end_step_percent=self.end_step_percent,
|
||||||
|
@ -71,16 +71,21 @@ def get_scheduler(
|
|||||||
scheduler_name: str,
|
scheduler_name: str,
|
||||||
) -> Scheduler:
|
) -> Scheduler:
|
||||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
|
||||||
scheduler_name, SCHEDULER_MAP['ddim'])
|
scheduler_name, SCHEDULER_MAP['ddim']
|
||||||
|
)
|
||||||
orig_scheduler_info = context.services.model_manager.get_model(
|
orig_scheduler_info = context.services.model_manager.get_model(
|
||||||
**scheduler_info.dict())
|
**scheduler_info.dict()
|
||||||
|
)
|
||||||
with orig_scheduler_info as orig_scheduler:
|
with orig_scheduler_info as orig_scheduler:
|
||||||
scheduler_config = orig_scheduler.config
|
scheduler_config = orig_scheduler.config
|
||||||
|
|
||||||
if "_backup" in scheduler_config:
|
if "_backup" in scheduler_config:
|
||||||
scheduler_config = scheduler_config["_backup"]
|
scheduler_config = scheduler_config["_backup"]
|
||||||
scheduler_config = {**scheduler_config, **
|
scheduler_config = {
|
||||||
scheduler_extra_config, "_backup": scheduler_config}
|
**scheduler_config,
|
||||||
|
**scheduler_extra_config,
|
||||||
|
"_backup": scheduler_config,
|
||||||
|
}
|
||||||
scheduler = scheduler_class.from_config(scheduler_config)
|
scheduler = scheduler_class.from_config(scheduler_config)
|
||||||
|
|
||||||
# hack copied over from generate.py
|
# hack copied over from generate.py
|
||||||
@ -137,8 +142,11 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||||
def dispatch_progress(
|
def dispatch_progress(
|
||||||
self, context: InvocationContext, source_node_id: str,
|
self,
|
||||||
intermediate_state: PipelineIntermediateState) -> None:
|
context: InvocationContext,
|
||||||
|
source_node_id: str,
|
||||||
|
intermediate_state: PipelineIntermediateState,
|
||||||
|
) -> None:
|
||||||
stable_diffusion_step_callback(
|
stable_diffusion_step_callback(
|
||||||
context=context,
|
context=context,
|
||||||
intermediate_state=intermediate_state,
|
intermediate_state=intermediate_state,
|
||||||
@ -147,11 +155,16 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_conditioning_data(
|
def get_conditioning_data(
|
||||||
self, context: InvocationContext, scheduler) -> ConditioningData:
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
scheduler,
|
||||||
|
) -> ConditioningData:
|
||||||
c, extra_conditioning_info = context.services.latents.get(
|
c, extra_conditioning_info = context.services.latents.get(
|
||||||
self.positive_conditioning.conditioning_name)
|
self.positive_conditioning.conditioning_name
|
||||||
|
)
|
||||||
uc, _ = context.services.latents.get(
|
uc, _ = context.services.latents.get(
|
||||||
self.negative_conditioning.conditioning_name)
|
self.negative_conditioning.conditioning_name
|
||||||
|
)
|
||||||
|
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(
|
||||||
unconditioned_embeddings=uc,
|
unconditioned_embeddings=uc,
|
||||||
@ -178,7 +191,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
return conditioning_data
|
return conditioning_data
|
||||||
|
|
||||||
def create_pipeline(
|
def create_pipeline(
|
||||||
self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
|
self,
|
||||||
|
unet,
|
||||||
|
scheduler,
|
||||||
|
) -> StableDiffusionGeneratorPipeline:
|
||||||
# TODO:
|
# TODO:
|
||||||
# configure_model_padding(
|
# configure_model_padding(
|
||||||
# unet,
|
# unet,
|
||||||
@ -213,6 +229,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
model: StableDiffusionGeneratorPipeline,
|
model: StableDiffusionGeneratorPipeline,
|
||||||
control_input: List[ControlField],
|
control_input: List[ControlField],
|
||||||
latents_shape: List[int],
|
latents_shape: List[int],
|
||||||
|
exit_stack: ExitStack,
|
||||||
do_classifier_free_guidance: bool = True,
|
do_classifier_free_guidance: bool = True,
|
||||||
) -> List[ControlNetData]:
|
) -> List[ControlNetData]:
|
||||||
|
|
||||||
@ -238,25 +255,19 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
control_data = []
|
control_data = []
|
||||||
control_models = []
|
control_models = []
|
||||||
for control_info in control_list:
|
for control_info in control_list:
|
||||||
# handle control models
|
control_model = exit_stack.enter_context(
|
||||||
if ("," in control_info.control_model):
|
context.model_manager.get_model(
|
||||||
control_model_split = control_info.control_model.split(",")
|
model_name=control_info.control_model.model_name,
|
||||||
control_name = control_model_split[0]
|
model_type=ModelType.ControlNet,
|
||||||
control_subfolder = control_model_split[1]
|
base_model=control_info.control_model.base_model,
|
||||||
print("Using HF model subfolders")
|
)
|
||||||
print(" control_name: ", control_name)
|
)
|
||||||
print(" control_subfolder: ", control_subfolder)
|
|
||||||
control_model = ControlNetModel.from_pretrained(
|
|
||||||
control_name, subfolder=control_subfolder,
|
|
||||||
torch_dtype=model.unet.dtype).to(
|
|
||||||
model.device)
|
|
||||||
else:
|
|
||||||
control_model = ControlNetModel.from_pretrained(
|
|
||||||
control_info.control_model, torch_dtype=model.unet.dtype).to(model.device)
|
|
||||||
control_models.append(control_model)
|
control_models.append(control_model)
|
||||||
control_image_field = control_info.image
|
control_image_field = control_info.image
|
||||||
input_image = context.services.images.get_pil_image(
|
input_image = context.services.images.get_pil_image(
|
||||||
control_image_field.image_name)
|
control_image_field.image_name
|
||||||
|
)
|
||||||
# self.image.image_type, self.image.image_name
|
# self.image.image_type, self.image.image_name
|
||||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||||
# and add in batch_size, num_images_per_prompt?
|
# and add in batch_size, num_images_per_prompt?
|
||||||
@ -278,7 +289,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
weight=control_info.control_weight,
|
weight=control_info.control_weight,
|
||||||
begin_step_percent=control_info.begin_step_percent,
|
begin_step_percent=control_info.begin_step_percent,
|
||||||
end_step_percent=control_info.end_step_percent,
|
end_step_percent=control_info.end_step_percent,
|
||||||
control_mode=control_info.control_mode,)
|
control_mode=control_info.control_mode,
|
||||||
|
)
|
||||||
control_data.append(control_item)
|
control_data.append(control_item)
|
||||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||||
return control_data
|
return control_data
|
||||||
@ -289,7 +301,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
graph_execution_state = context.services.graph_execution_manager.get(
|
||||||
context.graph_execution_state_id)
|
context.graph_execution_state_id
|
||||||
|
)
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
@ -298,14 +311,17 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}))
|
**lora.dict(exclude={"weight"})
|
||||||
|
)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
**self.unet.unet.dict())
|
**self.unet.unet.dict()
|
||||||
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
)
|
||||||
|
with ExitStack() as exit_stack,\
|
||||||
|
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||||
unet_info as unet:
|
unet_info as unet:
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
@ -322,6 +338,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
latents_shape=noise.shape,
|
latents_shape=noise.shape,
|
||||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
|
exit_stack=exit_stack,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
# TODO: Verify the noise is the right size
|
||||||
@ -374,7 +391,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
graph_execution_state = context.services.graph_execution_manager.get(
|
||||||
context.graph_execution_state_id)
|
context.graph_execution_state_id
|
||||||
|
)
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
@ -383,14 +401,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}))
|
**lora.dict(exclude={"weight"})
|
||||||
|
)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
**self.unet.unet.dict())
|
**self.unet.unet.dict()
|
||||||
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
)
|
||||||
|
with ExitStack() as exit_stack,\
|
||||||
|
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||||
unet_info as unet:
|
unet_info as unet:
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
@ -407,11 +428,13 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
latents_shape=noise.shape,
|
latents_shape=noise.shape,
|
||||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
|
exit_stack=exit_stack,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
# TODO: Verify the noise is the right size
|
||||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||||
latent, device=unet.device, dtype=latent.dtype)
|
latent, device=unet.device, dtype=latent.dtype
|
||||||
|
)
|
||||||
|
|
||||||
timesteps, _ = pipeline.get_img2img_timesteps(
|
timesteps, _ = pipeline.get_img2img_timesteps(
|
||||||
self.steps,
|
self.steps,
|
||||||
@ -535,7 +558,8 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
resized_latents = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
latents, size=(self.height // 8, self.width // 8),
|
latents, size=(self.height // 8, self.width // 8),
|
||||||
mode=self.mode, antialias=self.antialias
|
mode=self.mode, antialias=self.antialias
|
||||||
if self.mode in ["bilinear", "bicubic"] else False,)
|
if self.mode in ["bilinear", "bicubic"] else False,
|
||||||
|
)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -569,7 +593,8 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
resized_latents = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
latents, scale_factor=self.scale_factor, mode=self.mode,
|
latents, scale_factor=self.scale_factor, mode=self.mode,
|
||||||
antialias=self.antialias
|
antialias=self.antialias
|
||||||
if self.mode in ["bilinear", "bicubic"] else False,)
|
if self.mode in ["bilinear", "bicubic"] else False,
|
||||||
|
)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
Loading…
Reference in New Issue
Block a user