fix: Slow loading of Loras

Co-Authored-By: StAlKeR7779 <7768370+StAlKeR7779@users.noreply.github.com>
This commit is contained in:
blessedcoolant 2023-07-05 14:37:16 +12:00 committed by psychedelicious
parent 0f0336b6ef
commit c0501ed5c2
4 changed files with 253 additions and 206 deletions

View File

@ -1,28 +1,27 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field
from contextlib import ExitStack
import re
from contextlib import ExitStack
from typing import List, Literal, Optional, Union
import torch
from compel import Compel
from compel.prompt_parser import (Blend, Conjunction,
CrossAttentionControlSubstitute,
FlattenedPrompt, Fragment)
from pydantic import BaseModel, Field
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from .model import ClipField
from ...backend.util.devices import torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.model_management import BaseModelType, ModelType, SubModelType
from ...backend.model_management.lora import ModelPatcher
from compel import Compel
from compel.prompt_parser import (
Blend,
CrossAttentionControlSubstitute,
FlattenedPrompt,
Fragment, Conjunction,
)
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.util.devices import torch_dtype
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .model import ClipField
class ConditioningField(BaseModel):
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
conditioning_name: Optional[str] = Field(
default=None, description="The name of conditioning data")
class Config:
schema_extra = {"required": ["conditioning_name"]}
@ -52,84 +51,92 @@ class CompelInvocation(BaseInvocation):
"title": "Prompt (Compel)",
"tags": ["prompt", "compel"],
"type_hints": {
"model": "model"
"model": "model"
}
},
}
@torch.no_grad()
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> CompelOutput:
tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(),
)
text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.dict(),
)
with tokenizer_info as orig_tokenizer,\
text_encoder_info as text_encoder:
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
def _lora_loader():
for lora in self.clip.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model
)
except Exception:
#print(e)
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
with ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
ModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager):
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO:
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model
)
conjunction = Compel.parse_prompt_string(self.prompt)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
except Exception:
# print(e)
#import traceback
# print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
if context.services.configuration.log_tokenization:
log_tokenization_for_prompt_object(prompt, tokenizer)
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
text_encoder_info as text_encoder:
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
# TODO: long prompt support
#if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.save(conditioning_name, (c, ec))
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO:
)
conjunction = Compel.parse_prompt_string(self.prompt)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
if context.services.configuration.log_tokenization:
log_tokenization_for_prompt_object(prompt, tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(
prompt)
# TODO: long prompt support
# if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(
tokenizer, conjunction),
cross_attention_control_args=options.get(
"cross_attention_control", None),)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.save(conditioning_name, (c, ec))
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
) -> int:
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction],
truncate_if_too_long=False) -> int:
if type(prompt) is Blend:
blend: Blend = prompt
return max(
@ -148,13 +155,13 @@ def get_max_token_count(
)
else:
return len(
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)
)
get_tokens_for_prompt_object(
tokenizer, prompt, truncate_if_too_long))
def get_tokens_for_prompt_object(
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
) -> [str]:
) -> List[str]:
if type(parsed_prompt) is Blend:
raise ValueError(
"Blend is not supported here - you need to get tokens for each of its .children"
@ -183,7 +190,7 @@ def log_tokenization_for_conjunction(
):
display_label_prefix = display_label_prefix or ""
for i, p in enumerate(c.prompts):
if len(c.prompts)>1:
if len(c.prompts) > 1:
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
else:
this_display_label_prefix = display_label_prefix
@ -238,7 +245,8 @@ def log_tokenization_for_prompt_object(
)
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
def log_tokenization_for_text(
text, tokenizer, display_label=None, truncate_if_too_long=False):
"""shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '

View File

@ -4,18 +4,17 @@ from contextlib import ExitStack
from typing import List, Literal, Optional, Union
import einops
from pydantic import BaseModel, Field, validator
import torch
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.image_util.seamless import configure_model_padding
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline,
@ -24,7 +23,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import torch_dtype
from ...backend.model_management.lora import ModelPatcher
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .compel import ConditioningField
@ -32,14 +31,17 @@ from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations"""
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
latents_name: Optional[str] = Field(
default=None, description="The name of the latents")
class Config:
schema_extra = {"required": ["latents_name"]}
class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents"""
#fmt: off
@ -53,11 +55,11 @@ class LatentsOutput(BaseInvocationOutput):
def build_latents_output(latents_name: str, latents: torch.Tensor):
return LatentsOutput(
latents=LatentsField(latents_name=latents_name),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
return LatentsOutput(
latents=LatentsField(latents_name=latents_name),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
SAMPLER_NAME_VALUES = Literal[
@ -70,16 +72,19 @@ def get_scheduler(
scheduler_info: ModelInfo,
scheduler_name: str,
) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict())
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
scheduler_name, SCHEDULER_MAP['ddim'])
orig_scheduler_info = context.services.model_manager.get_model(
**scheduler_info.dict())
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
scheduler_config = {**scheduler_config, **
scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
@ -124,18 +129,18 @@ class TextToLatentsInvocation(BaseInvocation):
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model",
"control": "control",
# "cfg_scale": "float",
"cfg_scale": "number"
"model": "model",
"control": "control",
# "cfg_scale": "float",
"cfg_scale": "number"
}
},
}
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
) -> None:
self, context: InvocationContext, source_node_id: str,
intermediate_state: PipelineIntermediateState) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
@ -143,9 +148,12 @@ class TextToLatentsInvocation(BaseInvocation):
source_node_id=source_node_id,
)
def get_conditioning_data(self, context: InvocationContext, scheduler) -> ConditioningData:
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
def get_conditioning_data(
self, context: InvocationContext, scheduler) -> ConditioningData:
c, extra_conditioning_info = context.services.latents.get(
self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(
self.negative_conditioning.conditioning_name)
conditioning_data = ConditioningData(
unconditioned_embeddings=uc,
@ -153,10 +161,10 @@ class TextToLatentsInvocation(BaseInvocation):
guidance_scale=self.cfg_scale,
extra=extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=0.0,#threshold,
warmup=0.2,#warmup,
h_symmetry_time_pct=None,#h_symmetry_time_pct,
v_symmetry_time_pct=None#v_symmetry_time_pct,
threshold=0.0, # threshold,
warmup=0.2, # warmup,
h_symmetry_time_pct=None, # h_symmetry_time_pct,
v_symmetry_time_pct=None # v_symmetry_time_pct,
),
)
@ -164,31 +172,32 @@ class TextToLatentsInvocation(BaseInvocation):
scheduler,
# for ddim scheduler
eta=0.0, #ddim_eta
eta=0.0, # ddim_eta
# for ancestral and sde schedulers
generator=torch.Generator(device=uc.device).manual_seed(0),
)
return conditioning_data
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
def create_pipeline(
self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
# TODO:
#configure_model_padding(
# configure_model_padding(
# unet,
# self.seamless,
# self.seamless_axes,
#)
# )
class FakeVae:
class FakeVaeConfig:
def __init__(self):
self.block_out_channels = [0]
def __init__(self):
self.config = FakeVae.FakeVaeConfig()
return StableDiffusionGeneratorPipeline(
vae=FakeVae(), # TODO: oh...
vae=FakeVae(), # TODO: oh...
text_encoder=None,
tokenizer=None,
unet=unet,
@ -198,11 +207,12 @@ class TextToLatentsInvocation(BaseInvocation):
requires_safety_checker=False,
precision="float16" if unet.dtype == torch.float16 else "float32",
)
def prep_control_data(
self,
context: InvocationContext,
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
# really only need model for dtype and device
model: StableDiffusionGeneratorPipeline,
control_input: List[ControlField],
latents_shape: List[int],
do_classifier_free_guidance: bool = True,
@ -238,15 +248,17 @@ class TextToLatentsInvocation(BaseInvocation):
print("Using HF model subfolders")
print(" control_name: ", control_name)
print(" control_subfolder: ", control_subfolder)
control_model = ControlNetModel.from_pretrained(control_name,
subfolder=control_subfolder,
torch_dtype=model.unet.dtype).to(model.device)
control_model = ControlNetModel.from_pretrained(
control_name, subfolder=control_subfolder,
torch_dtype=model.unet.dtype).to(
model.device)
else:
control_model = ControlNetModel.from_pretrained(control_info.control_model,
torch_dtype=model.unet.dtype).to(model.device)
control_model = ControlNetModel.from_pretrained(
control_info.control_model, torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_name)
input_image = context.services.images.get_pil_image(
control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
@ -263,41 +275,50 @@ class TextToLatentsInvocation(BaseInvocation):
dtype=control_model.dtype,
control_mode=control_info.control_mode,
)
control_item = ControlNetData(model=control_model,
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
)
control_item = ControlNetData(
model=control_model, image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,)
control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
with unet_info as unet:
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict())
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler)
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
latents_shape=noise.shape,
@ -305,16 +326,15 @@ class TextToLatentsInvocation(BaseInvocation):
do_classifier_free_guidance=True,
)
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
@ -323,14 +343,18 @@ class TextToLatentsInvocation(BaseInvocation):
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
class LatentsToLatentsInvocation(TextToLatentsInvocation):
"""Generates latents using latents as base image."""
type: Literal["l2l"] = "l2l"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use")
latents: Optional[LatentsField] = Field(
description="The latents to use as a base image")
strength: float = Field(
default=0.7, ge=0, le=1,
description="The strength of the latents to use")
# Schema customisation
class Config(InvocationConfig):
@ -345,22 +369,31 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
},
}
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(),
)
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
with unet_info as unet:
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict())
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
scheduler = get_scheduler(
context=context,
@ -370,7 +403,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler)
control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
latents_shape=noise.shape,
@ -380,8 +413,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=unet.device, dtype=latent.dtype
)
latent, device=unet.device, dtype=latent.dtype)
timesteps, _ = pipeline.get_img2img_timesteps(
self.steps,
@ -389,18 +421,15 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
device=unet.device,
)
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback
)
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
@ -417,9 +446,12 @@ class LatentsToImageInvocation(BaseInvocation):
type: Literal["l2i"] = "l2i"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
latents: Optional[LatentsField] = Field(
description="The latents to generate an image from")
vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
tiled: bool = Field(
default=False,
description="Decode latents by overlaping tiles(less memory consumption)")
# Schema customisation
class Config(InvocationConfig):
@ -429,7 +461,7 @@ class LatentsToImageInvocation(BaseInvocation):
},
}
@torch.no_grad()
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name)
@ -450,7 +482,7 @@ class LatentsToImageInvocation(BaseInvocation):
# copied from diffusers pipeline
latents = latents / vae.config.scaling_factor
image = vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
@ -473,9 +505,9 @@ class LatentsToImageInvocation(BaseInvocation):
height=image_dto.height,
)
LATENTS_INTERPOLATION_MODE = Literal[
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
]
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear",
"bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
class ResizeLatentsInvocation(BaseInvocation):
@ -484,21 +516,25 @@ class ResizeLatentsInvocation(BaseInvocation):
type: Literal["lresize"] = "lresize"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
latents: Optional[LatentsField] = Field(
description="The latents to resize")
width: int = Field(
ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(
ge=64, multiple_of=8, description="The height to resize to (px)")
mode: LATENTS_INTERPOLATION_MODE = Field(
default="bilinear", description="The interpolation mode")
antialias: bool = Field(
default=False,
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
resized_latents = torch.nn.functional.interpolate(
latents,
size=(self.height // 8, self.width // 8),
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
latents, size=(self.height // 8, self.width // 8),
mode=self.mode, antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
@ -515,21 +551,24 @@ class ScaleLatentsInvocation(BaseInvocation):
type: Literal["lscale"] = "lscale"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
latents: Optional[LatentsField] = Field(
description="The latents to scale")
scale_factor: float = Field(
gt=0, description="The factor by which to scale the latents")
mode: LATENTS_INTERPOLATION_MODE = Field(
default="bilinear", description="The interpolation mode")
antialias: bool = Field(
default=False,
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
# resizing
resized_latents = torch.nn.functional.interpolate(
latents,
scale_factor=self.scale_factor,
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
latents, scale_factor=self.scale_factor, mode=self.mode,
antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
@ -548,7 +587,9 @@ class ImageToLatentsInvocation(BaseInvocation):
# Inputs
image: Union[ImageField, None] = Field(description="The image to encode")
vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)")
tiled: bool = Field(
default=False,
description="Encode latents by overlaping tiles(less memory consumption)")
# Schema customisation
class Config(InvocationConfig):
@ -558,7 +599,7 @@ class ImageToLatentsInvocation(BaseInvocation):
},
}
@torch.no_grad()
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> LatentsOutput:
# image = context.services.images.get(
# self.image.image_type, self.image.image_name

View File

@ -1,18 +1,17 @@
from __future__ import annotations
import copy
from pathlib import Path
from contextlib import contextmanager
from typing import Optional, Dict, Tuple, Any
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import torch
from compel.embeddings_provider import BaseTextualInversionManager
from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file
from torch.utils.hooks import RemovableHandle
from diffusers.models import UNet2DConditionModel
from transformers import CLIPTextModel
from compel.embeddings_provider import BaseTextualInversionManager
class LoRALayerBase:
#rank: Optional[int]
@ -527,7 +526,7 @@ class ModelPatcher:
):
original_weights = dict()
try:
with torch.no_grad():
with torch.inference_mode():
for lora, lora_weight in loras:
#assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
@ -539,9 +538,10 @@ class ModelPatcher:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
# enable autocast to calc fp16 loras on cpu
with torch.autocast(device_type="cpu"):
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight() * lora_weight * layer_scale
#with torch.autocast(device_type="cpu"):
layer.to(dtype=torch.float32)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight() * lora_weight * layer_scale
if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris
@ -552,7 +552,7 @@ class ModelPatcher:
yield # wait for context manager exit
finally:
with torch.no_grad():
with torch.inference_mode():
for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight)

View File

@ -49,8 +49,6 @@ export const addLoRAsToGraph = (
'_'
)}`;
console.log(lastLoraNodeId, currentLoraNodeId, currentLoraIndex, loraField);
const loraLoaderNode: LoraLoaderInvocation = {
type: 'lora_loader',
id: currentLoraNodeId,