mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add lora loader node
This commit is contained in:
parent
79de9047b5
commit
420a76ecdd
@ -173,7 +173,6 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
from typing import Literal, Optional, Union, List
|
||||
from pydantic import BaseModel, Field
|
||||
import copy
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
|
||||
@ -98,21 +99,6 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
)
|
||||
"""
|
||||
|
||||
loras = [
|
||||
LoraInfo(
|
||||
model_name="sadcatmeme",
|
||||
model_type=SDModelType.Lora,
|
||||
submodel=None,
|
||||
weight=0.75,
|
||||
),
|
||||
LoraInfo(
|
||||
model_name="gunAimingAtYouV1",
|
||||
model_type=SDModelType.Lora,
|
||||
submodel=None,
|
||||
weight=0.75,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
return ModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
@ -126,7 +112,7 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Scheduler,
|
||||
),
|
||||
loras=loras,
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
@ -139,7 +125,7 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
),
|
||||
loras=loras,
|
||||
loras=[],
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
@ -149,3 +135,65 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
class LoraLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
#fmt: off
|
||||
type: Literal["lora_loader_output"] = "lora_loader_output"
|
||||
|
||||
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
|
||||
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
#fmt: on
|
||||
|
||||
class LoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
type: Literal["lora_loader"] = "lora_loader"
|
||||
|
||||
lora_name: str = Field(description="Lora model name")
|
||||
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
||||
|
||||
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
||||
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.lora_name,
|
||||
model_type=SDModelType.Lora,
|
||||
):
|
||||
raise Exception(f"Unkown lora name: {self.lora_name}!")
|
||||
|
||||
if self.unet is not None and any(lora.model_name == self.lora_name for lora in self.unet.loras):
|
||||
raise Exception(f"Lora \"{self.lora_name}\" already applied to unet")
|
||||
|
||||
if self.clip is not None and any(lora.model_name == self.lora_name for lora in self.clip.loras):
|
||||
raise Exception(f"Lora \"{self.lora_name}\" already applied to clip")
|
||||
|
||||
output = LoraLoaderOutput()
|
||||
|
||||
if self.unet is not None:
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
model_name=self.lora_name,
|
||||
model_type=SDModelType.Lora,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip is not None:
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
model_name=self.lora_name,
|
||||
model_type=SDModelType.Lora,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user