From 08d428a5e7fd5fd0d2db3ad90a80f24c69af9a26 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 4 Jul 2023 21:11:50 +1000 Subject: [PATCH] feat(nodes): add lora field, update lora loader --- invokeai/app/invocations/baseinvocation.py | 25 ++++-- invokeai/app/invocations/model.py | 97 +++++++++++++--------- 2 files changed, 77 insertions(+), 45 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 1bf9353368..4c7314bd2b 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -4,9 +4,10 @@ from __future__ import annotations from abc import ABC, abstractmethod from inspect import signature -from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING +from typing import (TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args, + get_type_hints) -from pydantic import BaseModel, Field +from pydantic import BaseConfig, BaseModel, Field if TYPE_CHECKING: from ..services.invocation_services import InvocationServices @@ -65,8 +66,13 @@ class BaseInvocation(ABC, BaseModel): @classmethod def get_invocations_map(cls): # Get the type strings out of the literals and into a dictionary - return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseInvocation.get_all_subclasses())) - + return dict( + map( + lambda t: (get_args(get_type_hints(t)["type"])[0], t), + BaseInvocation.get_all_subclasses(), + ) + ) + @classmethod def get_output_type(cls): return signature(cls.invoke).return_annotation @@ -75,11 +81,11 @@ class BaseInvocation(ABC, BaseModel): def invoke(self, context: InvocationContext) -> BaseInvocationOutput: """Invoke with provided context and return outputs.""" pass - - #fmt: off + + # fmt: off id: str = Field(description="The id of this node. Must be unique among all nodes.") is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.") - #fmt: on + # fmt: on # TODO: figure out a better way to provide these hints @@ -98,16 +104,19 @@ class UIConfig(TypedDict, total=False): "model", "control", "image_collection", + "vae_model", + "lora_model", ], ] tags: List[str] title: str + class CustomisedSchemaExtra(TypedDict): ui: UIConfig -class InvocationConfig(BaseModel.Config): +class InvocationConfig(BaseConfig): """Customizes pydantic's BaseModel.Config class for use by Invocations. Provide `schema_extra` a `ui` dict to add hints for generated UIs. diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index e51873c59e..17297ba417 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -1,5 +1,5 @@ import copy -from typing import List, Literal, Optional +from typing import List, Literal, Optional, Union from pydantic import BaseModel, Field @@ -12,35 +12,42 @@ class ModelInfo(BaseModel): model_name: str = Field(description="Info to load submodel") base_model: BaseModelType = Field(description="Base model") model_type: ModelType = Field(description="Info to load submodel") - submodel: Optional[SubModelType] = Field(description="Info to load submodel") + submodel: Optional[SubModelType] = Field( + default=None, description="Info to load submodel" + ) + class LoraInfo(ModelInfo): weight: float = Field(description="Lora's weight which to use when apply to model") + class UNetField(BaseModel): unet: ModelInfo = Field(description="Info to load unet submodel") scheduler: ModelInfo = Field(description="Info to load scheduler submodel") loras: List[LoraInfo] = Field(description="Loras to apply on model loading") + class ClipField(BaseModel): tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel") text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel") loras: List[LoraInfo] = Field(description="Loras to apply on model loading") + class VaeField(BaseModel): # TODO: better naming? vae: ModelInfo = Field(description="Info to load vae submodel") + class ModelLoaderOutput(BaseInvocationOutput): """Model loader output""" - #fmt: off + # fmt: off type: Literal["model_loader_output"] = "model_loader_output" unet: UNetField = Field(default=None, description="UNet submodel") clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") vae: VaeField = Field(default=None, description="Vae submodel") - #fmt: on + # fmt: on class MainModelField(BaseModel): @@ -50,6 +57,13 @@ class MainModelField(BaseModel): base_model: BaseModelType = Field(description="Base model") +class LoRAModelField(BaseModel): + """LoRA model field""" + + model_name: str = Field(description="Name of the LoRA model") + base_model: BaseModelType = Field(description="Base model") + + class MainModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" @@ -64,14 +78,11 @@ class MainModelLoaderInvocation(BaseInvocation): "ui": { "title": "Model Loader", "tags": ["model", "loader"], - "type_hints": { - "model": "model" - } + "type_hints": {"model": "model"}, }, } def invoke(self, context: InvocationContext) -> ModelLoaderOutput: - base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main @@ -113,7 +124,6 @@ class MainModelLoaderInvocation(BaseInvocation): ) """ - return ModelLoaderOutput( unet=UNetField( unet=ModelInfo( @@ -152,25 +162,29 @@ class MainModelLoaderInvocation(BaseInvocation): model_type=model_type, submodel=SubModelType.Vae, ), - ) + ), ) + class LoraLoaderOutput(BaseInvocationOutput): """Model loader output""" - #fmt: off + # fmt: off type: Literal["lora_loader_output"] = "lora_loader_output" unet: Optional[UNetField] = Field(default=None, description="UNet submodel") clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels") - #fmt: on + # fmt: on + class LoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" type: Literal["lora_loader"] = "lora_loader" - lora_name: str = Field(description="Lora model name") + lora: Union[LoRAModelField, None] = Field( + default=None, description="Lora model name" + ) weight: float = Field(default=0.75, description="With what weight to apply lora") unet: Optional[UNetField] = Field(description="UNet model for applying lora") @@ -181,26 +195,33 @@ class LoraLoaderInvocation(BaseInvocation): "ui": { "title": "Lora Loader", "tags": ["lora", "loader"], + "type_hints": {"lora": "lora_model"}, }, } def invoke(self, context: InvocationContext) -> LoraLoaderOutput: + if self.lora is None: + raise Exception("No LoRA provided") - # TODO: ui rewrite - base_model = BaseModelType.StableDiffusion1 + base_model = self.lora.base_model + lora_name = self.lora.model_name if not context.services.model_manager.model_exists( base_model=base_model, - model_name=self.lora_name, + model_name=lora_name, model_type=ModelType.Lora, ): - raise Exception(f"Unkown lora name: {self.lora_name}!") + raise Exception(f"Unkown lora name: {lora_name}!") - if self.unet is not None and any(lora.model_name == self.lora_name for lora in self.unet.loras): - raise Exception(f"Lora \"{self.lora_name}\" already applied to unet") + if self.unet is not None and any( + lora.model_name == lora_name for lora in self.unet.loras + ): + raise Exception(f'Lora "{lora_name}" already applied to unet') - if self.clip is not None and any(lora.model_name == self.lora_name for lora in self.clip.loras): - raise Exception(f"Lora \"{self.lora_name}\" already applied to clip") + if self.clip is not None and any( + lora.model_name == lora_name for lora in self.clip.loras + ): + raise Exception(f'Lora "{lora_name}" already applied to clip') output = LoraLoaderOutput() @@ -209,7 +230,7 @@ class LoraLoaderInvocation(BaseInvocation): output.unet.loras.append( LoraInfo( base_model=base_model, - model_name=self.lora_name, + model_name=lora_name, model_type=ModelType.Lora, submodel=None, weight=self.weight, @@ -221,7 +242,7 @@ class LoraLoaderInvocation(BaseInvocation): output.clip.loras.append( LoraInfo( base_model=base_model, - model_name=self.lora_name, + model_name=lora_name, model_type=ModelType.Lora, submodel=None, weight=self.weight, @@ -230,25 +251,29 @@ class LoraLoaderInvocation(BaseInvocation): return output + class VAEModelField(BaseModel): """Vae model field""" model_name: str = Field(description="Name of the model") base_model: BaseModelType = Field(description="Base model") + class VaeLoaderOutput(BaseInvocationOutput): """Model loader output""" - #fmt: off + # fmt: off type: Literal["vae_loader_output"] = "vae_loader_output" vae: VaeField = Field(default=None, description="Vae model") - #fmt: on + # fmt: on + class VaeLoaderInvocation(BaseInvocation): """Loads a VAE model, outputting a VaeLoaderOutput""" + type: Literal["vae_loader"] = "vae_loader" - + vae_model: VAEModelField = Field(description="The VAE to load") # Schema customisation @@ -257,29 +282,27 @@ class VaeLoaderInvocation(BaseInvocation): "ui": { "title": "VAE Loader", "tags": ["vae", "loader"], - "type_hints": { - "vae_model": "vae_model" - } + "type_hints": {"vae_model": "vae_model"}, }, } - + def invoke(self, context: InvocationContext) -> VaeLoaderOutput: base_model = self.vae_model.base_model model_name = self.vae_model.model_name model_type = ModelType.Vae if not context.services.model_manager.model_exists( - base_model=base_model, - model_name=model_name, - model_type=model_type, + base_model=base_model, + model_name=model_name, + model_type=model_type, ): raise Exception(f"Unkown vae name: {model_name}!") return VaeLoaderOutput( vae=VaeField( - vae = ModelInfo( - model_name = model_name, - base_model = base_model, - model_type = model_type, + vae=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=model_type, ) ) )