mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add sdxl lora support
This commit is contained in:
committed by
Kent Keirsey
parent
cfc3a20565
commit
1ac14a1e43
@ -262,6 +262,103 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
return output
|
||||
|
||||
|
||||
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"
|
||||
|
||||
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
|
||||
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
clip2: Optional[ClipField] = Field(default=None, description="Tokenizer2 and text_encoder2 submodels")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"
|
||||
|
||||
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
|
||||
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
||||
|
||||
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
||||
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
||||
clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Lora Loader",
|
||||
"tags": ["lora", "loader"],
|
||||
"type_hints": {"lora": "lora_model"},
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
):
|
||||
raise Exception(f"Unkown lora name: {lora_name}!")
|
||||
|
||||
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||
|
||||
if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip2')
|
||||
|
||||
output = SDXLLoraLoaderOutput()
|
||||
|
||||
if self.unet is not None:
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip is not None:
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip2 is not None:
|
||||
output.clip2 = copy.deepcopy(self.clip2)
|
||||
output.clip2.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class VAEModelField(BaseModel):
|
||||
"""Vae model field"""
|
||||
|
||||
|
Reference in New Issue
Block a user