fix: Slow loading of Loras

Co-Authored-By: StAlKeR7779 <7768370+StAlKeR7779@users.noreply.github.com>
This commit is contained in:
blessedcoolant 2023-07-05 14:37:16 +12:00 committed by psychedelicious
parent 0f0336b6ef
commit c0501ed5c2
4 changed files with 253 additions and 206 deletions

View File

@ -1,28 +1,27 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field
from contextlib import ExitStack
import re import re
from contextlib import ExitStack
from typing import List, Literal, Optional, Union
import torch import torch
from compel import Compel
from compel.prompt_parser import (Blend, Conjunction,
CrossAttentionControlSubstitute,
FlattenedPrompt, Fragment)
from pydantic import BaseModel, Field
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from .model import ClipField
from ...backend.util.devices import torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.model_management import BaseModelType, ModelType, SubModelType from ...backend.model_management import BaseModelType, ModelType, SubModelType
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from compel import Compel from ...backend.util.devices import torch_dtype
from compel.prompt_parser import ( from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
Blend, InvocationConfig, InvocationContext)
CrossAttentionControlSubstitute, from .model import ClipField
FlattenedPrompt,
Fragment, Conjunction,
)
class ConditioningField(BaseModel): class ConditioningField(BaseModel):
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data") conditioning_name: Optional[str] = Field(
default=None, description="The name of conditioning data")
class Config: class Config:
schema_extra = {"required": ["conditioning_name"]} schema_extra = {"required": ["conditioning_name"]}
@ -52,84 +51,92 @@ class CompelInvocation(BaseInvocation):
"title": "Prompt (Compel)", "title": "Prompt (Compel)",
"tags": ["prompt", "compel"], "tags": ["prompt", "compel"],
"type_hints": { "type_hints": {
"model": "model" "model": "model"
} }
}, },
} }
@torch.no_grad() @torch.inference_mode()
def invoke(self, context: InvocationContext) -> CompelOutput: def invoke(self, context: InvocationContext) -> CompelOutput:
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(), **self.clip.tokenizer.dict(),
) )
text_encoder_info = context.services.model_manager.get_model( text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.dict(), **self.clip.text_encoder.dict(),
) )
with tokenizer_info as orig_tokenizer,\
text_encoder_info as text_encoder:
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] def _lora_loader():
for lora in self.clip.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
ti_list = [] #loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model
)
except Exception:
#print(e)
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
with ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\ ti_list = []
ModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager): for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
compel = Compel( try:
tokenizer=tokenizer, ti_list.append(
text_encoder=text_encoder, context.services.model_manager.get_model(
textual_inversion_manager=ti_manager, model_name=name,
dtype_for_device_getter=torch_dtype, base_model=self.clip.text_encoder.base_model,
truncate_long_prompts=True, # TODO: model_type=ModelType.TextualInversion,
).context.model
) )
except Exception:
conjunction = Compel.parse_prompt_string(self.prompt) # print(e)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0] #import traceback
# print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
if context.services.configuration.log_tokenization: with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
log_tokenization_for_prompt_object(prompt, tokenizer) ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
text_encoder_info as text_encoder:
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt) compel = Compel(
tokenizer=tokenizer,
# TODO: long prompt support text_encoder=text_encoder,
#if not self.truncate_long_prompts: textual_inversion_manager=ti_manager,
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) dtype_for_device_getter=torch_dtype,
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( truncate_long_prompts=True, # TODO:
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.save(conditioning_name, (c, ec))
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
) )
conjunction = Compel.parse_prompt_string(self.prompt)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
if context.services.configuration.log_tokenization:
log_tokenization_for_prompt_object(prompt, tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(
prompt)
# TODO: long prompt support
# if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(
tokenizer, conjunction),
cross_attention_control_args=options.get(
"cross_attention_control", None),)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.save(conditioning_name, (c, ec))
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
def get_max_token_count( def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction],
) -> int: truncate_if_too_long=False) -> int:
if type(prompt) is Blend: if type(prompt) is Blend:
blend: Blend = prompt blend: Blend = prompt
return max( return max(
@ -148,13 +155,13 @@ def get_max_token_count(
) )
else: else:
return len( return len(
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long) get_tokens_for_prompt_object(
) tokenizer, prompt, truncate_if_too_long))
def get_tokens_for_prompt_object( def get_tokens_for_prompt_object(
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
) -> [str]: ) -> List[str]:
if type(parsed_prompt) is Blend: if type(parsed_prompt) is Blend:
raise ValueError( raise ValueError(
"Blend is not supported here - you need to get tokens for each of its .children" "Blend is not supported here - you need to get tokens for each of its .children"
@ -183,7 +190,7 @@ def log_tokenization_for_conjunction(
): ):
display_label_prefix = display_label_prefix or "" display_label_prefix = display_label_prefix or ""
for i, p in enumerate(c.prompts): for i, p in enumerate(c.prompts):
if len(c.prompts)>1: if len(c.prompts) > 1:
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
else: else:
this_display_label_prefix = display_label_prefix this_display_label_prefix = display_label_prefix
@ -238,7 +245,8 @@ def log_tokenization_for_prompt_object(
) )
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False): def log_tokenization_for_text(
text, tokenizer, display_label=None, truncate_if_too_long=False):
"""shows how the prompt is tokenized """shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word, # usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' ' # but for readability it has been replaced with ' '

View File

@ -4,18 +4,17 @@ from contextlib import ExitStack
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
import einops import einops
from pydantic import BaseModel, Field, validator
import torch import torch
from diffusers import ControlNetModel, DPMSolverMultistepScheduler from diffusers import ControlNetModel, DPMSolverMultistepScheduler
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.image_util.seamless import configure_model_padding from ...backend.image_util.seamless import configure_model_padding
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline, ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline,
@ -24,7 +23,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
PostprocessingSettings PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import torch_dtype from ...backend.util.devices import torch_dtype
from ...backend.model_management.lora import ModelPatcher from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext) InvocationConfig, InvocationContext)
from .compel import ConditioningField from .compel import ConditioningField
@ -32,14 +31,17 @@ from .controlnet_image_processors import ControlField
from .image import ImageOutput from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField from .model import ModelInfo, UNetField, VaeField
class LatentsField(BaseModel): class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations""" """A latents field used for passing latents between invocations"""
latents_name: Optional[str] = Field(default=None, description="The name of the latents") latents_name: Optional[str] = Field(
default=None, description="The name of the latents")
class Config: class Config:
schema_extra = {"required": ["latents_name"]} schema_extra = {"required": ["latents_name"]}
class LatentsOutput(BaseInvocationOutput): class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents""" """Base class for invocations that output latents"""
#fmt: off #fmt: off
@ -53,11 +55,11 @@ class LatentsOutput(BaseInvocationOutput):
def build_latents_output(latents_name: str, latents: torch.Tensor): def build_latents_output(latents_name: str, latents: torch.Tensor):
return LatentsOutput( return LatentsOutput(
latents=LatentsField(latents_name=latents_name), latents=LatentsField(latents_name=latents_name),
width=latents.size()[3] * 8, width=latents.size()[3] * 8,
height=latents.size()[2] * 8, height=latents.size()[2] * 8,
) )
SAMPLER_NAME_VALUES = Literal[ SAMPLER_NAME_VALUES = Literal[
@ -70,16 +72,19 @@ def get_scheduler(
scheduler_info: ModelInfo, scheduler_info: ModelInfo,
scheduler_name: str, scheduler_name: str,
) -> Scheduler: ) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim']) scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict()) scheduler_name, SCHEDULER_MAP['ddim'])
orig_scheduler_info = context.services.model_manager.get_model(
**scheduler_info.dict())
with orig_scheduler_info as orig_scheduler: with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config scheduler_config = orig_scheduler.config
if "_backup" in scheduler_config: if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"] scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config} scheduler_config = {**scheduler_config, **
scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config) scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py # hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'): if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False scheduler.uses_inpainting_model = lambda: False
@ -124,18 +129,18 @@ class TextToLatentsInvocation(BaseInvocation):
"ui": { "ui": {
"tags": ["latents"], "tags": ["latents"],
"type_hints": { "type_hints": {
"model": "model", "model": "model",
"control": "control", "control": "control",
# "cfg_scale": "float", # "cfg_scale": "float",
"cfg_scale": "number" "cfg_scale": "number"
} }
}, },
} }
# TODO: pass this an emitter method or something? or a session for dispatching? # TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress( def dispatch_progress(
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState self, context: InvocationContext, source_node_id: str,
) -> None: intermediate_state: PipelineIntermediateState) -> None:
stable_diffusion_step_callback( stable_diffusion_step_callback(
context=context, context=context,
intermediate_state=intermediate_state, intermediate_state=intermediate_state,
@ -143,9 +148,12 @@ class TextToLatentsInvocation(BaseInvocation):
source_node_id=source_node_id, source_node_id=source_node_id,
) )
def get_conditioning_data(self, context: InvocationContext, scheduler) -> ConditioningData: def get_conditioning_data(
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name) self, context: InvocationContext, scheduler) -> ConditioningData:
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) c, extra_conditioning_info = context.services.latents.get(
self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(
self.negative_conditioning.conditioning_name)
conditioning_data = ConditioningData( conditioning_data = ConditioningData(
unconditioned_embeddings=uc, unconditioned_embeddings=uc,
@ -153,10 +161,10 @@ class TextToLatentsInvocation(BaseInvocation):
guidance_scale=self.cfg_scale, guidance_scale=self.cfg_scale,
extra=extra_conditioning_info, extra=extra_conditioning_info,
postprocessing_settings=PostprocessingSettings( postprocessing_settings=PostprocessingSettings(
threshold=0.0,#threshold, threshold=0.0, # threshold,
warmup=0.2,#warmup, warmup=0.2, # warmup,
h_symmetry_time_pct=None,#h_symmetry_time_pct, h_symmetry_time_pct=None, # h_symmetry_time_pct,
v_symmetry_time_pct=None#v_symmetry_time_pct, v_symmetry_time_pct=None # v_symmetry_time_pct,
), ),
) )
@ -164,31 +172,32 @@ class TextToLatentsInvocation(BaseInvocation):
scheduler, scheduler,
# for ddim scheduler # for ddim scheduler
eta=0.0, #ddim_eta eta=0.0, # ddim_eta
# for ancestral and sde schedulers # for ancestral and sde schedulers
generator=torch.Generator(device=uc.device).manual_seed(0), generator=torch.Generator(device=uc.device).manual_seed(0),
) )
return conditioning_data return conditioning_data
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline: def create_pipeline(
self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
# TODO: # TODO:
#configure_model_padding( # configure_model_padding(
# unet, # unet,
# self.seamless, # self.seamless,
# self.seamless_axes, # self.seamless_axes,
#) # )
class FakeVae: class FakeVae:
class FakeVaeConfig: class FakeVaeConfig:
def __init__(self): def __init__(self):
self.block_out_channels = [0] self.block_out_channels = [0]
def __init__(self): def __init__(self):
self.config = FakeVae.FakeVaeConfig() self.config = FakeVae.FakeVaeConfig()
return StableDiffusionGeneratorPipeline( return StableDiffusionGeneratorPipeline(
vae=FakeVae(), # TODO: oh... vae=FakeVae(), # TODO: oh...
text_encoder=None, text_encoder=None,
tokenizer=None, tokenizer=None,
unet=unet, unet=unet,
@ -198,11 +207,12 @@ class TextToLatentsInvocation(BaseInvocation):
requires_safety_checker=False, requires_safety_checker=False,
precision="float16" if unet.dtype == torch.float16 else "float32", precision="float16" if unet.dtype == torch.float16 else "float32",
) )
def prep_control_data( def prep_control_data(
self, self,
context: InvocationContext, context: InvocationContext,
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device # really only need model for dtype and device
model: StableDiffusionGeneratorPipeline,
control_input: List[ControlField], control_input: List[ControlField],
latents_shape: List[int], latents_shape: List[int],
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
@ -238,15 +248,17 @@ class TextToLatentsInvocation(BaseInvocation):
print("Using HF model subfolders") print("Using HF model subfolders")
print(" control_name: ", control_name) print(" control_name: ", control_name)
print(" control_subfolder: ", control_subfolder) print(" control_subfolder: ", control_subfolder)
control_model = ControlNetModel.from_pretrained(control_name, control_model = ControlNetModel.from_pretrained(
subfolder=control_subfolder, control_name, subfolder=control_subfolder,
torch_dtype=model.unet.dtype).to(model.device) torch_dtype=model.unet.dtype).to(
model.device)
else: else:
control_model = ControlNetModel.from_pretrained(control_info.control_model, control_model = ControlNetModel.from_pretrained(
torch_dtype=model.unet.dtype).to(model.device) control_info.control_model, torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model) control_models.append(control_model)
control_image_field = control_info.image control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_name) input_image = context.services.images.get_pil_image(
control_image_field.image_name)
# self.image.image_type, self.image.image_name # self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes # FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt? # and add in batch_size, num_images_per_prompt?
@ -263,41 +275,50 @@ class TextToLatentsInvocation(BaseInvocation):
dtype=control_model.dtype, dtype=control_model.dtype,
control_mode=control_info.control_mode, control_mode=control_info.control_mode,
) )
control_item = ControlNetData(model=control_model, control_item = ControlNetData(
image_tensor=control_image, model=control_model, image_tensor=control_image,
weight=control_info.control_weight, weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent, begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent, end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode, control_mode=control_info.control_mode,)
)
control_data.append(control_item) control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData] # MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data return control_data
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) def _lora_loader():
with unet_info as unet: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict())
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
scheduler = get_scheduler( scheduler = get_scheduler(
context=context, context=context,
scheduler_info=self.unet.scheduler, scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler, scheduler_name=self.scheduler,
) )
pipeline = self.create_pipeline(unet, scheduler) pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler) conditioning_data = self.get_conditioning_data(context, scheduler)
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
control_data = self.prep_control_data( control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control, model=pipeline, context=context, control_input=self.control,
latents_shape=noise.shape, latents_shape=noise.shape,
@ -305,16 +326,15 @@ class TextToLatentsInvocation(BaseInvocation):
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
) )
with ModelPatcher.apply_lora_unet(pipeline.unet, loras): # TODO: Verify the noise is the right size
# TODO: Verify the noise is the right size result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)), noise=noise,
noise=noise, num_inference_steps=self.steps,
num_inference_steps=self.steps, conditioning_data=conditioning_data,
conditioning_data=conditioning_data, control_data=control_data, # list[ControlNetData]
control_data=control_data, # list[ControlNetData] callback=step_callback,
callback=step_callback, )
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -323,14 +343,18 @@ class TextToLatentsInvocation(BaseInvocation):
context.services.latents.save(name, result_latents) context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents) return build_latents_output(latents_name=name, latents=result_latents)
class LatentsToLatentsInvocation(TextToLatentsInvocation): class LatentsToLatentsInvocation(TextToLatentsInvocation):
"""Generates latents using latents as base image.""" """Generates latents using latents as base image."""
type: Literal["l2l"] = "l2l" type: Literal["l2l"] = "l2l"
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image") latents: Optional[LatentsField] = Field(
strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use") description="The latents to use as a base image")
strength: float = Field(
default=0.7, ge=0, le=1,
description="The strength of the latents to use")
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
@ -345,22 +369,31 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
}, },
} }
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name) latent = context.services.latents.get(self.latents.latents_name)
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
unet_info = context.services.model_manager.get_model( def _lora_loader():
**self.unet.unet.dict(), for lora in self.unet.loras:
) lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
with unet_info as unet: unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict())
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
scheduler = get_scheduler( scheduler = get_scheduler(
context=context, context=context,
@ -370,7 +403,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
pipeline = self.create_pipeline(unet, scheduler) pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler) conditioning_data = self.get_conditioning_data(context, scheduler)
control_data = self.prep_control_data( control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control, model=pipeline, context=context, control_input=self.control,
latents_shape=noise.shape, latents_shape=noise.shape,
@ -380,8 +413,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like( initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=unet.device, dtype=latent.dtype latent, device=unet.device, dtype=latent.dtype)
)
timesteps, _ = pipeline.get_img2img_timesteps( timesteps, _ = pipeline.get_img2img_timesteps(
self.steps, self.steps,
@ -389,18 +421,15 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
device=unet.device, device=unet.device,
) )
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras] result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=initial_latents,
with ModelPatcher.apply_lora_unet(pipeline.unet, loras): timesteps=timesteps,
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( noise=noise,
latents=initial_latents, num_inference_steps=self.steps,
timesteps=timesteps, conditioning_data=conditioning_data,
noise=noise, control_data=control_data, # list[ControlNetData]
num_inference_steps=self.steps, callback=step_callback
conditioning_data=conditioning_data, )
control_data=control_data, # list[ControlNetData]
callback=step_callback
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -417,9 +446,12 @@ class LatentsToImageInvocation(BaseInvocation):
type: Literal["l2i"] = "l2i" type: Literal["l2i"] = "l2i"
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from") latents: Optional[LatentsField] = Field(
description="The latents to generate an image from")
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)") tiled: bool = Field(
default=False,
description="Decode latents by overlaping tiles(less memory consumption)")
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
@ -429,7 +461,7 @@ class LatentsToImageInvocation(BaseInvocation):
}, },
} }
@torch.no_grad() @torch.inference_mode()
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
@ -450,7 +482,7 @@ class LatentsToImageInvocation(BaseInvocation):
# copied from diffusers pipeline # copied from diffusers pipeline
latents = latents / vae.config.scaling_factor latents = latents / vae.config.scaling_factor
image = vae.decode(latents, return_dict=False)[0] image = vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) # denormalize image = (image / 2 + 0.5).clamp(0, 1) # denormalize
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy() np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
@ -473,9 +505,9 @@ class LatentsToImageInvocation(BaseInvocation):
height=image_dto.height, height=image_dto.height,
) )
LATENTS_INTERPOLATION_MODE = Literal[
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact" LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear",
] "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
class ResizeLatentsInvocation(BaseInvocation): class ResizeLatentsInvocation(BaseInvocation):
@ -484,21 +516,25 @@ class ResizeLatentsInvocation(BaseInvocation):
type: Literal["lresize"] = "lresize" type: Literal["lresize"] = "lresize"
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to resize") latents: Optional[LatentsField] = Field(
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)") description="The latents to resize")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)") width: int = Field(
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") ge=64, multiple_of=8, description="The width to resize to (px)")
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") height: int = Field(
ge=64, multiple_of=8, description="The height to resize to (px)")
mode: LATENTS_INTERPOLATION_MODE = Field(
default="bilinear", description="The interpolation mode")
antialias: bool = Field(
default=False,
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
latents, latents, size=(self.height // 8, self.width // 8),
size=(self.height // 8, self.width // 8), mode=self.mode, antialias=self.antialias
mode=self.mode, if self.mode in ["bilinear", "bicubic"] else False,)
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -515,21 +551,24 @@ class ScaleLatentsInvocation(BaseInvocation):
type: Literal["lscale"] = "lscale" type: Literal["lscale"] = "lscale"
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to scale") latents: Optional[LatentsField] = Field(
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents") description="The latents to scale")
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") scale_factor: float = Field(
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") gt=0, description="The factor by which to scale the latents")
mode: LATENTS_INTERPOLATION_MODE = Field(
default="bilinear", description="The interpolation mode")
antialias: bool = Field(
default=False,
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
# resizing # resizing
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
latents, latents, scale_factor=self.scale_factor, mode=self.mode,
scale_factor=self.scale_factor, antialias=self.antialias
mode=self.mode, if self.mode in ["bilinear", "bicubic"] else False,)
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -548,7 +587,9 @@ class ImageToLatentsInvocation(BaseInvocation):
# Inputs # Inputs
image: Union[ImageField, None] = Field(description="The image to encode") image: Union[ImageField, None] = Field(description="The image to encode")
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)") tiled: bool = Field(
default=False,
description="Encode latents by overlaping tiles(less memory consumption)")
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
@ -558,7 +599,7 @@ class ImageToLatentsInvocation(BaseInvocation):
}, },
} }
@torch.no_grad() @torch.inference_mode()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
# image = context.services.images.get( # image = context.services.images.get(
# self.image.image_type, self.image.image_name # self.image.image_type, self.image.image_name

View File

@ -1,18 +1,17 @@
from __future__ import annotations from __future__ import annotations
import copy import copy
from pathlib import Path
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Dict, Tuple, Any from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import torch import torch
from compel.embeddings_provider import BaseTextualInversionManager
from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file from safetensors.torch import load_file
from torch.utils.hooks import RemovableHandle from torch.utils.hooks import RemovableHandle
from diffusers.models import UNet2DConditionModel
from transformers import CLIPTextModel from transformers import CLIPTextModel
from compel.embeddings_provider import BaseTextualInversionManager
class LoRALayerBase: class LoRALayerBase:
#rank: Optional[int] #rank: Optional[int]
@ -527,7 +526,7 @@ class ModelPatcher:
): ):
original_weights = dict() original_weights = dict()
try: try:
with torch.no_grad(): with torch.inference_mode():
for lora, lora_weight in loras: for lora, lora_weight in loras:
#assert lora.device.type == "cpu" #assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items(): for layer_key, layer in lora.layers.items():
@ -539,9 +538,10 @@ class ModelPatcher:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
# enable autocast to calc fp16 loras on cpu # enable autocast to calc fp16 loras on cpu
with torch.autocast(device_type="cpu"): #with torch.autocast(device_type="cpu"):
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 layer.to(dtype=torch.float32)
layer_weight = layer.get_weight() * lora_weight * layer_scale layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight() * lora_weight * layer_scale
if module.weight.shape != layer_weight.shape: if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris # TODO: debug on lycoris
@ -552,7 +552,7 @@ class ModelPatcher:
yield # wait for context manager exit yield # wait for context manager exit
finally: finally:
with torch.no_grad(): with torch.inference_mode():
for module_key, weight in original_weights.items(): for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight) model.get_submodule(module_key).weight.copy_(weight)

View File

@ -49,8 +49,6 @@ export const addLoRAsToGraph = (
'_' '_'
)}`; )}`;
console.log(lastLoraNodeId, currentLoraNodeId, currentLoraIndex, loraField);
const loraLoaderNode: LoraLoaderInvocation = { const loraLoaderNode: LoraLoaderInvocation = {
type: 'lora_loader', type: 'lora_loader',
id: currentLoraNodeId, id: currentLoraNodeId,