mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add sdxl lora support
This commit is contained in:
committed by
Kent Keirsey
parent
cfc3a20565
commit
1ac14a1e43
@ -173,7 +173,7 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
|
||||
class SDXLPromptInvocationBase:
|
||||
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
|
||||
def run_clip_raw(self, context, clip_field, prompt, get_pooled, lora_prefix):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.dict(),
|
||||
context=context,
|
||||
@ -210,8 +210,8 @@ class SDXLPromptInvocationBase:
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(
|
||||
text_encoder_info.context.model, _lora_loader()
|
||||
with ModelPatcher.apply_lora(
|
||||
text_encoder_info.context.model, _lora_loader(), lora_prefix
|
||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
@ -247,7 +247,7 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
return c, c_pooled, None
|
||||
|
||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.dict(),
|
||||
context=context,
|
||||
@ -284,8 +284,8 @@ class SDXLPromptInvocationBase:
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(
|
||||
text_encoder_info.context.model, _lora_loader()
|
||||
with ModelPatcher.apply_lora(
|
||||
text_encoder_info.context.model, _lora_loader(), lora_prefix
|
||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
@ -357,11 +357,11 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False)
|
||||
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_")
|
||||
if self.style.strip() == "":
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True)
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True, "lora_te2_")
|
||||
else:
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_")
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
@ -415,7 +415,8 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
|
||||
# TODO: if there will appear lora for refiner - write proper prefix
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>")
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
@ -467,11 +468,11 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False)
|
||||
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False, "lora_te1_")
|
||||
if self.style.strip() == "":
|
||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True)
|
||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True, "lora_te2_")
|
||||
else:
|
||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
|
||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "lora_te2_")
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
@ -525,7 +526,8 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
|
||||
# TODO: if there will appear lora for refiner - write proper prefix
|
||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "<NONE>")
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
|
Reference in New Issue
Block a user