This commit is contained in:
Sergey Borisov 2023-06-11 06:12:21 +03:00
parent 3ce3a7ee72
commit 738ba40f51
10 changed files with 106 additions and 82 deletions

View File

@ -1,12 +1,12 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
from typing import Annotated, Literal, Optional, Union
from typing import Annotated, Literal, Optional, Union, Dict
from fastapi import Query
from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies
from invokeai.backend import SDModelType
from invokeai.backend import BaseModelType, ModelType
models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -60,7 +60,8 @@ class ConvertedModelResponse(BaseModel):
info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel):
models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
models: Dict[BaseModelType, Dict[ModelType, Dict[str, dict]]] # TODO: collect all configs
#models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
@models_router.get(
@ -69,9 +70,12 @@ class ModelsList(BaseModel):
responses={200: {"model": ModelsList }},
)
async def list_models(
model_type: SDModelType = Query(
default=None, description="The type of model to get"
),
base_model: BaseModelType = Query(
default=None, description="Base model"
),
model_type: ModelType = Query(
default=None, description="The type of model to get"
),
) -> ModelsList:
"""Gets a list of models"""
models_raw = ApiDependencies.invoker.services.model_manager.list_models(model_type)

View File

@ -8,7 +8,7 @@ from .model import ClipField
from ...backend.util.devices import torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.model_management import SDModelType
from ...backend.model_management import BaseModelType, ModelType, SubModelType
from ...backend.model_management.lora import ModelPatcher
from compel import Compel
@ -76,7 +76,11 @@ class CompelInvocation(BaseInvocation):
try:
ti_list.append(
stack.enter_context(
context.services.model_manager.get_model(model_name=name, model_type=SDModelType.TextualInversion)
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
)
)
)
except Exception:

View File

@ -5,12 +5,13 @@ import copy
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.model_management import SDModelType
from ...backend.model_management import BaseModelType, ModelType, SubModelType
class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel")
model_type: SDModelType = Field(description="Info to load submodel")
submodel: Optional[SDModelType] = Field(description="Info to load submodel")
base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Info to load submodel")
submodel: Optional[SubModelType] = Field(description="Info to load submodel")
class LoraInfo(ModelInfo):
weight: float = Field(description="Lora's weight which to use when apply to model")
@ -63,10 +64,13 @@ class ModelLoaderInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = BaseModelType.StableDiffusion2 # TODO:
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
base_model=base_model,
model_type=ModelType.Pipeline,
):
raise Exception(f"Unkown model name: {self.model_name}!")
@ -104,12 +108,14 @@ class ModelLoaderInvocation(BaseInvocation):
unet=UNetField(
unet=ModelInfo(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.UNet,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SDModelType.Scheduler,
),
loras=[],
@ -117,12 +123,14 @@ class ModelLoaderInvocation(BaseInvocation):
clip=ClipField(
tokenizer=ModelInfo(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SDModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SDModelType.TextEncoder,
),
loras=[],
@ -130,7 +138,8 @@ class ModelLoaderInvocation(BaseInvocation):
vae=VaeField(
vae=ModelInfo(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SDModelType.Vae,
),
)

View File

@ -3,7 +3,7 @@
from typing import Any
from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp
from invokeai.app.services.model_manager_service import SDModelType, SDModelInfo
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, SDModelInfo
from invokeai.app.models.exceptions import CanceledException
class EventServiceBase:
@ -109,8 +109,9 @@ class EventServiceBase:
node: dict,
source_node_id: str,
model_name: str,
model_type: SDModelType,
submodel: SDModelType,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
) -> None:
"""Emitted when a model is requested"""
self.__emit_session_event(
@ -120,6 +121,7 @@ class EventServiceBase:
node=node,
source_node_id=source_node_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
),
@ -131,8 +133,9 @@ class EventServiceBase:
node: dict,
source_node_id: str,
model_name: str,
model_type: SDModelType,
submodel: SDModelType,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
model_info: SDModelInfo,
) -> None:
"""Emitted when a model is correctly loaded (returns model info)"""
@ -143,6 +146,7 @@ class EventServiceBase:
node=node,
source_node_id=source_node_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
model_info=model_info,

View File

@ -10,7 +10,9 @@ from dataclasses import dataclass
from invokeai.backend.model_management.model_manager import (
ModelManager,
SDModelType,
BaseModelType,
ModelType,
SubModelType,
SDModelInfo,
)
from invokeai.app.models.exceptions import CanceledException
@ -20,12 +22,6 @@ from ...backend.util import choose_precision, choose_torch_device
if TYPE_CHECKING:
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
@dataclass
class LastUsedModel:
model_name: str=None
model_type: SDModelType=None
last_used_model = LastUsedModel()
class ModelManagerServiceBase(ABC):
"""Responsible for managing models on disk and in memory"""
@ -48,8 +44,9 @@ class ModelManagerServiceBase(ABC):
def get_model(
self,
model_name: str,
model_type: SDModelType = SDModelType.Diffusers,
submodel: Optional[SDModelType] = None,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None,
) -> SDModelInfo:
@ -67,12 +64,13 @@ class ModelManagerServiceBase(ABC):
def model_exists(
self,
model_name: str,
model_type: SDModelType,
base_model: BaseModelType,
model_type: ModelType,
) -> bool:
pass
@abstractmethod
def default_model(self) -> Optional[Tuple[str, SDModelType]]:
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
"""
Returns the name and typeof the default model, or None
if none is defined.
@ -80,26 +78,26 @@ class ModelManagerServiceBase(ABC):
pass
@abstractmethod
def set_default_model(self, model_name: str, model_type: SDModelType):
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
"""Sets the default model to the indicated name."""
pass
@abstractmethod
def model_info(self, model_name: str, model_type: SDModelType) -> dict:
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
"""
pass
@abstractmethod
def model_names(self) -> List[Tuple[str, SDModelType]]:
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Returns a list of all the model names known.
"""
pass
@abstractmethod
def list_models(self, model_type: SDModelType=None) -> dict:
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
"""
Return a dict of models in the format:
{ model_type1:
@ -122,7 +120,8 @@ class ModelManagerServiceBase(ABC):
def add_model(
self,
model_name: str,
model_type: SDModelType,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
clobber: bool = False
) -> None:
@ -139,7 +138,8 @@ class ModelManagerServiceBase(ABC):
def del_model(
self,
model_name: str,
model_type: SDModelType,
base_model: BaseModelType,
model_type: ModelType,
delete_files: bool = False,
):
"""
@ -297,8 +297,9 @@ class ModelManagerService(ModelManagerServiceBase):
def get_model(
self,
model_name: str,
model_type: SDModelType = SDModelType.Diffusers,
submodel: Optional[SDModelType] = None,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None,
) -> SDModelInfo:
@ -306,23 +307,6 @@ class ModelManagerService(ModelManagerServiceBase):
Retrieve the indicated model. submodel can be used to get a
part (such as the vae) of a diffusers mode.
"""
# Temporary hack here: we remember the last model fetched
# so that when executing a graph, the first node called gets
# to set default model for subsequent nodes in the event that
# they do not set the model explicitly. This should be
# displaced by model loader mechanism.
# This is to work around lack of model loader at current time,
# which was causing inconsistent model usage throughout graph.
global last_used_model
if not model_name:
self.logger.debug('No model name provided, defaulting to last loaded model')
model_name = last_used_model.model_name
model_type = model_type or last_used_model.model_type
else:
last_used_model.model_name = model_name
last_used_model.model_type = model_type
# if we are called from within a node, then we get to emit
# load start and complete events
@ -331,12 +315,14 @@ class ModelManagerService(ModelManagerServiceBase):
node=node,
context=context,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel
submodel=submodel,
)
model_info = self.mgr.get_model(
model_name,
base_model,
model_type,
submodel,
)
@ -346,6 +332,7 @@ class ModelManagerService(ModelManagerServiceBase):
node=node,
context=context,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
model_info=model_info
@ -356,7 +343,8 @@ class ModelManagerService(ModelManagerServiceBase):
def model_exists(
self,
model_name: str,
model_type: SDModelType,
base_model: BaseModelType,
model_type: ModelType,
) -> bool:
"""
Given a model name, returns True if it is a valid
@ -364,33 +352,38 @@ class ModelManagerService(ModelManagerServiceBase):
"""
return self.mgr.model_exists(
model_name,
base_model,
model_type,
)
def default_model(self) -> Optional[Tuple[str, SDModelType]]:
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
"""
Returns the name of the default model, or None
if none is defined.
"""
return self.mgr.default_model()
def set_default_model(self, model_name: str, model_type: SDModelType):
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
"""Sets the default model to the indicated name."""
self.mgr.set_default_model(model_name)
self.mgr.set_default_model(model_name, base_model, model_type)
def model_info(self, model_name: str, model_type: SDModelType) -> dict:
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
"""
return self.mgr.model_info(model_name)
return self.mgr.model_info(model_name, base_model, model_type)
def model_names(self) -> List[Tuple[str, SDModelType]]:
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Returns a list of all the model names known.
"""
return self.mgr.model_names()
def list_models(self, model_type: SDModelType=None) -> dict:
def list_models(
self,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None
) -> dict:
"""
Return a dict of models in the format:
{ model_type1:
@ -406,12 +399,13 @@ class ModelManagerService(ModelManagerServiceBase):
{ model_name_n: etc
}
"""
return self.mgr.list_models()
return self.mgr.list_models(base_model, model_type)
def add_model(
self,
model_name: str,
model_type: SDModelType,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
clobber: bool = False,
)->None:
@ -422,13 +416,14 @@ class ModelManagerService(ModelManagerServiceBase):
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
return self.mgr.add_model(model_name, model_type, model_attributes, clobber)
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
def del_model(
self,
model_name: str,
model_type: SDModelType = SDModelType.Diffusers,
base_model: BaseModelType,
model_type: ModelType,
delete_files: bool = False,
):
"""
@ -436,7 +431,7 @@ class ModelManagerService(ModelManagerServiceBase):
then the underlying weight file or diffusers directory will be deleted
as well. Call commit() to write to disk.
"""
self.mgr.del_model(model_name, model_type, delete_files)
self.mgr.del_model(model_name, base_model, model_type, delete_files)
def import_diffuser_model(
self,
@ -541,8 +536,9 @@ class ModelManagerService(ModelManagerServiceBase):
node,
context,
model_name: str,
model_type: SDModelType,
submodel: SDModelType,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
model_info: Optional[SDModelInfo] = None,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
@ -555,6 +551,7 @@ class ModelManagerService(ModelManagerServiceBase):
node=node.dict(),
source_node_id=source_node_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
model_info=model_info
@ -565,6 +562,7 @@ class ModelManagerService(ModelManagerServiceBase):
node=node.dict(),
source_node_id=source_node_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
)

View File

@ -9,5 +9,5 @@ from .generator import (
Img2Img,
Inpaint
)
from .model_management import ModelManager, ModelCache, SDModelType, SDModelInfo
from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, SDModelInfo
from .safety_checker import SafetyChecker

View File

@ -2,4 +2,5 @@
Initialization file for invokeai.backend.model_management
"""
from .model_manager import ModelManager, SDModelInfo
from .model_cache import ModelCache, SDModelType
from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType

View File

@ -39,7 +39,7 @@ from invokeai.app.services.config import get_invokeai_config
from .lora import LoRAModel, TextualInversionModel
from .models import BaseModelType, ModelType, SubModelType
from .models import BaseModelType, ModelType, SubModelType, ModelBase
# Maximum size of the cache, in gigs

View File

@ -167,6 +167,8 @@ from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from pydantic import BaseModel
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import CUDA_DEVICE, download_with_resume
@ -539,14 +541,16 @@ class ModelManager(object):
def del_model(
self,
model_name: str,
model_type: SDModelType.Diffusers,
base_model: BaseModelType,
model_type: ModelType,
delete_files: bool = False,
):
"""
Delete the named model.
"""
model_key = self.create_key(model_name, model_type)
model_cfg = self.pop(model_key, None)
raise Exception("TODO: del_model") # TODO: redo
model_key = self.create_key(model_name, base_model, model_type)
model_cfg = self.models.pop(model_key, None)
if model_cfg is None:
self.logger.error(

View File

@ -24,7 +24,7 @@ class ModelType(str, Enum):
ControlNet = "controlnet"
TextualInversion = "embedding"
class SubModelType:
class SubModelType(str, Enum):
UNet = "unet"
TextEncoder = "text_encoder"
Tokenizer = "tokenizer"