Add lora loader node

This commit is contained in:
Sergey Borisov 2023-05-30 02:12:33 +03:00
parent 79de9047b5
commit 420a76ecdd
2 changed files with 65 additions and 18 deletions

View File

@ -173,7 +173,6 @@ class TextToLatentsInvocation(BaseInvocation):
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)")
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")

View File

@ -1,5 +1,6 @@
from typing import Literal, Optional, Union, List
from pydantic import BaseModel, Field
import copy
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
@ -98,21 +99,6 @@ class ModelLoaderInvocation(BaseInvocation):
)
"""
loras = [
LoraInfo(
model_name="sadcatmeme",
model_type=SDModelType.Lora,
submodel=None,
weight=0.75,
),
LoraInfo(
model_name="gunAimingAtYouV1",
model_type=SDModelType.Lora,
submodel=None,
weight=0.75,
),
]
return ModelLoaderOutput(
unet=UNetField(
@ -126,7 +112,7 @@ class ModelLoaderInvocation(BaseInvocation):
model_type=SDModelType.Diffusers,
submodel=SDModelType.Scheduler,
),
loras=loras,
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
@ -139,7 +125,7 @@ class ModelLoaderInvocation(BaseInvocation):
model_type=SDModelType.Diffusers,
submodel=SDModelType.TextEncoder,
),
loras=loras,
loras=[],
),
vae=VaeField(
vae=ModelInfo(
@ -149,3 +135,65 @@ class ModelLoaderInvocation(BaseInvocation):
),
)
)
class LoraLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
#fmt: off
type: Literal["lora_loader_output"] = "lora_loader_output"
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
#fmt: on
class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
type: Literal["lora_loader"] = "lora_loader"
lora_name: str = Field(description="Lora model name")
weight: float = Field(default=0.75, description="With what weight to apply lora")
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
if not context.services.model_manager.model_exists(
model_name=self.lora_name,
model_type=SDModelType.Lora,
):
raise Exception(f"Unkown lora name: {self.lora_name}!")
if self.unet is not None and any(lora.model_name == self.lora_name for lora in self.unet.loras):
raise Exception(f"Lora \"{self.lora_name}\" already applied to unet")
if self.clip is not None and any(lora.model_name == self.lora_name for lora in self.clip.loras):
raise Exception(f"Lora \"{self.lora_name}\" already applied to clip")
output = LoraLoaderOutput()
if self.unet is not None:
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoraInfo(
model_name=self.lora_name,
model_type=SDModelType.Lora,
submodel=None,
weight=self.weight,
)
)
if self.clip is not None:
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoraInfo(
model_name=self.lora_name,
model_type=SDModelType.Lora,
submodel=None,
weight=self.weight,
)
)
return output