This commit is contained in:
Sergey Borisov 2023-06-11 06:12:21 +03:00
parent 3ce3a7ee72
commit 738ba40f51
10 changed files with 106 additions and 82 deletions

View File

@ -1,12 +1,12 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
from typing import Annotated, Literal, Optional, Union from typing import Annotated, Literal, Optional, Union, Dict
from fastapi import Query from fastapi import Query
from fastapi.routing import APIRouter, HTTPException from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
from invokeai.backend import SDModelType from invokeai.backend import BaseModelType, ModelType
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -60,7 +60,8 @@ class ConvertedModelResponse(BaseModel):
info: DiffusersModelInfo = Field(description="The converted model info") info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel): class ModelsList(BaseModel):
models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]] models: Dict[BaseModelType, Dict[ModelType, Dict[str, dict]]] # TODO: collect all configs
#models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
@models_router.get( @models_router.get(
@ -69,7 +70,10 @@ class ModelsList(BaseModel):
responses={200: {"model": ModelsList }}, responses={200: {"model": ModelsList }},
) )
async def list_models( async def list_models(
model_type: SDModelType = Query( base_model: BaseModelType = Query(
default=None, description="Base model"
),
model_type: ModelType = Query(
default=None, description="The type of model to get" default=None, description="The type of model to get"
), ),
) -> ModelsList: ) -> ModelsList:

View File

@ -8,7 +8,7 @@ from .model import ClipField
from ...backend.util.devices import torch_dtype from ...backend.util.devices import torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.model_management import SDModelType from ...backend.model_management import BaseModelType, ModelType, SubModelType
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from compel import Compel from compel import Compel
@ -76,7 +76,11 @@ class CompelInvocation(BaseInvocation):
try: try:
ti_list.append( ti_list.append(
stack.enter_context( stack.enter_context(
context.services.model_manager.get_model(model_name=name, model_type=SDModelType.TextualInversion) context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
)
) )
) )
except Exception: except Exception:

View File

@ -5,12 +5,13 @@ import copy
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.model_management import SDModelType from ...backend.model_management import BaseModelType, ModelType, SubModelType
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel") model_name: str = Field(description="Info to load submodel")
model_type: SDModelType = Field(description="Info to load submodel") base_model: BaseModelType = Field(description="Base model")
submodel: Optional[SDModelType] = Field(description="Info to load submodel") model_type: ModelType = Field(description="Info to load submodel")
submodel: Optional[SubModelType] = Field(description="Info to load submodel")
class LoraInfo(ModelInfo): class LoraInfo(ModelInfo):
weight: float = Field(description="Lora's weight which to use when apply to model") weight: float = Field(description="Lora's weight which to use when apply to model")
@ -63,10 +64,13 @@ class ModelLoaderInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ModelLoaderOutput: def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = BaseModelType.StableDiffusion2 # TODO:
# TODO: not found exceptions # TODO: not found exceptions
if not context.services.model_manager.model_exists( if not context.services.model_manager.model_exists(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.Diffusers, base_model=base_model,
model_type=ModelType.Pipeline,
): ):
raise Exception(f"Unkown model name: {self.model_name}!") raise Exception(f"Unkown model name: {self.model_name}!")
@ -104,12 +108,14 @@ class ModelLoaderInvocation(BaseInvocation):
unet=UNetField( unet=UNetField(
unet=ModelInfo( unet=ModelInfo(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.Diffusers, base_model=base_model,
submodel=SDModelType.UNet, model_type=ModelType.Pipeline,
submodel=SubModelType.UNet,
), ),
scheduler=ModelInfo( scheduler=ModelInfo(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.Diffusers, base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SDModelType.Scheduler, submodel=SDModelType.Scheduler,
), ),
loras=[], loras=[],
@ -117,12 +123,14 @@ class ModelLoaderInvocation(BaseInvocation):
clip=ClipField( clip=ClipField(
tokenizer=ModelInfo( tokenizer=ModelInfo(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.Diffusers, base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SDModelType.Tokenizer, submodel=SDModelType.Tokenizer,
), ),
text_encoder=ModelInfo( text_encoder=ModelInfo(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.Diffusers, base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SDModelType.TextEncoder, submodel=SDModelType.TextEncoder,
), ),
loras=[], loras=[],
@ -130,7 +138,8 @@ class ModelLoaderInvocation(BaseInvocation):
vae=VaeField( vae=VaeField(
vae=ModelInfo( vae=ModelInfo(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.Diffusers, base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SDModelType.Vae, submodel=SDModelType.Vae,
), ),
) )

View File

@ -3,7 +3,7 @@
from typing import Any from typing import Any
from invokeai.app.models.image import ProgressImage from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp from invokeai.app.util.misc import get_timestamp
from invokeai.app.services.model_manager_service import SDModelType, SDModelInfo from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, SDModelInfo
from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.exceptions import CanceledException
class EventServiceBase: class EventServiceBase:
@ -109,8 +109,9 @@ class EventServiceBase:
node: dict, node: dict,
source_node_id: str, source_node_id: str,
model_name: str, model_name: str,
model_type: SDModelType, base_model: BaseModelType,
submodel: SDModelType, model_type: ModelType,
submodel: SubModelType,
) -> None: ) -> None:
"""Emitted when a model is requested""" """Emitted when a model is requested"""
self.__emit_session_event( self.__emit_session_event(
@ -120,6 +121,7 @@ class EventServiceBase:
node=node, node=node,
source_node_id=source_node_id, source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model,
model_type=model_type, model_type=model_type,
submodel=submodel, submodel=submodel,
), ),
@ -131,8 +133,9 @@ class EventServiceBase:
node: dict, node: dict,
source_node_id: str, source_node_id: str,
model_name: str, model_name: str,
model_type: SDModelType, base_model: BaseModelType,
submodel: SDModelType, model_type: ModelType,
submodel: SubModelType,
model_info: SDModelInfo, model_info: SDModelInfo,
) -> None: ) -> None:
"""Emitted when a model is correctly loaded (returns model info)""" """Emitted when a model is correctly loaded (returns model info)"""
@ -143,6 +146,7 @@ class EventServiceBase:
node=node, node=node,
source_node_id=source_node_id, source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model,
model_type=model_type, model_type=model_type,
submodel=submodel, submodel=submodel,
model_info=model_info, model_info=model_info,

View File

@ -10,7 +10,9 @@ from dataclasses import dataclass
from invokeai.backend.model_management.model_manager import ( from invokeai.backend.model_management.model_manager import (
ModelManager, ModelManager,
SDModelType, BaseModelType,
ModelType,
SubModelType,
SDModelInfo, SDModelInfo,
) )
from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.exceptions import CanceledException
@ -20,12 +22,6 @@ from ...backend.util import choose_precision, choose_torch_device
if TYPE_CHECKING: if TYPE_CHECKING:
from ..invocations.baseinvocation import BaseInvocation, InvocationContext from ..invocations.baseinvocation import BaseInvocation, InvocationContext
@dataclass
class LastUsedModel:
model_name: str=None
model_type: SDModelType=None
last_used_model = LastUsedModel()
class ModelManagerServiceBase(ABC): class ModelManagerServiceBase(ABC):
"""Responsible for managing models on disk and in memory""" """Responsible for managing models on disk and in memory"""
@ -48,8 +44,9 @@ class ModelManagerServiceBase(ABC):
def get_model( def get_model(
self, self,
model_name: str, model_name: str,
model_type: SDModelType = SDModelType.Diffusers, base_model: BaseModelType,
submodel: Optional[SDModelType] = None, model_type: ModelType,
submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None, node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None, context: Optional[InvocationContext] = None,
) -> SDModelInfo: ) -> SDModelInfo:
@ -67,12 +64,13 @@ class ModelManagerServiceBase(ABC):
def model_exists( def model_exists(
self, self,
model_name: str, model_name: str,
model_type: SDModelType, base_model: BaseModelType,
model_type: ModelType,
) -> bool: ) -> bool:
pass pass
@abstractmethod @abstractmethod
def default_model(self) -> Optional[Tuple[str, SDModelType]]: def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
""" """
Returns the name and typeof the default model, or None Returns the name and typeof the default model, or None
if none is defined. if none is defined.
@ -80,26 +78,26 @@ class ModelManagerServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def set_default_model(self, model_name: str, model_type: SDModelType): def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
"""Sets the default model to the indicated name.""" """Sets the default model to the indicated name."""
pass pass
@abstractmethod @abstractmethod
def model_info(self, model_name: str, model_type: SDModelType) -> dict: def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
""" """
Given a model name returns a dict-like (OmegaConf) object describing it. Given a model name returns a dict-like (OmegaConf) object describing it.
""" """
pass pass
@abstractmethod @abstractmethod
def model_names(self) -> List[Tuple[str, SDModelType]]: def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
""" """
Returns a list of all the model names known. Returns a list of all the model names known.
""" """
pass pass
@abstractmethod @abstractmethod
def list_models(self, model_type: SDModelType=None) -> dict: def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
""" """
Return a dict of models in the format: Return a dict of models in the format:
{ model_type1: { model_type1:
@ -122,7 +120,8 @@ class ModelManagerServiceBase(ABC):
def add_model( def add_model(
self, self,
model_name: str, model_name: str,
model_type: SDModelType, base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict, model_attributes: dict,
clobber: bool = False clobber: bool = False
) -> None: ) -> None:
@ -139,7 +138,8 @@ class ModelManagerServiceBase(ABC):
def del_model( def del_model(
self, self,
model_name: str, model_name: str,
model_type: SDModelType, base_model: BaseModelType,
model_type: ModelType,
delete_files: bool = False, delete_files: bool = False,
): ):
""" """
@ -297,8 +297,9 @@ class ModelManagerService(ModelManagerServiceBase):
def get_model( def get_model(
self, self,
model_name: str, model_name: str,
model_type: SDModelType = SDModelType.Diffusers, base_model: BaseModelType,
submodel: Optional[SDModelType] = None, model_type: ModelType,
submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None, node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None, context: Optional[InvocationContext] = None,
) -> SDModelInfo: ) -> SDModelInfo:
@ -307,23 +308,6 @@ class ModelManagerService(ModelManagerServiceBase):
part (such as the vae) of a diffusers mode. part (such as the vae) of a diffusers mode.
""" """
# Temporary hack here: we remember the last model fetched
# so that when executing a graph, the first node called gets
# to set default model for subsequent nodes in the event that
# they do not set the model explicitly. This should be
# displaced by model loader mechanism.
# This is to work around lack of model loader at current time,
# which was causing inconsistent model usage throughout graph.
global last_used_model
if not model_name:
self.logger.debug('No model name provided, defaulting to last loaded model')
model_name = last_used_model.model_name
model_type = model_type or last_used_model.model_type
else:
last_used_model.model_name = model_name
last_used_model.model_type = model_type
# if we are called from within a node, then we get to emit # if we are called from within a node, then we get to emit
# load start and complete events # load start and complete events
if node and context: if node and context:
@ -331,12 +315,14 @@ class ModelManagerService(ModelManagerServiceBase):
node=node, node=node,
context=context, context=context,
model_name=model_name, model_name=model_name,
base_model=base_model,
model_type=model_type, model_type=model_type,
submodel=submodel submodel=submodel,
) )
model_info = self.mgr.get_model( model_info = self.mgr.get_model(
model_name, model_name,
base_model,
model_type, model_type,
submodel, submodel,
) )
@ -346,6 +332,7 @@ class ModelManagerService(ModelManagerServiceBase):
node=node, node=node,
context=context, context=context,
model_name=model_name, model_name=model_name,
base_model=base_model,
model_type=model_type, model_type=model_type,
submodel=submodel, submodel=submodel,
model_info=model_info model_info=model_info
@ -356,7 +343,8 @@ class ModelManagerService(ModelManagerServiceBase):
def model_exists( def model_exists(
self, self,
model_name: str, model_name: str,
model_type: SDModelType, base_model: BaseModelType,
model_type: ModelType,
) -> bool: ) -> bool:
""" """
Given a model name, returns True if it is a valid Given a model name, returns True if it is a valid
@ -364,33 +352,38 @@ class ModelManagerService(ModelManagerServiceBase):
""" """
return self.mgr.model_exists( return self.mgr.model_exists(
model_name, model_name,
base_model,
model_type, model_type,
) )
def default_model(self) -> Optional[Tuple[str, SDModelType]]: def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
""" """
Returns the name of the default model, or None Returns the name of the default model, or None
if none is defined. if none is defined.
""" """
return self.mgr.default_model() return self.mgr.default_model()
def set_default_model(self, model_name: str, model_type: SDModelType): def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
"""Sets the default model to the indicated name.""" """Sets the default model to the indicated name."""
self.mgr.set_default_model(model_name) self.mgr.set_default_model(model_name, base_model, model_type)
def model_info(self, model_name: str, model_type: SDModelType) -> dict: def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
""" """
Given a model name returns a dict-like (OmegaConf) object describing it. Given a model name returns a dict-like (OmegaConf) object describing it.
""" """
return self.mgr.model_info(model_name) return self.mgr.model_info(model_name, base_model, model_type)
def model_names(self) -> List[Tuple[str, SDModelType]]: def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
""" """
Returns a list of all the model names known. Returns a list of all the model names known.
""" """
return self.mgr.model_names() return self.mgr.model_names()
def list_models(self, model_type: SDModelType=None) -> dict: def list_models(
self,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None
) -> dict:
""" """
Return a dict of models in the format: Return a dict of models in the format:
{ model_type1: { model_type1:
@ -406,12 +399,13 @@ class ModelManagerService(ModelManagerServiceBase):
{ model_name_n: etc { model_name_n: etc
} }
""" """
return self.mgr.list_models() return self.mgr.list_models(base_model, model_type)
def add_model( def add_model(
self, self,
model_name: str, model_name: str,
model_type: SDModelType, base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict, model_attributes: dict,
clobber: bool = False, clobber: bool = False,
)->None: )->None:
@ -422,13 +416,14 @@ class ModelManagerService(ModelManagerServiceBase):
with an assertion error if provided attributes are incorrect or with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk. the model name is missing. Call commit() to write changes to disk.
""" """
return self.mgr.add_model(model_name, model_type, model_attributes, clobber) return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
def del_model( def del_model(
self, self,
model_name: str, model_name: str,
model_type: SDModelType = SDModelType.Diffusers, base_model: BaseModelType,
model_type: ModelType,
delete_files: bool = False, delete_files: bool = False,
): ):
""" """
@ -436,7 +431,7 @@ class ModelManagerService(ModelManagerServiceBase):
then the underlying weight file or diffusers directory will be deleted then the underlying weight file or diffusers directory will be deleted
as well. Call commit() to write to disk. as well. Call commit() to write to disk.
""" """
self.mgr.del_model(model_name, model_type, delete_files) self.mgr.del_model(model_name, base_model, model_type, delete_files)
def import_diffuser_model( def import_diffuser_model(
self, self,
@ -541,8 +536,9 @@ class ModelManagerService(ModelManagerServiceBase):
node, node,
context, context,
model_name: str, model_name: str,
model_type: SDModelType, base_model: BaseModelType,
submodel: SDModelType, model_type: ModelType,
submodel: SubModelType,
model_info: Optional[SDModelInfo] = None, model_info: Optional[SDModelInfo] = None,
): ):
if context.services.queue.is_canceled(context.graph_execution_state_id): if context.services.queue.is_canceled(context.graph_execution_state_id):
@ -555,6 +551,7 @@ class ModelManagerService(ModelManagerServiceBase):
node=node.dict(), node=node.dict(),
source_node_id=source_node_id, source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model,
model_type=model_type, model_type=model_type,
submodel=submodel, submodel=submodel,
model_info=model_info model_info=model_info
@ -565,6 +562,7 @@ class ModelManagerService(ModelManagerServiceBase):
node=node.dict(), node=node.dict(),
source_node_id=source_node_id, source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model,
model_type=model_type, model_type=model_type,
submodel=submodel, submodel=submodel,
) )

View File

@ -9,5 +9,5 @@ from .generator import (
Img2Img, Img2Img,
Inpaint Inpaint
) )
from .model_management import ModelManager, ModelCache, SDModelType, SDModelInfo from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, SDModelInfo
from .safety_checker import SafetyChecker from .safety_checker import SafetyChecker

View File

@ -2,4 +2,5 @@
Initialization file for invokeai.backend.model_management Initialization file for invokeai.backend.model_management
""" """
from .model_manager import ModelManager, SDModelInfo from .model_manager import ModelManager, SDModelInfo
from .model_cache import ModelCache, SDModelType from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType

View File

@ -39,7 +39,7 @@ from invokeai.app.services.config import get_invokeai_config
from .lora import LoRAModel, TextualInversionModel from .lora import LoRAModel, TextualInversionModel
from .models import BaseModelType, ModelType, SubModelType from .models import BaseModelType, ModelType, SubModelType, ModelBase
# Maximum size of the cache, in gigs # Maximum size of the cache, in gigs

View File

@ -167,6 +167,8 @@ from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig from omegaconf.dictconfig import DictConfig
from pydantic import BaseModel
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import CUDA_DEVICE, download_with_resume from invokeai.backend.util import CUDA_DEVICE, download_with_resume
@ -539,14 +541,16 @@ class ModelManager(object):
def del_model( def del_model(
self, self,
model_name: str, model_name: str,
model_type: SDModelType.Diffusers, base_model: BaseModelType,
model_type: ModelType,
delete_files: bool = False, delete_files: bool = False,
): ):
""" """
Delete the named model. Delete the named model.
""" """
model_key = self.create_key(model_name, model_type) raise Exception("TODO: del_model") # TODO: redo
model_cfg = self.pop(model_key, None) model_key = self.create_key(model_name, base_model, model_type)
model_cfg = self.models.pop(model_key, None)
if model_cfg is None: if model_cfg is None:
self.logger.error( self.logger.error(

View File

@ -24,7 +24,7 @@ class ModelType(str, Enum):
ControlNet = "controlnet" ControlNet = "controlnet"
TextualInversion = "embedding" TextualInversion = "embedding"
class SubModelType: class SubModelType(str, Enum):
UNet = "unet" UNet = "unet"
TextEncoder = "text_encoder" TextEncoder = "text_encoder"
Tokenizer = "tokenizer" Tokenizer = "tokenizer"