mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
First working lora implementation
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
from typing import Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from contextlib import ExitStack
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
|
||||
@ -8,6 +9,7 @@ from .model import ClipField
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
|
||||
from ...backend.model_management.lora import LoRAHelper
|
||||
|
||||
from compel import Compel
|
||||
from compel.prompt_parser import (
|
||||
@ -63,7 +65,10 @@ class CompelInvocation(BaseInvocation):
|
||||
**self.clip.tokenizer.dict(),
|
||||
)
|
||||
with text_encoder_info as text_encoder,\
|
||||
tokenizer_info as tokenizer:
|
||||
tokenizer_info as tokenizer,\
|
||||
ExitStack() as stack:
|
||||
|
||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
||||
|
||||
# TODO: global? input?
|
||||
#use_full_precision = precision == "float32" or precision == "autocast"
|
||||
@ -92,7 +97,8 @@ class CompelInvocation(BaseInvocation):
|
||||
if context.services.configuration.log_tokenization:
|
||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
||||
with LoRAHelper.apply_lora_text_encoder(text_encoder, loras):
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
||||
|
||||
# TODO: long prompt support
|
||||
#if not self.truncate_long_prompts:
|
||||
@ -106,7 +112,7 @@ class CompelInvocation(BaseInvocation):
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.set(conditioning_name, (c, ec))
|
||||
context.services.latents.save(conditioning_name, (c, ec))
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
|
Reference in New Issue
Block a user