Add lora loader node

This commit is contained in:
Sergey Borisov 2023-05-30 02:12:33 +03:00
parent 79de9047b5
commit 420a76ecdd
2 changed files with 65 additions and 18 deletions

View File

@ -173,7 +173,6 @@ class TextToLatentsInvocation(BaseInvocation):
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)")
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")

View File

@ -1,5 +1,6 @@
from typing import Literal, Optional, Union, List from typing import Literal, Optional, Union, List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import copy
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
@ -98,21 +99,6 @@ class ModelLoaderInvocation(BaseInvocation):
) )
""" """
loras = [
LoraInfo(
model_name="sadcatmeme",
model_type=SDModelType.Lora,
submodel=None,
weight=0.75,
),
LoraInfo(
model_name="gunAimingAtYouV1",
model_type=SDModelType.Lora,
submodel=None,
weight=0.75,
),
]
return ModelLoaderOutput( return ModelLoaderOutput(
unet=UNetField( unet=UNetField(
@ -126,7 +112,7 @@ class ModelLoaderInvocation(BaseInvocation):
model_type=SDModelType.Diffusers, model_type=SDModelType.Diffusers,
submodel=SDModelType.Scheduler, submodel=SDModelType.Scheduler,
), ),
loras=loras, loras=[],
), ),
clip=ClipField( clip=ClipField(
tokenizer=ModelInfo( tokenizer=ModelInfo(
@ -139,7 +125,7 @@ class ModelLoaderInvocation(BaseInvocation):
model_type=SDModelType.Diffusers, model_type=SDModelType.Diffusers,
submodel=SDModelType.TextEncoder, submodel=SDModelType.TextEncoder,
), ),
loras=loras, loras=[],
), ),
vae=VaeField( vae=VaeField(
vae=ModelInfo( vae=ModelInfo(
@ -149,3 +135,65 @@ class ModelLoaderInvocation(BaseInvocation):
), ),
) )
) )
class LoraLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
#fmt: off
type: Literal["lora_loader_output"] = "lora_loader_output"
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
#fmt: on
class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
type: Literal["lora_loader"] = "lora_loader"
lora_name: str = Field(description="Lora model name")
weight: float = Field(default=0.75, description="With what weight to apply lora")
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
if not context.services.model_manager.model_exists(
model_name=self.lora_name,
model_type=SDModelType.Lora,
):
raise Exception(f"Unkown lora name: {self.lora_name}!")
if self.unet is not None and any(lora.model_name == self.lora_name for lora in self.unet.loras):
raise Exception(f"Lora \"{self.lora_name}\" already applied to unet")
if self.clip is not None and any(lora.model_name == self.lora_name for lora in self.clip.loras):
raise Exception(f"Lora \"{self.lora_name}\" already applied to clip")
output = LoraLoaderOutput()
if self.unet is not None:
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoraInfo(
model_name=self.lora_name,
model_type=SDModelType.Lora,
submodel=None,
weight=self.weight,
)
)
if self.clip is not None:
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoraInfo(
model_name=self.lora_name,
model_type=SDModelType.Lora,
submodel=None,
weight=self.weight,
)
)
return output