First working lora implementation

This commit is contained in:
Sergey Borisov
2023-05-30 01:11:00 +03:00
parent f50293920e
commit 79de9047b5
6 changed files with 652 additions and 43 deletions

View File

@ -1,5 +1,6 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field
from contextlib import ExitStack
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
@ -8,6 +9,7 @@ from .model import ClipField
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
from ...backend.model_management.lora import LoRAHelper
from compel import Compel
from compel.prompt_parser import (
@ -63,7 +65,10 @@ class CompelInvocation(BaseInvocation):
**self.clip.tokenizer.dict(),
)
with text_encoder_info as text_encoder,\
tokenizer_info as tokenizer:
tokenizer_info as tokenizer,\
ExitStack() as stack:
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
# TODO: global? input?
#use_full_precision = precision == "float32" or precision == "autocast"
@ -92,7 +97,8 @@ class CompelInvocation(BaseInvocation):
if context.services.configuration.log_tokenization:
log_tokenization_for_prompt_object(prompt, tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
with LoRAHelper.apply_lora_text_encoder(text_encoder, loras):
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
# TODO: long prompt support
#if not self.truncate_long_prompts:
@ -106,7 +112,7 @@ class CompelInvocation(BaseInvocation):
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.set(conditioning_name, (c, ec))
context.services.latents.save(conditioning_name, (c, ec))
return CompelOutput(
conditioning=ConditioningField(