feat(nodes): add lora field, update lora loader

This commit is contained in:
psychedelicious 2023-07-04 21:11:50 +10:00
parent 92b163e95c
commit 08d428a5e7
2 changed files with 77 additions and 45 deletions

View File

@ -4,9 +4,10 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from inspect import signature from inspect import signature
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING from typing import (TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args,
get_type_hints)
from pydantic import BaseModel, Field from pydantic import BaseConfig, BaseModel, Field
if TYPE_CHECKING: if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
@ -65,7 +66,12 @@ class BaseInvocation(ABC, BaseModel):
@classmethod @classmethod
def get_invocations_map(cls): def get_invocations_map(cls):
# Get the type strings out of the literals and into a dictionary # Get the type strings out of the literals and into a dictionary
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseInvocation.get_all_subclasses())) return dict(
map(
lambda t: (get_args(get_type_hints(t)["type"])[0], t),
BaseInvocation.get_all_subclasses(),
)
)
@classmethod @classmethod
def get_output_type(cls): def get_output_type(cls):
@ -76,10 +82,10 @@ class BaseInvocation(ABC, BaseModel):
"""Invoke with provided context and return outputs.""" """Invoke with provided context and return outputs."""
pass pass
#fmt: off # fmt: off
id: str = Field(description="The id of this node. Must be unique among all nodes.") id: str = Field(description="The id of this node. Must be unique among all nodes.")
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.") is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
#fmt: on # fmt: on
# TODO: figure out a better way to provide these hints # TODO: figure out a better way to provide these hints
@ -98,16 +104,19 @@ class UIConfig(TypedDict, total=False):
"model", "model",
"control", "control",
"image_collection", "image_collection",
"vae_model",
"lora_model",
], ],
] ]
tags: List[str] tags: List[str]
title: str title: str
class CustomisedSchemaExtra(TypedDict): class CustomisedSchemaExtra(TypedDict):
ui: UIConfig ui: UIConfig
class InvocationConfig(BaseModel.Config): class InvocationConfig(BaseConfig):
"""Customizes pydantic's BaseModel.Config class for use by Invocations. """Customizes pydantic's BaseModel.Config class for use by Invocations.
Provide `schema_extra` a `ui` dict to add hints for generated UIs. Provide `schema_extra` a `ui` dict to add hints for generated UIs.

View File

@ -1,5 +1,5 @@
import copy import copy
from typing import List, Literal, Optional from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -12,35 +12,42 @@ class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel") model_name: str = Field(description="Info to load submodel")
base_model: BaseModelType = Field(description="Base model") base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Info to load submodel") model_type: ModelType = Field(description="Info to load submodel")
submodel: Optional[SubModelType] = Field(description="Info to load submodel") submodel: Optional[SubModelType] = Field(
default=None, description="Info to load submodel"
)
class LoraInfo(ModelInfo): class LoraInfo(ModelInfo):
weight: float = Field(description="Lora's weight which to use when apply to model") weight: float = Field(description="Lora's weight which to use when apply to model")
class UNetField(BaseModel): class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel") unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel") scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading") loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
class ClipField(BaseModel): class ClipField(BaseModel):
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel") tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel") text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading") loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
class VaeField(BaseModel): class VaeField(BaseModel):
# TODO: better naming? # TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel") vae: ModelInfo = Field(description="Info to load vae submodel")
class ModelLoaderOutput(BaseInvocationOutput): class ModelLoaderOutput(BaseInvocationOutput):
"""Model loader output""" """Model loader output"""
#fmt: off # fmt: off
type: Literal["model_loader_output"] = "model_loader_output" type: Literal["model_loader_output"] = "model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel") unet: UNetField = Field(default=None, description="UNet submodel")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
#fmt: on # fmt: on
class MainModelField(BaseModel): class MainModelField(BaseModel):
@ -50,6 +57,13 @@ class MainModelField(BaseModel):
base_model: BaseModelType = Field(description="Base model") base_model: BaseModelType = Field(description="Base model")
class LoRAModelField(BaseModel):
"""LoRA model field"""
model_name: str = Field(description="Name of the LoRA model")
base_model: BaseModelType = Field(description="Base model")
class MainModelLoaderInvocation(BaseInvocation): class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels.""" """Loads a main model, outputting its submodels."""
@ -64,14 +78,11 @@ class MainModelLoaderInvocation(BaseInvocation):
"ui": { "ui": {
"title": "Model Loader", "title": "Model Loader",
"tags": ["model", "loader"], "tags": ["model", "loader"],
"type_hints": { "type_hints": {"model": "model"},
"model": "model"
}
}, },
} }
def invoke(self, context: InvocationContext) -> ModelLoaderOutput: def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = self.model.base_model base_model = self.model.base_model
model_name = self.model.model_name model_name = self.model.model_name
model_type = ModelType.Main model_type = ModelType.Main
@ -113,7 +124,6 @@ class MainModelLoaderInvocation(BaseInvocation):
) )
""" """
return ModelLoaderOutput( return ModelLoaderOutput(
unet=UNetField( unet=UNetField(
unet=ModelInfo( unet=ModelInfo(
@ -152,25 +162,29 @@ class MainModelLoaderInvocation(BaseInvocation):
model_type=model_type, model_type=model_type,
submodel=SubModelType.Vae, submodel=SubModelType.Vae,
), ),
) ),
) )
class LoraLoaderOutput(BaseInvocationOutput): class LoraLoaderOutput(BaseInvocationOutput):
"""Model loader output""" """Model loader output"""
#fmt: off # fmt: off
type: Literal["lora_loader_output"] = "lora_loader_output" type: Literal["lora_loader_output"] = "lora_loader_output"
unet: Optional[UNetField] = Field(default=None, description="UNet submodel") unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels") clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
#fmt: on # fmt: on
class LoraLoaderInvocation(BaseInvocation): class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder.""" """Apply selected lora to unet and text_encoder."""
type: Literal["lora_loader"] = "lora_loader" type: Literal["lora_loader"] = "lora_loader"
lora_name: str = Field(description="Lora model name") lora: Union[LoRAModelField, None] = Field(
default=None, description="Lora model name"
)
weight: float = Field(default=0.75, description="With what weight to apply lora") weight: float = Field(default=0.75, description="With what weight to apply lora")
unet: Optional[UNetField] = Field(description="UNet model for applying lora") unet: Optional[UNetField] = Field(description="UNet model for applying lora")
@ -181,26 +195,33 @@ class LoraLoaderInvocation(BaseInvocation):
"ui": { "ui": {
"title": "Lora Loader", "title": "Lora Loader",
"tags": ["lora", "loader"], "tags": ["lora", "loader"],
"type_hints": {"lora": "lora_model"},
}, },
} }
def invoke(self, context: InvocationContext) -> LoraLoaderOutput: def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
# TODO: ui rewrite base_model = self.lora.base_model
base_model = BaseModelType.StableDiffusion1 lora_name = self.lora.model_name
if not context.services.model_manager.model_exists( if not context.services.model_manager.model_exists(
base_model=base_model, base_model=base_model,
model_name=self.lora_name, model_name=lora_name,
model_type=ModelType.Lora, model_type=ModelType.Lora,
): ):
raise Exception(f"Unkown lora name: {self.lora_name}!") raise Exception(f"Unkown lora name: {lora_name}!")
if self.unet is not None and any(lora.model_name == self.lora_name for lora in self.unet.loras): if self.unet is not None and any(
raise Exception(f"Lora \"{self.lora_name}\" already applied to unet") lora.model_name == lora_name for lora in self.unet.loras
):
raise Exception(f'Lora "{lora_name}" already applied to unet')
if self.clip is not None and any(lora.model_name == self.lora_name for lora in self.clip.loras): if self.clip is not None and any(
raise Exception(f"Lora \"{self.lora_name}\" already applied to clip") lora.model_name == lora_name for lora in self.clip.loras
):
raise Exception(f'Lora "{lora_name}" already applied to clip')
output = LoraLoaderOutput() output = LoraLoaderOutput()
@ -209,7 +230,7 @@ class LoraLoaderInvocation(BaseInvocation):
output.unet.loras.append( output.unet.loras.append(
LoraInfo( LoraInfo(
base_model=base_model, base_model=base_model,
model_name=self.lora_name, model_name=lora_name,
model_type=ModelType.Lora, model_type=ModelType.Lora,
submodel=None, submodel=None,
weight=self.weight, weight=self.weight,
@ -221,7 +242,7 @@ class LoraLoaderInvocation(BaseInvocation):
output.clip.loras.append( output.clip.loras.append(
LoraInfo( LoraInfo(
base_model=base_model, base_model=base_model,
model_name=self.lora_name, model_name=lora_name,
model_type=ModelType.Lora, model_type=ModelType.Lora,
submodel=None, submodel=None,
weight=self.weight, weight=self.weight,
@ -230,23 +251,27 @@ class LoraLoaderInvocation(BaseInvocation):
return output return output
class VAEModelField(BaseModel): class VAEModelField(BaseModel):
"""Vae model field""" """Vae model field"""
model_name: str = Field(description="Name of the model") model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model") base_model: BaseModelType = Field(description="Base model")
class VaeLoaderOutput(BaseInvocationOutput): class VaeLoaderOutput(BaseInvocationOutput):
"""Model loader output""" """Model loader output"""
#fmt: off # fmt: off
type: Literal["vae_loader_output"] = "vae_loader_output" type: Literal["vae_loader_output"] = "vae_loader_output"
vae: VaeField = Field(default=None, description="Vae model") vae: VaeField = Field(default=None, description="Vae model")
#fmt: on # fmt: on
class VaeLoaderInvocation(BaseInvocation): class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput""" """Loads a VAE model, outputting a VaeLoaderOutput"""
type: Literal["vae_loader"] = "vae_loader" type: Literal["vae_loader"] = "vae_loader"
vae_model: VAEModelField = Field(description="The VAE to load") vae_model: VAEModelField = Field(description="The VAE to load")
@ -257,9 +282,7 @@ class VaeLoaderInvocation(BaseInvocation):
"ui": { "ui": {
"title": "VAE Loader", "title": "VAE Loader",
"tags": ["vae", "loader"], "tags": ["vae", "loader"],
"type_hints": { "type_hints": {"vae_model": "vae_model"},
"vae_model": "vae_model"
}
}, },
} }
@ -269,17 +292,17 @@ class VaeLoaderInvocation(BaseInvocation):
model_type = ModelType.Vae model_type = ModelType.Vae
if not context.services.model_manager.model_exists( if not context.services.model_manager.model_exists(
base_model=base_model, base_model=base_model,
model_name=model_name, model_name=model_name,
model_type=model_type, model_type=model_type,
): ):
raise Exception(f"Unkown vae name: {model_name}!") raise Exception(f"Unkown vae name: {model_name}!")
return VaeLoaderOutput( return VaeLoaderOutput(
vae=VaeField( vae=VaeField(
vae = ModelInfo( vae=ModelInfo(
model_name = model_name, model_name=model_name,
base_model = base_model, base_model=base_model,
model_type = model_type, model_type=model_type,
) )
) )
) )