mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
BREAKING CHANGES: invocations now require model key, not base/type/name
- Implement new model loader and modify invocations and embeddings - Finish implementation loaders for all models currently supported by InvokeAI. - Move lora, textual_inversion, and model patching support into backend/embeddings. - Restore support for model cache statistics collection (a little ugly, needs work). - Fixed up invocations that load and patch models. - Move seamless and silencewarnings utils into better location
This commit is contained in:
committed by
psychedelicious
parent
5745ce9c7d
commit
78ef946e01
@ -1,13 +1,13 @@
|
||||
import copy
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
from ...backend.model_manager import SubModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@ -17,13 +17,9 @@ from .baseinvocation import (
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
model_name: str = Field(description="Info to load submodel")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Info to load submodel")
|
||||
key: str = Field(description="Info to load submodel")
|
||||
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class LoraInfo(ModelInfo):
|
||||
weight: float = Field(description="Lora's weight which to use when apply to model")
|
||||
@ -52,7 +48,7 @@ class VaeField(BaseModel):
|
||||
|
||||
@invocation_output("unet_output")
|
||||
class UNetOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a UNet field"""
|
||||
"""Base class for invocations that output a UNet field."""
|
||||
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
|
||||
@ -81,20 +77,13 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
|
||||
class MainModelField(BaseModel):
|
||||
"""Main model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
key: str = Field(description="Model key")
|
||||
|
||||
|
||||
class LoRAModelField(BaseModel):
|
||||
"""LoRA model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the LoRA model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
key: str = Field(description="LoRA model key")
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -111,74 +100,31 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.Main
|
||||
key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.models.exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
"""
|
||||
if not context.services.model_records.exists(key):
|
||||
raise Exception(f"Unknown model {key}")
|
||||
|
||||
return ModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
@ -186,9 +132,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
submodel=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
@ -226,21 +170,16 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
):
|
||||
raise Exception(f"Unkown lora name: {lora_name}!")
|
||||
if not context.services.model_records.exists(lora_key):
|
||||
raise Exception(f"Unkown lora: {lora_key}!")
|
||||
|
||||
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip')
|
||||
|
||||
output = LoraLoaderOutput()
|
||||
|
||||
@ -248,9 +187,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
key=lora_key,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
@ -260,9 +197,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
key=lora_key,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
@ -315,24 +250,19 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
):
|
||||
raise Exception(f"Unknown lora name: {lora_name}!")
|
||||
if not context.services.model_records.exists(lora_key):
|
||||
raise Exception(f"Unknown lora: {lora_key}!")
|
||||
|
||||
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip')
|
||||
|
||||
if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip2')
|
||||
if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip2')
|
||||
|
||||
output = SDXLLoraLoaderOutput()
|
||||
|
||||
@ -340,9 +270,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
key=lora_key,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
@ -352,9 +280,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
key=lora_key,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
@ -364,9 +290,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
output.clip2 = copy.deepcopy(self.clip2)
|
||||
output.clip2.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
key=lora_key,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
@ -378,10 +302,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
class VAEModelField(BaseModel):
|
||||
"""Vae model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
key: str = Field(description="Model's key")
|
||||
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
|
||||
@ -395,25 +316,12 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> VAEOutput:
|
||||
base_model = self.vae_model.base_model
|
||||
model_name = self.vae_model.model_name
|
||||
model_type = ModelType.Vae
|
||||
key = self.vae_model.key
|
||||
|
||||
if not context.models.exists(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unkown vae name: {model_name}!")
|
||||
return VAEOutput(
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
)
|
||||
)
|
||||
if not context.services.model_records.exists(key):
|
||||
raise Exception(f"Unkown vae: {key}!")
|
||||
|
||||
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
|
||||
|
||||
|
||||
@invocation_output("seamless_output")
|
||||
|
Reference in New Issue
Block a user