mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add lora loader node
This commit is contained in:
parent
79de9047b5
commit
420a76ecdd
@ -173,7 +173,6 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
|
||||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from typing import Literal, Optional, Union, List
|
from typing import Literal, Optional, Union, List
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
import copy
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||||
|
|
||||||
@ -98,21 +99,6 @@ class ModelLoaderInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
loras = [
|
|
||||||
LoraInfo(
|
|
||||||
model_name="sadcatmeme",
|
|
||||||
model_type=SDModelType.Lora,
|
|
||||||
submodel=None,
|
|
||||||
weight=0.75,
|
|
||||||
),
|
|
||||||
LoraInfo(
|
|
||||||
model_name="gunAimingAtYouV1",
|
|
||||||
model_type=SDModelType.Lora,
|
|
||||||
submodel=None,
|
|
||||||
weight=0.75,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
return ModelLoaderOutput(
|
return ModelLoaderOutput(
|
||||||
unet=UNetField(
|
unet=UNetField(
|
||||||
@ -126,7 +112,7 @@ class ModelLoaderInvocation(BaseInvocation):
|
|||||||
model_type=SDModelType.Diffusers,
|
model_type=SDModelType.Diffusers,
|
||||||
submodel=SDModelType.Scheduler,
|
submodel=SDModelType.Scheduler,
|
||||||
),
|
),
|
||||||
loras=loras,
|
loras=[],
|
||||||
),
|
),
|
||||||
clip=ClipField(
|
clip=ClipField(
|
||||||
tokenizer=ModelInfo(
|
tokenizer=ModelInfo(
|
||||||
@ -139,7 +125,7 @@ class ModelLoaderInvocation(BaseInvocation):
|
|||||||
model_type=SDModelType.Diffusers,
|
model_type=SDModelType.Diffusers,
|
||||||
submodel=SDModelType.TextEncoder,
|
submodel=SDModelType.TextEncoder,
|
||||||
),
|
),
|
||||||
loras=loras,
|
loras=[],
|
||||||
),
|
),
|
||||||
vae=VaeField(
|
vae=VaeField(
|
||||||
vae=ModelInfo(
|
vae=ModelInfo(
|
||||||
@ -149,3 +135,65 @@ class ModelLoaderInvocation(BaseInvocation):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class LoraLoaderOutput(BaseInvocationOutput):
|
||||||
|
"""Model loader output"""
|
||||||
|
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["lora_loader_output"] = "lora_loader_output"
|
||||||
|
|
||||||
|
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
|
||||||
|
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
class LoraLoaderInvocation(BaseInvocation):
|
||||||
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
|
type: Literal["lora_loader"] = "lora_loader"
|
||||||
|
|
||||||
|
lora_name: str = Field(description="Lora model name")
|
||||||
|
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
||||||
|
|
||||||
|
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
||||||
|
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||||
|
|
||||||
|
if not context.services.model_manager.model_exists(
|
||||||
|
model_name=self.lora_name,
|
||||||
|
model_type=SDModelType.Lora,
|
||||||
|
):
|
||||||
|
raise Exception(f"Unkown lora name: {self.lora_name}!")
|
||||||
|
|
||||||
|
if self.unet is not None and any(lora.model_name == self.lora_name for lora in self.unet.loras):
|
||||||
|
raise Exception(f"Lora \"{self.lora_name}\" already applied to unet")
|
||||||
|
|
||||||
|
if self.clip is not None and any(lora.model_name == self.lora_name for lora in self.clip.loras):
|
||||||
|
raise Exception(f"Lora \"{self.lora_name}\" already applied to clip")
|
||||||
|
|
||||||
|
output = LoraLoaderOutput()
|
||||||
|
|
||||||
|
if self.unet is not None:
|
||||||
|
output.unet = copy.deepcopy(self.unet)
|
||||||
|
output.unet.loras.append(
|
||||||
|
LoraInfo(
|
||||||
|
model_name=self.lora_name,
|
||||||
|
model_type=SDModelType.Lora,
|
||||||
|
submodel=None,
|
||||||
|
weight=self.weight,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.clip is not None:
|
||||||
|
output.clip = copy.deepcopy(self.clip)
|
||||||
|
output.clip.loras.append(
|
||||||
|
LoraInfo(
|
||||||
|
model_name=self.lora_name,
|
||||||
|
model_type=SDModelType.Lora,
|
||||||
|
submodel=None,
|
||||||
|
weight=self.weight,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user