From c0501ed5c243efd746c178366d41360922aad0a9 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Wed, 5 Jul 2023 14:37:16 +1200 Subject: [PATCH] fix: Slow loading of Loras Co-Authored-By: StAlKeR7779 <7768370+StAlKeR7779@users.noreply.github.com> --- invokeai/app/invocations/compel.py | 166 ++++++----- invokeai/app/invocations/latent.py | 271 ++++++++++-------- invokeai/backend/model_management/lora.py | 20 +- .../util/graphBuilders/addLoRAsToGraph.ts | 2 - 4 files changed, 253 insertions(+), 206 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 0421841e8a..d77269da20 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,28 +1,27 @@ -from typing import Literal, Optional, Union -from pydantic import BaseModel, Field -from contextlib import ExitStack import re +from contextlib import ExitStack +from typing import List, Literal, Optional, Union + import torch +from compel import Compel +from compel.prompt_parser import (Blend, Conjunction, + CrossAttentionControlSubstitute, + FlattenedPrompt, Fragment) +from pydantic import BaseModel, Field -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig -from .model import ClipField - -from ...backend.util.devices import torch_dtype -from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent from ...backend.model_management import BaseModelType, ModelType, SubModelType from ...backend.model_management.lora import ModelPatcher - -from compel import Compel -from compel.prompt_parser import ( - Blend, - CrossAttentionControlSubstitute, - FlattenedPrompt, - Fragment, Conjunction, -) +from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent +from ...backend.util.devices import torch_dtype +from .baseinvocation import (BaseInvocation, BaseInvocationOutput, + InvocationConfig, InvocationContext) +from .model import ClipField class ConditioningField(BaseModel): - conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data") + conditioning_name: Optional[str] = Field( + default=None, description="The name of conditioning data") + class Config: schema_extra = {"required": ["conditioning_name"]} @@ -52,84 +51,92 @@ class CompelInvocation(BaseInvocation): "title": "Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": { - "model": "model" + "model": "model" } }, } - @torch.no_grad() + @torch.inference_mode() def invoke(self, context: InvocationContext) -> CompelOutput: - tokenizer_info = context.services.model_manager.get_model( **self.clip.tokenizer.dict(), ) text_encoder_info = context.services.model_manager.get_model( **self.clip.text_encoder.dict(), ) - with tokenizer_info as orig_tokenizer,\ - text_encoder_info as text_encoder: - loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + def _lora_loader(): + for lora in self.clip.loras: + lora_info = context.services.model_manager.get_model( + **lora.dict(exclude={"weight"})) + yield (lora_info.context.model, lora.weight) + del lora_info + return - ti_list = [] - for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): - name = trigger[1:-1] - try: - ti_list.append( - context.services.model_manager.get_model( - model_name=name, - base_model=self.clip.text_encoder.base_model, - model_type=ModelType.TextualInversion, - ).context.model - ) - except Exception: - #print(e) - #import traceback - #print(traceback.format_exc()) - print(f"Warn: trigger: \"{trigger}\" not found") + #loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] - with ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\ - ModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager): - - compel = Compel( - tokenizer=tokenizer, - text_encoder=text_encoder, - textual_inversion_manager=ti_manager, - dtype_for_device_getter=torch_dtype, - truncate_long_prompts=True, # TODO: + ti_list = [] + for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): + name = trigger[1:-1] + try: + ti_list.append( + context.services.model_manager.get_model( + model_name=name, + base_model=self.clip.text_encoder.base_model, + model_type=ModelType.TextualInversion, + ).context.model ) - - conjunction = Compel.parse_prompt_string(self.prompt) - prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0] + except Exception: + # print(e) + #import traceback + # print(traceback.format_exc()) + print(f"Warn: trigger: \"{trigger}\" not found") - if context.services.configuration.log_tokenization: - log_tokenization_for_prompt_object(prompt, tokenizer) + with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\ + ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\ + text_encoder_info as text_encoder: - c, options = compel.build_conditioning_tensor_for_prompt_object(prompt) - - # TODO: long prompt support - #if not self.truncate_long_prompts: - # [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) - ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( - tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction), - cross_attention_control_args=options.get("cross_attention_control", None), - ) - - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - - # TODO: hacky but works ;D maybe rename latents somehow? - context.services.latents.save(conditioning_name, (c, ec)) - - return CompelOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), + compel = Compel( + tokenizer=tokenizer, + text_encoder=text_encoder, + textual_inversion_manager=ti_manager, + dtype_for_device_getter=torch_dtype, + truncate_long_prompts=True, # TODO: ) + conjunction = Compel.parse_prompt_string(self.prompt) + prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0] + + if context.services.configuration.log_tokenization: + log_tokenization_for_prompt_object(prompt, tokenizer) + + c, options = compel.build_conditioning_tensor_for_prompt_object( + prompt) + + # TODO: long prompt support + # if not self.truncate_long_prompts: + # [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) + ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( + tokens_count_including_eos_bos=get_max_token_count( + tokenizer, conjunction), + cross_attention_control_args=options.get( + "cross_attention_control", None),) + + conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" + + # TODO: hacky but works ;D maybe rename latents somehow? + context.services.latents.save(conditioning_name, (c, ec)) + + return CompelOutput( + conditioning=ConditioningField( + conditioning_name=conditioning_name, + ), + ) + def get_max_token_count( - tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False -) -> int: + tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], + truncate_if_too_long=False) -> int: if type(prompt) is Blend: blend: Blend = prompt return max( @@ -148,13 +155,13 @@ def get_max_token_count( ) else: return len( - get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long) - ) + get_tokens_for_prompt_object( + tokenizer, prompt, truncate_if_too_long)) def get_tokens_for_prompt_object( tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True -) -> [str]: +) -> List[str]: if type(parsed_prompt) is Blend: raise ValueError( "Blend is not supported here - you need to get tokens for each of its .children" @@ -183,7 +190,7 @@ def log_tokenization_for_conjunction( ): display_label_prefix = display_label_prefix or "" for i, p in enumerate(c.prompts): - if len(c.prompts)>1: + if len(c.prompts) > 1: this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" else: this_display_label_prefix = display_label_prefix @@ -238,7 +245,8 @@ def log_tokenization_for_prompt_object( ) -def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False): +def log_tokenization_for_text( + text, tokenizer, display_label=None, truncate_if_too_long=False): """shows how the prompt is tokenized # usually tokens have '' to indicate end-of-word, # but for readability it has been replaced with ' ' diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index a9576a2fe1..50c901f15f 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -4,18 +4,17 @@ from contextlib import ExitStack from typing import List, Literal, Optional, Union import einops - -from pydantic import BaseModel, Field, validator import torch from diffusers import ControlNetModel, DPMSolverMultistepScheduler from diffusers.image_processor import VaeImageProcessor from diffusers.schedulers import SchedulerMixin as Scheduler +from pydantic import BaseModel, Field, validator from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.step_callback import stable_diffusion_step_callback -from ..models.image import ImageCategory, ImageField, ResourceOrigin from ...backend.image_util.seamless import configure_model_padding +from ...backend.model_management.lora import ModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline, @@ -24,7 +23,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \ PostprocessingSettings from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.util.devices import torch_dtype -from ...backend.model_management.lora import ModelPatcher +from ..models.image import ImageCategory, ImageField, ResourceOrigin from .baseinvocation import (BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext) from .compel import ConditioningField @@ -32,14 +31,17 @@ from .controlnet_image_processors import ControlField from .image import ImageOutput from .model import ModelInfo, UNetField, VaeField + class LatentsField(BaseModel): """A latents field used for passing latents between invocations""" - latents_name: Optional[str] = Field(default=None, description="The name of the latents") + latents_name: Optional[str] = Field( + default=None, description="The name of the latents") class Config: schema_extra = {"required": ["latents_name"]} + class LatentsOutput(BaseInvocationOutput): """Base class for invocations that output latents""" #fmt: off @@ -53,11 +55,11 @@ class LatentsOutput(BaseInvocationOutput): def build_latents_output(latents_name: str, latents: torch.Tensor): - return LatentsOutput( - latents=LatentsField(latents_name=latents_name), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, - ) + return LatentsOutput( + latents=LatentsField(latents_name=latents_name), + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, + ) SAMPLER_NAME_VALUES = Literal[ @@ -70,16 +72,19 @@ def get_scheduler( scheduler_info: ModelInfo, scheduler_name: str, ) -> Scheduler: - scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim']) - orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict()) + scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get( + scheduler_name, SCHEDULER_MAP['ddim']) + orig_scheduler_info = context.services.model_manager.get_model( + **scheduler_info.dict()) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config - + if "_backup" in scheduler_config: scheduler_config = scheduler_config["_backup"] - scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config} + scheduler_config = {**scheduler_config, ** + scheduler_extra_config, "_backup": scheduler_config} scheduler = scheduler_class.from_config(scheduler_config) - + # hack copied over from generate.py if not hasattr(scheduler, 'uses_inpainting_model'): scheduler.uses_inpainting_model = lambda: False @@ -124,18 +129,18 @@ class TextToLatentsInvocation(BaseInvocation): "ui": { "tags": ["latents"], "type_hints": { - "model": "model", - "control": "control", - # "cfg_scale": "float", - "cfg_scale": "number" + "model": "model", + "control": "control", + # "cfg_scale": "float", + "cfg_scale": "number" } }, } # TODO: pass this an emitter method or something? or a session for dispatching? def dispatch_progress( - self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState - ) -> None: + self, context: InvocationContext, source_node_id: str, + intermediate_state: PipelineIntermediateState) -> None: stable_diffusion_step_callback( context=context, intermediate_state=intermediate_state, @@ -143,9 +148,12 @@ class TextToLatentsInvocation(BaseInvocation): source_node_id=source_node_id, ) - def get_conditioning_data(self, context: InvocationContext, scheduler) -> ConditioningData: - c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name) - uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) + def get_conditioning_data( + self, context: InvocationContext, scheduler) -> ConditioningData: + c, extra_conditioning_info = context.services.latents.get( + self.positive_conditioning.conditioning_name) + uc, _ = context.services.latents.get( + self.negative_conditioning.conditioning_name) conditioning_data = ConditioningData( unconditioned_embeddings=uc, @@ -153,10 +161,10 @@ class TextToLatentsInvocation(BaseInvocation): guidance_scale=self.cfg_scale, extra=extra_conditioning_info, postprocessing_settings=PostprocessingSettings( - threshold=0.0,#threshold, - warmup=0.2,#warmup, - h_symmetry_time_pct=None,#h_symmetry_time_pct, - v_symmetry_time_pct=None#v_symmetry_time_pct, + threshold=0.0, # threshold, + warmup=0.2, # warmup, + h_symmetry_time_pct=None, # h_symmetry_time_pct, + v_symmetry_time_pct=None # v_symmetry_time_pct, ), ) @@ -164,31 +172,32 @@ class TextToLatentsInvocation(BaseInvocation): scheduler, # for ddim scheduler - eta=0.0, #ddim_eta + eta=0.0, # ddim_eta # for ancestral and sde schedulers generator=torch.Generator(device=uc.device).manual_seed(0), ) return conditioning_data - def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline: + def create_pipeline( + self, unet, scheduler) -> StableDiffusionGeneratorPipeline: # TODO: - #configure_model_padding( + # configure_model_padding( # unet, # self.seamless, # self.seamless_axes, - #) + # ) class FakeVae: class FakeVaeConfig: def __init__(self): self.block_out_channels = [0] - + def __init__(self): self.config = FakeVae.FakeVaeConfig() return StableDiffusionGeneratorPipeline( - vae=FakeVae(), # TODO: oh... + vae=FakeVae(), # TODO: oh... text_encoder=None, tokenizer=None, unet=unet, @@ -198,11 +207,12 @@ class TextToLatentsInvocation(BaseInvocation): requires_safety_checker=False, precision="float16" if unet.dtype == torch.float16 else "float32", ) - + def prep_control_data( self, context: InvocationContext, - model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device + # really only need model for dtype and device + model: StableDiffusionGeneratorPipeline, control_input: List[ControlField], latents_shape: List[int], do_classifier_free_guidance: bool = True, @@ -238,15 +248,17 @@ class TextToLatentsInvocation(BaseInvocation): print("Using HF model subfolders") print(" control_name: ", control_name) print(" control_subfolder: ", control_subfolder) - control_model = ControlNetModel.from_pretrained(control_name, - subfolder=control_subfolder, - torch_dtype=model.unet.dtype).to(model.device) + control_model = ControlNetModel.from_pretrained( + control_name, subfolder=control_subfolder, + torch_dtype=model.unet.dtype).to( + model.device) else: - control_model = ControlNetModel.from_pretrained(control_info.control_model, - torch_dtype=model.unet.dtype).to(model.device) + control_model = ControlNetModel.from_pretrained( + control_info.control_model, torch_dtype=model.unet.dtype).to(model.device) control_models.append(control_model) control_image_field = control_info.image - input_image = context.services.images.get_pil_image(control_image_field.image_name) + input_image = context.services.images.get_pil_image( + control_image_field.image_name) # self.image.image_type, self.image.image_name # FIXME: still need to test with different widths, heights, devices, dtypes # and add in batch_size, num_images_per_prompt? @@ -263,41 +275,50 @@ class TextToLatentsInvocation(BaseInvocation): dtype=control_model.dtype, control_mode=control_info.control_mode, ) - control_item = ControlNetData(model=control_model, - image_tensor=control_image, - weight=control_info.control_weight, - begin_step_percent=control_info.begin_step_percent, - end_step_percent=control_info.end_step_percent, - control_mode=control_info.control_mode, - ) + control_item = ControlNetData( + model=control_model, image_tensor=control_image, + weight=control_info.control_weight, + begin_step_percent=control_info.begin_step_percent, + end_step_percent=control_info.end_step_percent, + control_mode=control_info.control_mode,) control_data.append(control_item) # MultiControlNetModel has been refactored out, just need list[ControlNetData] return control_data + @torch.inference_mode() def invoke(self, context: InvocationContext) -> LatentsOutput: noise = context.services.latents.get(self.noise.latents_name) # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) + graph_execution_state = context.services.graph_execution_manager.get( + context.graph_execution_state_id) source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): self.dispatch_progress(context, source_node_id, state) - unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) - with unet_info as unet: + def _lora_loader(): + for lora in self.unet.loras: + lora_info = context.services.model_manager.get_model( + **lora.dict(exclude={"weight"})) + yield (lora_info.context.model, lora.weight) + del lora_info + return + + unet_info = context.services.model_manager.get_model( + **self.unet.unet.dict()) + with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ + unet_info as unet: scheduler = get_scheduler( context=context, scheduler_info=self.unet.scheduler, scheduler_name=self.scheduler, ) - + pipeline = self.create_pipeline(unet, scheduler) conditioning_data = self.get_conditioning_data(context, scheduler) - loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras] - control_data = self.prep_control_data( model=pipeline, context=context, control_input=self.control, latents_shape=noise.shape, @@ -305,16 +326,15 @@ class TextToLatentsInvocation(BaseInvocation): do_classifier_free_guidance=True, ) - with ModelPatcher.apply_lora_unet(pipeline.unet, loras): - # TODO: Verify the noise is the right size - result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( - latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)), - noise=noise, - num_inference_steps=self.steps, - conditioning_data=conditioning_data, - control_data=control_data, # list[ControlNetData] - callback=step_callback, - ) + # TODO: Verify the noise is the right size + result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( + latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)), + noise=noise, + num_inference_steps=self.steps, + conditioning_data=conditioning_data, + control_data=control_data, # list[ControlNetData] + callback=step_callback, + ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -323,14 +343,18 @@ class TextToLatentsInvocation(BaseInvocation): context.services.latents.save(name, result_latents) return build_latents_output(latents_name=name, latents=result_latents) + class LatentsToLatentsInvocation(TextToLatentsInvocation): """Generates latents using latents as base image.""" type: Literal["l2l"] = "l2l" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to use as a base image") - strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use") + latents: Optional[LatentsField] = Field( + description="The latents to use as a base image") + strength: float = Field( + default=0.7, ge=0, le=1, + description="The strength of the latents to use") # Schema customisation class Config(InvocationConfig): @@ -345,22 +369,31 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): }, } + @torch.inference_mode() def invoke(self, context: InvocationContext) -> LatentsOutput: noise = context.services.latents.get(self.noise.latents_name) latent = context.services.latents.get(self.latents.latents_name) # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) + graph_execution_state = context.services.graph_execution_manager.get( + context.graph_execution_state_id) source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): self.dispatch_progress(context, source_node_id, state) - unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict(), - ) + def _lora_loader(): + for lora in self.unet.loras: + lora_info = context.services.model_manager.get_model( + **lora.dict(exclude={"weight"})) + yield (lora_info.context.model, lora.weight) + del lora_info + return - with unet_info as unet: + unet_info = context.services.model_manager.get_model( + **self.unet.unet.dict()) + with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ + unet_info as unet: scheduler = get_scheduler( context=context, @@ -370,7 +403,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): pipeline = self.create_pipeline(unet, scheduler) conditioning_data = self.get_conditioning_data(context, scheduler) - + control_data = self.prep_control_data( model=pipeline, context=context, control_input=self.control, latents_shape=noise.shape, @@ -380,8 +413,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): # TODO: Verify the noise is the right size initial_latents = latent if self.strength < 1.0 else torch.zeros_like( - latent, device=unet.device, dtype=latent.dtype - ) + latent, device=unet.device, dtype=latent.dtype) timesteps, _ = pipeline.get_img2img_timesteps( self.steps, @@ -389,18 +421,15 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): device=unet.device, ) - loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras] - - with ModelPatcher.apply_lora_unet(pipeline.unet, loras): - result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( - latents=initial_latents, - timesteps=timesteps, - noise=noise, - num_inference_steps=self.steps, - conditioning_data=conditioning_data, - control_data=control_data, # list[ControlNetData] - callback=step_callback - ) + result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( + latents=initial_latents, + timesteps=timesteps, + noise=noise, + num_inference_steps=self.steps, + conditioning_data=conditioning_data, + control_data=control_data, # list[ControlNetData] + callback=step_callback + ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -417,9 +446,12 @@ class LatentsToImageInvocation(BaseInvocation): type: Literal["l2i"] = "l2i" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to generate an image from") + latents: Optional[LatentsField] = Field( + description="The latents to generate an image from") vae: VaeField = Field(default=None, description="Vae submodel") - tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)") + tiled: bool = Field( + default=False, + description="Decode latents by overlaping tiles(less memory consumption)") # Schema customisation class Config(InvocationConfig): @@ -429,7 +461,7 @@ class LatentsToImageInvocation(BaseInvocation): }, } - @torch.no_grad() + @torch.inference_mode() def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.services.latents.get(self.latents.latents_name) @@ -450,7 +482,7 @@ class LatentsToImageInvocation(BaseInvocation): # copied from diffusers pipeline latents = latents / vae.config.scaling_factor image = vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) # denormalize + image = (image / 2 + 0.5).clamp(0, 1) # denormalize # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 np_image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -473,9 +505,9 @@ class LatentsToImageInvocation(BaseInvocation): height=image_dto.height, ) -LATENTS_INTERPOLATION_MODE = Literal[ - "nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact" -] + +LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", + "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] class ResizeLatentsInvocation(BaseInvocation): @@ -484,21 +516,25 @@ class ResizeLatentsInvocation(BaseInvocation): type: Literal["lresize"] = "lresize" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to resize") - width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)") - height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)") - mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") - antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") + latents: Optional[LatentsField] = Field( + description="The latents to resize") + width: int = Field( + ge=64, multiple_of=8, description="The width to resize to (px)") + height: int = Field( + ge=64, multiple_of=8, description="The height to resize to (px)") + mode: LATENTS_INTERPOLATION_MODE = Field( + default="bilinear", description="The interpolation mode") + antialias: bool = Field( + default=False, + description="Whether or not to antialias (applied in bilinear and bicubic modes only)") def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) resized_latents = torch.nn.functional.interpolate( - latents, - size=(self.height // 8, self.width // 8), - mode=self.mode, - antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, - ) + latents, size=(self.height // 8, self.width // 8), + mode=self.mode, antialias=self.antialias + if self.mode in ["bilinear", "bicubic"] else False,) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -515,21 +551,24 @@ class ScaleLatentsInvocation(BaseInvocation): type: Literal["lscale"] = "lscale" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to scale") - scale_factor: float = Field(gt=0, description="The factor by which to scale the latents") - mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") - antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") + latents: Optional[LatentsField] = Field( + description="The latents to scale") + scale_factor: float = Field( + gt=0, description="The factor by which to scale the latents") + mode: LATENTS_INTERPOLATION_MODE = Field( + default="bilinear", description="The interpolation mode") + antialias: bool = Field( + default=False, + description="Whether or not to antialias (applied in bilinear and bicubic modes only)") def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) # resizing resized_latents = torch.nn.functional.interpolate( - latents, - scale_factor=self.scale_factor, - mode=self.mode, - antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, - ) + latents, scale_factor=self.scale_factor, mode=self.mode, + antialias=self.antialias + if self.mode in ["bilinear", "bicubic"] else False,) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -548,7 +587,9 @@ class ImageToLatentsInvocation(BaseInvocation): # Inputs image: Union[ImageField, None] = Field(description="The image to encode") vae: VaeField = Field(default=None, description="Vae submodel") - tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)") + tiled: bool = Field( + default=False, + description="Encode latents by overlaping tiles(less memory consumption)") # Schema customisation class Config(InvocationConfig): @@ -558,7 +599,7 @@ class ImageToLatentsInvocation(BaseInvocation): }, } - @torch.no_grad() + @torch.inference_mode() def invoke(self, context: InvocationContext) -> LatentsOutput: # image = context.services.images.get( # self.image.image_type, self.image.image_name diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 6cfcb8dd8d..bcd47ff00a 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -1,18 +1,17 @@ from __future__ import annotations import copy -from pathlib import Path from contextlib import contextmanager -from typing import Optional, Dict, Tuple, Any +from pathlib import Path +from typing import Any, Dict, Optional, Tuple import torch +from compel.embeddings_provider import BaseTextualInversionManager +from diffusers.models import UNet2DConditionModel from safetensors.torch import load_file from torch.utils.hooks import RemovableHandle - -from diffusers.models import UNet2DConditionModel from transformers import CLIPTextModel -from compel.embeddings_provider import BaseTextualInversionManager class LoRALayerBase: #rank: Optional[int] @@ -527,7 +526,7 @@ class ModelPatcher: ): original_weights = dict() try: - with torch.no_grad(): + with torch.inference_mode(): for lora, lora_weight in loras: #assert lora.device.type == "cpu" for layer_key, layer in lora.layers.items(): @@ -539,9 +538,10 @@ class ModelPatcher: original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) # enable autocast to calc fp16 loras on cpu - with torch.autocast(device_type="cpu"): - layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 - layer_weight = layer.get_weight() * lora_weight * layer_scale + #with torch.autocast(device_type="cpu"): + layer.to(dtype=torch.float32) + layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 + layer_weight = layer.get_weight() * lora_weight * layer_scale if module.weight.shape != layer_weight.shape: # TODO: debug on lycoris @@ -552,7 +552,7 @@ class ModelPatcher: yield # wait for context manager exit finally: - with torch.no_grad(): + with torch.inference_mode(): for module_key, weight in original_weights.items(): model.get_submodule(module_key).weight.copy_(weight) diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts index dd4b713196..9712ef4d5f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts @@ -49,8 +49,6 @@ export const addLoRAsToGraph = ( '_' )}`; - console.log(lastLoraNodeId, currentLoraNodeId, currentLoraIndex, loraField); - const loraLoaderNode: LoraLoaderInvocation = { type: 'lora_loader', id: currentLoraNodeId,