mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
resolve conflicts
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from contextlib import ExitStack
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import einops
|
||||
@ -9,9 +10,10 @@ from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_management.models.base import ModelType
|
||||
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
@ -21,6 +23,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
|
||||
PostprocessingSettings
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .compel import ConditioningField
|
||||
@ -77,16 +80,21 @@ def get_scheduler(
|
||||
scheduler_name: str,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
|
||||
scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
scheduler_name, SCHEDULER_MAP['ddim']
|
||||
)
|
||||
orig_scheduler_info = context.services.model_manager.get_model(
|
||||
**scheduler_info.dict())
|
||||
**scheduler_info.dict()
|
||||
)
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **
|
||||
scheduler_extra_config, "_backup": scheduler_config}
|
||||
scheduler_config = {
|
||||
**scheduler_config,
|
||||
**scheduler_extra_config,
|
||||
"_backup": scheduler_config,
|
||||
}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
@ -143,8 +151,11 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState) -> None:
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
@ -153,11 +164,16 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def get_conditioning_data(
|
||||
self, context: InvocationContext, scheduler) -> ConditioningData:
|
||||
self,
|
||||
context: InvocationContext,
|
||||
scheduler,
|
||||
) -> ConditioningData:
|
||||
c, extra_conditioning_info = context.services.latents.get(
|
||||
self.positive_conditioning.conditioning_name)
|
||||
self.positive_conditioning.conditioning_name
|
||||
)
|
||||
uc, _ = context.services.latents.get(
|
||||
self.negative_conditioning.conditioning_name)
|
||||
self.negative_conditioning.conditioning_name
|
||||
)
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
unconditioned_embeddings=uc,
|
||||
@ -184,7 +200,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
return conditioning_data
|
||||
|
||||
def create_pipeline(
|
||||
self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
|
||||
self,
|
||||
unet,
|
||||
scheduler,
|
||||
) -> StableDiffusionGeneratorPipeline:
|
||||
# TODO:
|
||||
# configure_model_padding(
|
||||
# unet,
|
||||
@ -219,6 +238,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
model: StableDiffusionGeneratorPipeline,
|
||||
control_input: List[ControlField],
|
||||
latents_shape: List[int],
|
||||
exit_stack: ExitStack,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> List[ControlNetData]:
|
||||
|
||||
@ -244,25 +264,19 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
control_data = []
|
||||
control_models = []
|
||||
for control_info in control_list:
|
||||
# handle control models
|
||||
if ("," in control_info.control_model):
|
||||
control_model_split = control_info.control_model.split(",")
|
||||
control_name = control_model_split[0]
|
||||
control_subfolder = control_model_split[1]
|
||||
print("Using HF model subfolders")
|
||||
print(" control_name: ", control_name)
|
||||
print(" control_subfolder: ", control_subfolder)
|
||||
control_model = ControlNetModel.from_pretrained(
|
||||
control_name, subfolder=control_subfolder,
|
||||
torch_dtype=model.unet.dtype).to(
|
||||
model.device)
|
||||
else:
|
||||
control_model = ControlNetModel.from_pretrained(
|
||||
control_info.control_model, torch_dtype=model.unet.dtype).to(model.device)
|
||||
control_model = exit_stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=control_info.control_model.model_name,
|
||||
model_type=ModelType.ControlNet,
|
||||
base_model=control_info.control_model.base_model,
|
||||
)
|
||||
)
|
||||
|
||||
control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.services.images.get_pil_image(
|
||||
control_image_field.image_name)
|
||||
control_image_field.image_name
|
||||
)
|
||||
# self.image.image_type, self.image.image_name
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
@ -284,7 +298,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
weight=control_info.control_weight,
|
||||
begin_step_percent=control_info.begin_step_percent,
|
||||
end_step_percent=control_info.end_step_percent,
|
||||
control_mode=control_info.control_mode,)
|
||||
control_mode=control_info.control_mode,
|
||||
)
|
||||
control_data.append(control_item)
|
||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||
return control_data
|
||||
@ -295,7 +310,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id)
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
@ -304,14 +320,17 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
**lora.dict(exclude={"weight"})
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict())
|
||||
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
**self.unet.unet.dict()
|
||||
)
|
||||
with ExitStack() as exit_stack,\
|
||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
unet_info as unet:
|
||||
|
||||
scheduler = get_scheduler(
|
||||
@ -328,6 +347,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
@ -380,7 +400,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id)
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
@ -389,14 +410,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
**lora.dict(exclude={"weight"})
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict())
|
||||
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
**self.unet.unet.dict()
|
||||
)
|
||||
with ExitStack() as exit_stack,\
|
||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
unet_info as unet:
|
||||
|
||||
scheduler = get_scheduler(
|
||||
@ -413,11 +437,13 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||
latent, device=unet.device, dtype=latent.dtype)
|
||||
latent, device=unet.device, dtype=latent.dtype
|
||||
)
|
||||
|
||||
timesteps, _ = pipeline.get_img2img_timesteps(
|
||||
self.steps,
|
||||
@ -457,6 +483,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
default=False,
|
||||
description="Decode latents by overlaping tiles(less memory consumption)")
|
||||
fp32: bool = Field(False, description="Decode in full precision")
|
||||
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
@ -526,7 +553,8 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -548,9 +576,9 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(
|
||||
description="The latents to resize")
|
||||
width: int = Field(
|
||||
width: Union[int, None] = Field(default=512,
|
||||
ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: int = Field(
|
||||
height: Union[int, None] = Field(default=512,
|
||||
ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(
|
||||
default="bilinear", description="The interpolation mode")
|
||||
@ -564,7 +592,8 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents, size=(self.height // 8, self.width // 8),
|
||||
mode=self.mode, antialias=self.antialias
|
||||
if self.mode in ["bilinear", "bicubic"] else False,)
|
||||
if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
@ -598,7 +627,8 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents, scale_factor=self.scale_factor, mode=self.mode,
|
||||
antialias=self.antialias
|
||||
if self.mode in ["bilinear", "bicubic"] else False,)
|
||||
if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
Reference in New Issue
Block a user