From 420a76ecdd5340b03b6f64e4a8b92ee3d2bddf5e Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 30 May 2023 02:12:33 +0300 Subject: [PATCH] Add lora loader node --- invokeai/app/invocations/latent.py | 1 - invokeai/app/invocations/model.py | 82 +++++++++++++++++++++++------- 2 files changed, 65 insertions(+), 18 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 3e79bad2c5..abfe92f828 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -173,7 +173,6 @@ class TextToLatentsInvocation(BaseInvocation): steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) - model: str = Field(default="", description="The model to use (currently ignored)") seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 267389c089..d0a11424c8 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -1,5 +1,6 @@ from typing import Literal, Optional, Union, List from pydantic import BaseModel, Field +import copy from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig @@ -98,21 +99,6 @@ class ModelLoaderInvocation(BaseInvocation): ) """ - loras = [ - LoraInfo( - model_name="sadcatmeme", - model_type=SDModelType.Lora, - submodel=None, - weight=0.75, - ), - LoraInfo( - model_name="gunAimingAtYouV1", - model_type=SDModelType.Lora, - submodel=None, - weight=0.75, - ), - ] - return ModelLoaderOutput( unet=UNetField( @@ -126,7 +112,7 @@ class ModelLoaderInvocation(BaseInvocation): model_type=SDModelType.Diffusers, submodel=SDModelType.Scheduler, ), - loras=loras, + loras=[], ), clip=ClipField( tokenizer=ModelInfo( @@ -139,7 +125,7 @@ class ModelLoaderInvocation(BaseInvocation): model_type=SDModelType.Diffusers, submodel=SDModelType.TextEncoder, ), - loras=loras, + loras=[], ), vae=VaeField( vae=ModelInfo( @@ -149,3 +135,65 @@ class ModelLoaderInvocation(BaseInvocation): ), ) ) + +class LoraLoaderOutput(BaseInvocationOutput): + """Model loader output""" + + #fmt: off + type: Literal["lora_loader_output"] = "lora_loader_output" + + unet: Optional[UNetField] = Field(default=None, description="UNet submodel") + clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels") + #fmt: on + +class LoraLoaderInvocation(BaseInvocation): + """Apply selected lora to unet and text_encoder.""" + + type: Literal["lora_loader"] = "lora_loader" + + lora_name: str = Field(description="Lora model name") + weight: float = Field(default=0.75, description="With what weight to apply lora") + + unet: Optional[UNetField] = Field(description="UNet model for applying lora") + clip: Optional[ClipField] = Field(description="Clip model for applying lora") + + def invoke(self, context: InvocationContext) -> LoraLoaderOutput: + + if not context.services.model_manager.model_exists( + model_name=self.lora_name, + model_type=SDModelType.Lora, + ): + raise Exception(f"Unkown lora name: {self.lora_name}!") + + if self.unet is not None and any(lora.model_name == self.lora_name for lora in self.unet.loras): + raise Exception(f"Lora \"{self.lora_name}\" already applied to unet") + + if self.clip is not None and any(lora.model_name == self.lora_name for lora in self.clip.loras): + raise Exception(f"Lora \"{self.lora_name}\" already applied to clip") + + output = LoraLoaderOutput() + + if self.unet is not None: + output.unet = copy.deepcopy(self.unet) + output.unet.loras.append( + LoraInfo( + model_name=self.lora_name, + model_type=SDModelType.Lora, + submodel=None, + weight=self.weight, + ) + ) + + if self.clip is not None: + output.clip = copy.deepcopy(self.clip) + output.clip.loras.append( + LoraInfo( + model_name=self.lora_name, + model_type=SDModelType.Lora, + submodel=None, + weight=self.weight, + ) + ) + + return output +