feat(nodes): add lora field, update lora loader

This commit is contained in:
psychedelicious 2023-07-04 21:11:50 +10:00
parent 92b163e95c
commit 08d428a5e7
2 changed files with 77 additions and 45 deletions

View File

@ -4,9 +4,10 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from inspect import signature
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING
from typing import (TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args,
get_type_hints)
from pydantic import BaseModel, Field
from pydantic import BaseConfig, BaseModel, Field
if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices
@ -65,8 +66,13 @@ class BaseInvocation(ABC, BaseModel):
@classmethod
def get_invocations_map(cls):
# Get the type strings out of the literals and into a dictionary
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseInvocation.get_all_subclasses()))
return dict(
map(
lambda t: (get_args(get_type_hints(t)["type"])[0], t),
BaseInvocation.get_all_subclasses(),
)
)
@classmethod
def get_output_type(cls):
return signature(cls.invoke).return_annotation
@ -75,11 +81,11 @@ class BaseInvocation(ABC, BaseModel):
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
"""Invoke with provided context and return outputs."""
pass
#fmt: off
# fmt: off
id: str = Field(description="The id of this node. Must be unique among all nodes.")
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
#fmt: on
# fmt: on
# TODO: figure out a better way to provide these hints
@ -98,16 +104,19 @@ class UIConfig(TypedDict, total=False):
"model",
"control",
"image_collection",
"vae_model",
"lora_model",
],
]
tags: List[str]
title: str
class CustomisedSchemaExtra(TypedDict):
ui: UIConfig
class InvocationConfig(BaseModel.Config):
class InvocationConfig(BaseConfig):
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
Provide `schema_extra` a `ui` dict to add hints for generated UIs.

View File

@ -1,5 +1,5 @@
import copy
from typing import List, Literal, Optional
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
@ -12,35 +12,42 @@ class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel")
base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Info to load submodel")
submodel: Optional[SubModelType] = Field(description="Info to load submodel")
submodel: Optional[SubModelType] = Field(
default=None, description="Info to load submodel"
)
class LoraInfo(ModelInfo):
weight: float = Field(description="Lora's weight which to use when apply to model")
class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
class ClipField(BaseModel):
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
class VaeField(BaseModel):
# TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel")
class ModelLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
#fmt: off
# fmt: off
type: Literal["model_loader_output"] = "model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae: VaeField = Field(default=None, description="Vae submodel")
#fmt: on
# fmt: on
class MainModelField(BaseModel):
@ -50,6 +57,13 @@ class MainModelField(BaseModel):
base_model: BaseModelType = Field(description="Base model")
class LoRAModelField(BaseModel):
"""LoRA model field"""
model_name: str = Field(description="Name of the LoRA model")
base_model: BaseModelType = Field(description="Base model")
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
@ -64,14 +78,11 @@ class MainModelLoaderInvocation(BaseInvocation):
"ui": {
"title": "Model Loader",
"tags": ["model", "loader"],
"type_hints": {
"model": "model"
}
"type_hints": {"model": "model"},
},
}
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
@ -113,7 +124,6 @@ class MainModelLoaderInvocation(BaseInvocation):
)
"""
return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
@ -152,25 +162,29 @@ class MainModelLoaderInvocation(BaseInvocation):
model_type=model_type,
submodel=SubModelType.Vae,
),
)
),
)
class LoraLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
#fmt: off
# fmt: off
type: Literal["lora_loader_output"] = "lora_loader_output"
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
#fmt: on
# fmt: on
class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
type: Literal["lora_loader"] = "lora_loader"
lora_name: str = Field(description="Lora model name")
lora: Union[LoRAModelField, None] = Field(
default=None, description="Lora model name"
)
weight: float = Field(default=0.75, description="With what weight to apply lora")
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
@ -181,26 +195,33 @@ class LoraLoaderInvocation(BaseInvocation):
"ui": {
"title": "Lora Loader",
"tags": ["lora", "loader"],
"type_hints": {"lora": "lora_model"},
},
}
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
# TODO: ui rewrite
base_model = BaseModelType.StableDiffusion1
base_model = self.lora.base_model
lora_name = self.lora.model_name
if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=self.lora_name,
model_name=lora_name,
model_type=ModelType.Lora,
):
raise Exception(f"Unkown lora name: {self.lora_name}!")
raise Exception(f"Unkown lora name: {lora_name}!")
if self.unet is not None and any(lora.model_name == self.lora_name for lora in self.unet.loras):
raise Exception(f"Lora \"{self.lora_name}\" already applied to unet")
if self.unet is not None and any(
lora.model_name == lora_name for lora in self.unet.loras
):
raise Exception(f'Lora "{lora_name}" already applied to unet')
if self.clip is not None and any(lora.model_name == self.lora_name for lora in self.clip.loras):
raise Exception(f"Lora \"{self.lora_name}\" already applied to clip")
if self.clip is not None and any(
lora.model_name == lora_name for lora in self.clip.loras
):
raise Exception(f'Lora "{lora_name}" already applied to clip')
output = LoraLoaderOutput()
@ -209,7 +230,7 @@ class LoraLoaderInvocation(BaseInvocation):
output.unet.loras.append(
LoraInfo(
base_model=base_model,
model_name=self.lora_name,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
@ -221,7 +242,7 @@ class LoraLoaderInvocation(BaseInvocation):
output.clip.loras.append(
LoraInfo(
base_model=base_model,
model_name=self.lora_name,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
@ -230,25 +251,29 @@ class LoraLoaderInvocation(BaseInvocation):
return output
class VAEModelField(BaseModel):
"""Vae model field"""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
class VaeLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
#fmt: off
# fmt: off
type: Literal["vae_loader_output"] = "vae_loader_output"
vae: VaeField = Field(default=None, description="Vae model")
#fmt: on
# fmt: on
class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
type: Literal["vae_loader"] = "vae_loader"
vae_model: VAEModelField = Field(description="The VAE to load")
# Schema customisation
@ -257,29 +282,27 @@ class VaeLoaderInvocation(BaseInvocation):
"ui": {
"title": "VAE Loader",
"tags": ["vae", "loader"],
"type_hints": {
"vae_model": "vae_model"
}
"type_hints": {"vae_model": "vae_model"},
},
}
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
base_model = self.vae_model.base_model
model_name = self.vae_model.model_name
model_type = ModelType.Vae
if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=model_name,
model_type=model_type,
base_model=base_model,
model_name=model_name,
model_type=model_type,
):
raise Exception(f"Unkown vae name: {model_name}!")
return VaeLoaderOutput(
vae=VaeField(
vae = ModelInfo(
model_name = model_name,
base_model = base_model,
model_type = model_type,
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
)
)