remove defunct code

This commit is contained in:
Lincoln Stein 2023-06-11 12:57:06 -04:00
commit 8e1a56875e
17 changed files with 969 additions and 244 deletions

View File

@ -1,12 +1,12 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
from typing import Annotated, Literal, Optional, Union from typing import Annotated, Literal, Optional, Union, Dict
from fastapi import Query from fastapi import Query
from fastapi.routing import APIRouter, HTTPException from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
from invokeai.backend import SDModelType from invokeai.backend import BaseModelType, ModelType
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -60,7 +60,8 @@ class ConvertedModelResponse(BaseModel):
info: DiffusersModelInfo = Field(description="The converted model info") info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel): class ModelsList(BaseModel):
models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]] models: Dict[BaseModelType, Dict[ModelType, Dict[str, dict]]] # TODO: collect all configs
#models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
@models_router.get( @models_router.get(
@ -69,9 +70,12 @@ class ModelsList(BaseModel):
responses={200: {"model": ModelsList }}, responses={200: {"model": ModelsList }},
) )
async def list_models( async def list_models(
model_type: SDModelType = Query( base_model: BaseModelType = Query(
default=None, description="The type of model to get" default=None, description="Base model"
), ),
model_type: ModelType = Query(
default=None, description="The type of model to get"
),
) -> ModelsList: ) -> ModelsList:
"""Gets a list of models""" """Gets a list of models"""
models_raw = ApiDependencies.invoker.services.model_manager.list_models(model_type) models_raw = ApiDependencies.invoker.services.model_manager.list_models(model_type)

View File

@ -8,7 +8,7 @@ from .model import ClipField
from ...backend.util.devices import torch_dtype from ...backend.util.devices import torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.model_management import SDModelType from ...backend.model_management import BaseModelType, ModelType, SubModelType
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from compel import Compel from compel import Compel
@ -76,7 +76,11 @@ class CompelInvocation(BaseInvocation):
try: try:
ti_list.append( ti_list.append(
stack.enter_context( stack.enter_context(
context.services.model_manager.get_model(model_name=name, model_type=SDModelType.TextualInversion) context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
)
) )
) )
except Exception: except Exception:

View File

@ -5,12 +5,13 @@ import copy
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.model_management import SDModelType from ...backend.model_management import BaseModelType, ModelType, SubModelType
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel") model_name: str = Field(description="Info to load submodel")
model_type: SDModelType = Field(description="Info to load submodel") base_model: BaseModelType = Field(description="Base model")
submodel: Optional[SDModelType] = Field(description="Info to load submodel") model_type: ModelType = Field(description="Info to load submodel")
submodel: Optional[SubModelType] = Field(description="Info to load submodel")
class LoraInfo(ModelInfo): class LoraInfo(ModelInfo):
weight: float = Field(description="Lora's weight which to use when apply to model") weight: float = Field(description="Lora's weight which to use when apply to model")
@ -63,10 +64,13 @@ class ModelLoaderInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ModelLoaderOutput: def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = BaseModelType.StableDiffusion2 # TODO:
# TODO: not found exceptions # TODO: not found exceptions
if not context.services.model_manager.model_exists( if not context.services.model_manager.model_exists(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.Diffusers, base_model=base_model,
model_type=ModelType.Pipeline,
): ):
raise Exception(f"Unkown model name: {self.model_name}!") raise Exception(f"Unkown model name: {self.model_name}!")
@ -104,12 +108,14 @@ class ModelLoaderInvocation(BaseInvocation):
unet=UNetField( unet=UNetField(
unet=ModelInfo( unet=ModelInfo(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.Diffusers, base_model=base_model,
submodel=SDModelType.UNet, model_type=ModelType.Pipeline,
submodel=SubModelType.UNet,
), ),
scheduler=ModelInfo( scheduler=ModelInfo(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.Diffusers, base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SDModelType.Scheduler, submodel=SDModelType.Scheduler,
), ),
loras=[], loras=[],
@ -117,12 +123,14 @@ class ModelLoaderInvocation(BaseInvocation):
clip=ClipField( clip=ClipField(
tokenizer=ModelInfo( tokenizer=ModelInfo(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.Diffusers, base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SDModelType.Tokenizer, submodel=SDModelType.Tokenizer,
), ),
text_encoder=ModelInfo( text_encoder=ModelInfo(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.Diffusers, base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SDModelType.TextEncoder, submodel=SDModelType.TextEncoder,
), ),
loras=[], loras=[],
@ -130,7 +138,8 @@ class ModelLoaderInvocation(BaseInvocation):
vae=VaeField( vae=VaeField(
vae=ModelInfo( vae=ModelInfo(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.Diffusers, base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SDModelType.Vae, submodel=SDModelType.Vae,
), ),
) )

View File

@ -3,7 +3,7 @@
from typing import Any from typing import Any
from invokeai.app.models.image import ProgressImage from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp from invokeai.app.util.misc import get_timestamp
from invokeai.app.services.model_manager_service import SDModelType, SDModelInfo from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, SDModelInfo
from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.exceptions import CanceledException
class EventServiceBase: class EventServiceBase:
@ -109,8 +109,9 @@ class EventServiceBase:
node: dict, node: dict,
source_node_id: str, source_node_id: str,
model_name: str, model_name: str,
model_type: SDModelType, base_model: BaseModelType,
submodel: SDModelType, model_type: ModelType,
submodel: SubModelType,
) -> None: ) -> None:
"""Emitted when a model is requested""" """Emitted when a model is requested"""
self.__emit_session_event( self.__emit_session_event(
@ -120,6 +121,7 @@ class EventServiceBase:
node=node, node=node,
source_node_id=source_node_id, source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model,
model_type=model_type, model_type=model_type,
submodel=submodel, submodel=submodel,
), ),
@ -131,8 +133,9 @@ class EventServiceBase:
node: dict, node: dict,
source_node_id: str, source_node_id: str,
model_name: str, model_name: str,
model_type: SDModelType, base_model: BaseModelType,
submodel: SDModelType, model_type: ModelType,
submodel: SubModelType,
model_info: SDModelInfo, model_info: SDModelInfo,
) -> None: ) -> None:
"""Emitted when a model is correctly loaded (returns model info)""" """Emitted when a model is correctly loaded (returns model info)"""
@ -143,6 +146,7 @@ class EventServiceBase:
node=node, node=node,
source_node_id=source_node_id, source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model,
model_type=model_type, model_type=model_type,
submodel=submodel, submodel=submodel,
model_info=model_info, model_info=model_info,

View File

@ -10,7 +10,9 @@ from dataclasses import dataclass
from invokeai.backend.model_management.model_manager import ( from invokeai.backend.model_management.model_manager import (
ModelManager, ModelManager,
SDModelType, BaseModelType,
ModelType,
SubModelType,
SDModelInfo, SDModelInfo,
) )
from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.exceptions import CanceledException
@ -20,12 +22,6 @@ from ...backend.util import choose_precision, choose_torch_device
if TYPE_CHECKING: if TYPE_CHECKING:
from ..invocations.baseinvocation import BaseInvocation, InvocationContext from ..invocations.baseinvocation import BaseInvocation, InvocationContext
@dataclass
class LastUsedModel:
model_name: str=None
model_type: SDModelType=None
last_used_model = LastUsedModel()
class ModelManagerServiceBase(ABC): class ModelManagerServiceBase(ABC):
"""Responsible for managing models on disk and in memory""" """Responsible for managing models on disk and in memory"""
@ -48,8 +44,9 @@ class ModelManagerServiceBase(ABC):
def get_model( def get_model(
self, self,
model_name: str, model_name: str,
model_type: SDModelType = SDModelType.Diffusers, base_model: BaseModelType,
submodel: Optional[SDModelType] = None, model_type: ModelType,
submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None, node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None, context: Optional[InvocationContext] = None,
) -> SDModelInfo: ) -> SDModelInfo:
@ -67,12 +64,13 @@ class ModelManagerServiceBase(ABC):
def model_exists( def model_exists(
self, self,
model_name: str, model_name: str,
model_type: SDModelType, base_model: BaseModelType,
model_type: ModelType,
) -> bool: ) -> bool:
pass pass
@abstractmethod @abstractmethod
def default_model(self) -> Optional[Tuple[str, SDModelType]]: def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
""" """
Returns the name and typeof the default model, or None Returns the name and typeof the default model, or None
if none is defined. if none is defined.
@ -80,26 +78,26 @@ class ModelManagerServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def set_default_model(self, model_name: str, model_type: SDModelType): def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
"""Sets the default model to the indicated name.""" """Sets the default model to the indicated name."""
pass pass
@abstractmethod @abstractmethod
def model_info(self, model_name: str, model_type: SDModelType) -> dict: def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
""" """
Given a model name returns a dict-like (OmegaConf) object describing it. Given a model name returns a dict-like (OmegaConf) object describing it.
""" """
pass pass
@abstractmethod @abstractmethod
def model_names(self) -> List[Tuple[str, SDModelType]]: def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
""" """
Returns a list of all the model names known. Returns a list of all the model names known.
""" """
pass pass
@abstractmethod @abstractmethod
def list_models(self, model_type: SDModelType=None) -> dict: def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
""" """
Return a dict of models in the format: Return a dict of models in the format:
{ model_type1: { model_type1:
@ -122,7 +120,8 @@ class ModelManagerServiceBase(ABC):
def add_model( def add_model(
self, self,
model_name: str, model_name: str,
model_type: SDModelType, base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict, model_attributes: dict,
clobber: bool = False clobber: bool = False
) -> None: ) -> None:
@ -139,7 +138,8 @@ class ModelManagerServiceBase(ABC):
def del_model( def del_model(
self, self,
model_name: str, model_name: str,
model_type: SDModelType, base_model: BaseModelType,
model_type: ModelType,
delete_files: bool = False, delete_files: bool = False,
): ):
""" """
@ -297,8 +297,9 @@ class ModelManagerService(ModelManagerServiceBase):
def get_model( def get_model(
self, self,
model_name: str, model_name: str,
model_type: SDModelType = SDModelType.Diffusers, base_model: BaseModelType,
submodel: Optional[SDModelType] = None, model_type: ModelType,
submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None, node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None, context: Optional[InvocationContext] = None,
) -> SDModelInfo: ) -> SDModelInfo:
@ -306,23 +307,6 @@ class ModelManagerService(ModelManagerServiceBase):
Retrieve the indicated model. submodel can be used to get a Retrieve the indicated model. submodel can be used to get a
part (such as the vae) of a diffusers mode. part (such as the vae) of a diffusers mode.
""" """
# Temporary hack here: we remember the last model fetched
# so that when executing a graph, the first node called gets
# to set default model for subsequent nodes in the event that
# they do not set the model explicitly. This should be
# displaced by model loader mechanism.
# This is to work around lack of model loader at current time,
# which was causing inconsistent model usage throughout graph.
global last_used_model
if not model_name:
self.logger.debug('No model name provided, defaulting to last loaded model')
model_name = last_used_model.model_name
model_type = model_type or last_used_model.model_type
else:
last_used_model.model_name = model_name
last_used_model.model_type = model_type
# if we are called from within a node, then we get to emit # if we are called from within a node, then we get to emit
# load start and complete events # load start and complete events
@ -331,12 +315,14 @@ class ModelManagerService(ModelManagerServiceBase):
node=node, node=node,
context=context, context=context,
model_name=model_name, model_name=model_name,
base_model=base_model,
model_type=model_type, model_type=model_type,
submodel=submodel submodel=submodel,
) )
model_info = self.mgr.get_model( model_info = self.mgr.get_model(
model_name, model_name,
base_model,
model_type, model_type,
submodel, submodel,
) )
@ -346,6 +332,7 @@ class ModelManagerService(ModelManagerServiceBase):
node=node, node=node,
context=context, context=context,
model_name=model_name, model_name=model_name,
base_model=base_model,
model_type=model_type, model_type=model_type,
submodel=submodel, submodel=submodel,
model_info=model_info model_info=model_info
@ -356,7 +343,8 @@ class ModelManagerService(ModelManagerServiceBase):
def model_exists( def model_exists(
self, self,
model_name: str, model_name: str,
model_type: SDModelType, base_model: BaseModelType,
model_type: ModelType,
) -> bool: ) -> bool:
""" """
Given a model name, returns True if it is a valid Given a model name, returns True if it is a valid
@ -364,33 +352,38 @@ class ModelManagerService(ModelManagerServiceBase):
""" """
return self.mgr.model_exists( return self.mgr.model_exists(
model_name, model_name,
base_model,
model_type, model_type,
) )
def default_model(self) -> Optional[Tuple[str, SDModelType]]: def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
""" """
Returns the name of the default model, or None Returns the name of the default model, or None
if none is defined. if none is defined.
""" """
return self.mgr.default_model() return self.mgr.default_model()
def set_default_model(self, model_name: str, model_type: SDModelType): def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
"""Sets the default model to the indicated name.""" """Sets the default model to the indicated name."""
self.mgr.set_default_model(model_name) self.mgr.set_default_model(model_name, base_model, model_type)
def model_info(self, model_name: str, model_type: SDModelType) -> dict: def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
""" """
Given a model name returns a dict-like (OmegaConf) object describing it. Given a model name returns a dict-like (OmegaConf) object describing it.
""" """
return self.mgr.model_info(model_name) return self.mgr.model_info(model_name, base_model, model_type)
def model_names(self) -> List[Tuple[str, SDModelType]]: def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
""" """
Returns a list of all the model names known. Returns a list of all the model names known.
""" """
return self.mgr.model_names() return self.mgr.model_names()
def list_models(self, model_type: SDModelType=None) -> dict: def list_models(
self,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None
) -> dict:
""" """
Return a dict of models in the format: Return a dict of models in the format:
{ model_type1: { model_type1:
@ -406,12 +399,13 @@ class ModelManagerService(ModelManagerServiceBase):
{ model_name_n: etc { model_name_n: etc
} }
""" """
return self.mgr.list_models() return self.mgr.list_models(base_model, model_type)
def add_model( def add_model(
self, self,
model_name: str, model_name: str,
model_type: SDModelType, base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict, model_attributes: dict,
clobber: bool = False, clobber: bool = False,
)->None: )->None:
@ -422,13 +416,14 @@ class ModelManagerService(ModelManagerServiceBase):
with an assertion error if provided attributes are incorrect or with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk. the model name is missing. Call commit() to write changes to disk.
""" """
return self.mgr.add_model(model_name, model_type, model_attributes, clobber) return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
def del_model( def del_model(
self, self,
model_name: str, model_name: str,
model_type: SDModelType = SDModelType.Diffusers, base_model: BaseModelType,
model_type: ModelType,
delete_files: bool = False, delete_files: bool = False,
): ):
""" """
@ -436,7 +431,7 @@ class ModelManagerService(ModelManagerServiceBase):
then the underlying weight file or diffusers directory will be deleted then the underlying weight file or diffusers directory will be deleted
as well. Call commit() to write to disk. as well. Call commit() to write to disk.
""" """
self.mgr.del_model(model_name, model_type, delete_files) self.mgr.del_model(model_name, base_model, model_type, delete_files)
def import_diffuser_model( def import_diffuser_model(
self, self,
@ -541,8 +536,9 @@ class ModelManagerService(ModelManagerServiceBase):
node, node,
context, context,
model_name: str, model_name: str,
model_type: SDModelType, base_model: BaseModelType,
submodel: SDModelType, model_type: ModelType,
submodel: SubModelType,
model_info: Optional[SDModelInfo] = None, model_info: Optional[SDModelInfo] = None,
): ):
if context.services.queue.is_canceled(context.graph_execution_state_id): if context.services.queue.is_canceled(context.graph_execution_state_id):
@ -555,6 +551,7 @@ class ModelManagerService(ModelManagerServiceBase):
node=node.dict(), node=node.dict(),
source_node_id=source_node_id, source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model,
model_type=model_type, model_type=model_type,
submodel=submodel, submodel=submodel,
model_info=model_info model_info=model_info
@ -565,6 +562,7 @@ class ModelManagerService(ModelManagerServiceBase):
node=node.dict(), node=node.dict(),
source_node_id=source_node_id, source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model,
model_type=model_type, model_type=model_type,
submodel=submodel, submodel=submodel,
) )

View File

@ -16,13 +16,14 @@ class RestorationServices:
gfpgan, codeformer, esrgan = None, None, None gfpgan, codeformer, esrgan = None, None, None
if args.restore or args.esrgan: if args.restore or args.esrgan:
restoration = Restoration() restoration = Restoration()
if args.restore: # TODO: redo for new model structure
if False and args.restore:
gfpgan, codeformer = restoration.load_face_restore_models( gfpgan, codeformer = restoration.load_face_restore_models(
args.gfpgan_model_path args.gfpgan_model_path
) )
else: else:
logger.info("Face restoration disabled") logger.info("Face restoration disabled")
if args.esrgan: if False and args.esrgan:
esrgan = restoration.load_esrgan(args.esrgan_bg_tile) esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
else: else:
logger.info("Upscaling disabled") logger.info("Upscaling disabled")

View File

@ -9,5 +9,5 @@ from .generator import (
Img2Img, Img2Img,
Inpaint Inpaint
) )
from .model_management import ModelManager, ModelCache, ModelType, ModelInfo from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubmodelType, ModelInfo
from .safety_checker import SafetyChecker from .safety_checker import SafetyChecker

View File

@ -3,4 +3,4 @@ Initialization file for invokeai.backend.model_management
""" """
from .model_manager import ModelManager, ModelInfo from .model_manager import ModelManager, ModelInfo
from .model_cache import ModelCache from .model_cache import ModelCache
from .models import ModelType from .models import BaseModelType, ModelType, SubModelType

View File

@ -30,7 +30,9 @@ import torch
from diffusers import logging as diffusers_logging from diffusers import logging as diffusers_logging
from transformers import logging as transformers_logging from transformers import logging as transformers_logging
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from .models import ModelType, SubModelType, ModelBase from invokeai.app.services.config import get_invokeai_config
from .lora import LoRAModel, TextualInversionModel
from .models import BaseModelType, ModelType, SubModelType, ModelBase
# Maximum size of the cache, in gigs # Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously # Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
@ -122,11 +124,12 @@ class ModelCache(object):
def get_key( def get_key(
self, self,
model_path: str, model_path: str,
base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
submodel_type: Optional[ModelType] = None, submodel_type: Optional[SubModelType] = None,
): ):
key = f"{model_path}:{model_type}" key = f"{model_path}:{base_model}:{model_type}"
if submodel_type: if submodel_type:
key += f":{submodel_type}" key += f":{submodel_type}"
return key return key
@ -145,10 +148,13 @@ class ModelCache(object):
self, self,
model_path: str, model_path: str,
model_class: Type[ModelBase], model_class: Type[ModelBase],
base_model: BaseModelType,
model_type: ModelType,
): ):
model_info_key = self.get_key( model_info_key = self.get_key(
model_path=model_path, model_path=model_path,
model_type=model_class, base_model=base_model,
model_type=model_type,
submodel_type=None, submodel_type=None,
) )
@ -165,6 +171,8 @@ class ModelCache(object):
self, self,
model_path: Union[str, Path], model_path: Union[str, Path],
model_class: Type[ModelBase], model_class: Type[ModelBase],
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None, submodel: Optional[SubModelType] = None,
gpu_load: bool = True, gpu_load: bool = True,
) -> Any: ) -> Any:
@ -178,17 +186,20 @@ class ModelCache(object):
model_info = self._get_model_info( model_info = self._get_model_info(
model_path=model_path, model_path=model_path,
model_class=model_class, model_class=model_class,
base_model=base_model,
model_type=model_type,
) )
key = self.get_key( key = self.get_key(
model_path=model_path, model_path=model_path,
model_type=model_class, # TODO: base_model=base_model,
model_type=model_type,
submodel_type=submodel, submodel_type=submodel,
) )
# TODO: lock for no copies on simultaneous calls? # TODO: lock for no copies on simultaneous calls?
cache_entry = self._cached_models.get(key, None) cache_entry = self._cached_models.get(key, None)
if cache_entry is None: if cache_entry is None:
self.logger.info(f'Loading model {model_path}, type {model_class}:{submodel}') self.logger.info(f'Loading model {model_path}, type {base_model}:{model_type}:{submodel}')
# this will remove older cached models until # this will remove older cached models until
# there is sufficient room to load the requested model # there is sufficient room to load the requested model

View File

@ -160,6 +160,8 @@ from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig from omegaconf.dictconfig import DictConfig
from pydantic import BaseModel
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import CUDA_DEVICE, download_with_resume from invokeai.backend.util import CUDA_DEVICE, download_with_resume
@ -184,9 +186,10 @@ class ModelCache(object):
class ModelInfo(): class ModelInfo():
context: ModelLocker context: ModelLocker
name: str name: str
base_model: BaseModelType
type: ModelType type: ModelType
hash: str hash: str
location: Union[Path,str] location: Union[Path, str]
precision: torch.dtype precision: torch.dtype
revision: str = None revision: str = None
_cache: ModelCache = None _cache: ModelCache = None
@ -222,6 +225,9 @@ MAX_CACHE_SIZE = 6.0 # GB
# └── realesrgan # └── realesrgan
class ConfigMeta(BaseModel):
version: str
class ModelManager(object): class ModelManager(object):
""" """
High-level interface to model management. High-level interface to model management.
@ -229,6 +235,38 @@ class ModelManager(object):
logger: types.ModuleType = logger logger: types.ModuleType = logger
# TODO:
def _convert_2_3_models(self, config: DictConfig):
for model_name, model_config in config.items():
if model_config["format"] == "diffusers":
pass
elif model_config["format"] == "ckpt":
if any(model_config["config"].endswith(file) for file in {
"v1-finetune.yaml",
"v1-finetune_style.yaml",
"v1-inference.yaml",
"v1-inpainting-inference.yaml",
"v1-m1-finetune.yaml",
}):
# copy as as sd1.5
pass
# ~99% accurate should be
elif model_config["config"].endswith("v2-inference-v.yaml"):
# copy as sd 2.x (768)
pass
# for real don't know how accurate it
elif model_config["config"].endswith("v2-inference.yaml"):
# copy as sd 2.x-base (512)
pass
else:
# TODO:
raise Exception("Unknown model")
def __init__( def __init__(
self, self,
config: Union[Path, DictConfig, str], config: Union[Path, DictConfig, str],
@ -244,18 +282,29 @@ class ModelManager(object):
and sequential_offload boolean. Note that the default device and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision. type and precision are set up for a CUDA system running at half precision.
""" """
if isinstance(config, DictConfig):
self.config_path = None self.config_path = None
self.config = config if isinstance(config, (str, Path)):
elif isinstance(config,(str,Path)): self.config_path = Path(config)
self.config_path = config config = OmegaConf.load(self.config_path)
self.config = OmegaConf.load(self.config_path)
else: elif not isinstance(config, DictConfig):
raise ValueError('config argument must be an OmegaConf object, a Path or a string') raise ValueError('config argument must be an OmegaConf object, a Path or a string')
#if "__meta__" not in config:
# config = self._convert_2_3_models(config)
config_meta = ConfigMeta(**config.pop("__metadata__")) # TODO: naming
# TODO: metadata not found
self.models = dict()
for model_key, model_config in config.items():
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
self.models[model_key] = model_class.build_config(**model_config)
# check config version number and update on disk/RAM if necessary # check config version number and update on disk/RAM if necessary
self.globals = InvokeAIAppConfig.get_config() self.globals = InvokeAIAppConfig.get_config()
self._update_config_file_version()
self.logger = logger self.logger = logger
self.cache = ModelCache( self.cache = ModelCache(
max_cache_size=max_cache_size, max_cache_size=max_cache_size,
@ -267,7 +316,7 @@ class ModelManager(object):
self.cache_keys = dict() self.cache_keys = dict()
# add controlnet, lora and textual_inversion models from disk # add controlnet, lora and textual_inversion models from disk
self.scan_models_directory(include_diffusers=False) self.scan_models_directory()
def model_exists( def model_exists(
self, self,
@ -280,7 +329,7 @@ class ModelManager(object):
identifier. identifier.
""" """
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
return model_key in self.config return model_key in self.models
def create_key( def create_key(
self, self,
@ -350,38 +399,49 @@ class ModelManager(object):
""" """
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
model_dir = self.globals.models_path model_key = self.create_key(model_name, base_model, model_type)
if not model_class.has_config:
model_config = None
for ext in {"pt", "ckpt", "safetensors"}: # if model not found try to find it (maybe file just pasted)
model_path = os.path.join(model_dir, base_model, model_type, f"{model_name}.{ext}") if model_key not in self.models:
if os.path.exists(model_path): # TODO: find by mask or try rescan?
break path_mask = f"/models/{base_model}/{model_type}/{model_name}*"
else: if False: # model_path = next(find_by_mask(path_mask)):
model_path = os.path.join(model_dir, base_model, model_type, model_name) model_path = None # TODO:
if not os.path.exists(model_path): model_config = model_class.build_config(
raise InvalidModelError( path=model_path,
f"Model not found - \"{base_model}/{model_type}/{model_name}\" "
)
else:
# find in config
model_key = self.create_key(model_name, base_model, model_type)
if model_key not in self.config:
raise InvalidModelError(
f'"{model_key}" is not a known model name. Please check your models.yaml file'
) )
self.models[model_key] = model_config
else:
raise Exception(f"Model not found - {model_key}")
model_config = self.config[model_key] # if it known model check that target path exists (if manualy deleted)
model_path = model_config.path else:
# logic repeated twice(in rescan too) any way to optimize?
if not os.path.exists(self.models[model_key].path):
if model_class.save_to_config:
self.models[model_key].error = ModelError.NotFound
raise Exception(f"Files for model \"{model_key}\" not found")
# vae/movq override else:
# TODO: self.models.pop(model_key, None)
if submodel_type is not None and submodel_type in model_config: raise Exception(f"Model not found - {model_key}")
model_path = model_config[submodel_type]["path"]
model_type = submodel_type # reset model errors?
submodel_type = None
model_config = self.models[model_key]
# /models/{base_model}/{model_type}/{name}.ckpt or .safentesors
# /models/{base_model}/{model_type}/{name}/
model_path = model_config.path
# vae/movq override
# TODO:
if submodel is not None and submodel in model_config:
model_path = model_config[submodel]
model_type = submodel
submodel = None
dst_convert_path = None # TODO: dst_convert_path = None # TODO:
model_path = model_class.convert_if_required( model_path = model_class.convert_if_required(
@ -414,11 +474,11 @@ class ModelManager(object):
Returns the name of the default model, or None Returns the name of the default model, or None
if none is defined. if none is defined.
""" """
for model_key, model_config in self.config.items(): for model_key, model_config in self.models.items():
if model_config.get("default", False): if model_config.default:
return self.parse_key(model_key) return self.parse_key(model_key)
for model_key, _ in self.config.items(): for model_key, _ in self.models.items():
return self.parse_key(model_key) return self.parse_key(model_key)
else: else:
return None # TODO: or redo as (None, None, None) return None # TODO: or redo as (None, None, None)
@ -435,14 +495,11 @@ class ModelManager(object):
""" """
model_key = self.model_key(model_name, base_model, model_type) model_key = self.model_key(model_name, base_model, model_type)
if model_key not in self.config: if model_key not in self.models:
raise Exception(f"Unknown model: {model_key}") raise Exception(f"Unknown model: {model_key}")
for cur_model_key, config in self.config.items(): for cur_model_key, config in self.models.items():
if cur_model_key == model_key: config.default = cur_model_key == model_key
config["default"] = True
else:
config.pop("default", None)
def model_info( def model_info(
self, self,
@ -454,14 +511,17 @@ class ModelManager(object):
Given a model name returns the OmegaConf (dict-like) object describing it. Given a model name returns the OmegaConf (dict-like) object describing it.
""" """
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
return self.config.get(model_key, None) if model_key in self.models:
return self.models[model_key].dict(exclude_defaults=True)
else:
return None # TODO: None or empty dict on not found
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
""" """
Return a list of (str, BaseModelType, ModelType) corresponding to all models Return a list of (str, BaseModelType, ModelType) corresponding to all models
known to the configuration. known to the configuration.
""" """
return [(self.parse_key(x)) for x in self.config.keys() if isinstance(self.config[x], DictConfig)] return [(self.parse_key(x)) for x in self.models.keys()]
def list_models( def list_models(
self, self,
@ -479,61 +539,52 @@ class ModelManager(object):
assert not(model_type is not None and base_model is None), "model_type must be provided with base_model" assert not(model_type is not None and base_model is None), "model_type must be provided with base_model"
models = dict() models = dict()
for model_key in sorted(self.config, key=str.casefold): for model_key in sorted(self.models, key=str.casefold):
stanza = self.config[model_key] model_config = self.models[model_key]
if model_key.startswith('_'): cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
if base_model is not None and cur_base_model != base_model:
continue
if model_type is not None and cur_model_type != model_type:
continue continue
model_name, m_base_model, stanza_type = self.parse_key(model_key) if cur_base_model not in models:
if base_model is not None and m_base_model != base_model: models[cur_base_model] = dict()
continue if cur_model_type not in models[cur_base_model]:
if model_type is not None and model_type != stanza_type: models[cur_base_model][cur_model_type] = dict()
continue
if m_base_model not in models: models[cur_base_model][cur_model_type][cur_model_name] = dict(
models[m_base_model] = dict() **model_config.dict(exclude_defaults=True),
if stanza_type not in models: name=cur_model_name,
models[m_base_model][stanza_type] = dict() base_model=cur_base_model,
type=cur_model_type,
model_class = MODEL_CLASSES[m_base_model][stanza_type]
models[m_base_model][stanza_type][model_name] = model_class.build_config(
**stanza,
name=model_name,
base_model=base_model,
type=stanza_type,
) )
#models[m_base_model][stanza_type][model_name] = model_class.Config(
# **stanza,
# name=model_name,
# base_model=base_model,
# type=stanza_type,
#).dict()
return models return models
def print_models(self) -> None: def print_models(self) -> None:
""" """
Print a table of models, their descriptions, and load status Print a table of models, their descriptions
""" """
# TODO: redo
for model_type, model_dict in self.list_models().items(): for model_type, model_dict in self.list_models().items():
for model_name, model_info in model_dict.items(): for model_name, model_info in model_dict.items():
line = f'{model_info["name"]:25s} {model_info["status"]:>15s} {model_info["type"]:10s} {model_info["description"]}' line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
if model_info["status"] in ["in gpu","locked in gpu"]:
line = f"\033[1m{line}\033[0m"
print(line) print(line)
def del_model( def del_model(
self, self,
model_name: str, model_name: str,
model_type: ModelType.Diffusers, base_model: BaseModelType,
model_type: ModelType,
delete_files: bool = False, delete_files: bool = False,
): ):
""" """
Delete the named model. Delete the named model.
""" """
model_key = self.create_key(model_name, model_type) raise Exception("TODO: del_model") # TODO: redo
model_cfg = self.pop(model_key, None) model_key = self.create_key(model_name, base_model, model_type)
model_cfg = self.models.pop(model_key, None)
if model_cfg is None: if model_cfg is None:
self.logger.error( self.logger.error(
@ -581,27 +632,14 @@ class ModelManager(object):
""" """
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
model_config = model_class.build_config(**model_attributes)
model_class.build_config(
**model_attributes,
name=model_name,
base_model=base_model,
type=model_type,
)
#model_cfg = model_class.Config(
# **model_attributes,
# name=model_name,
# base_model=base_model,
# type=model_type,
#)
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
assert ( assert (
clobber or model_key not in self.config clobber or model_key not in self.models
), f'attempt to overwrite existing model definition "{model_key}"' ), f'attempt to overwrite existing model definition "{model_key}"'
self.config[model_key] = model_attributes self.models[model_key] = model_config
if clobber and model_key in self.cache_keys: if clobber and model_key in self.cache_keys:
# TODO: # TODO:
@ -633,7 +671,15 @@ class ModelManager(object):
""" """
Write current configuration out to the indicated file. Write current configuration out to the indicated file.
""" """
yaml_str = OmegaConf.to_yaml(self.config) data_to_save = dict()
for model_key, model_config in self.models.items():
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
if model_class.save_to_config:
# TODO: or exclude_unset better fits here?
data_to_save[model_key] = model_config.dict(exclude_defaults=True)
yaml_str = OmegaConf.to_yaml(data_to_save)
config_file_path = conf_file or self.config_path config_file_path = conf_file or self.config_path
assert config_file_path is not None,'no config file path to write to' assert config_file_path is not None,'no config file path to write to'
config_file_path = self.globals.root_dir / config_file_path config_file_path = self.globals.root_dir / config_file_path
@ -697,61 +743,3 @@ class ModelManager(object):
resolved_path = self.globals.root_dir / source resolved_path = self.globals.root_dir / source
return resolved_path return resolved_path
def _update_config_file_version(self):
"""
This gets called at object init time and will update
from older versions of the config file to new ones
as necessary.
"""
current_version = self.config.get("_version","1.0.0")
if version.parse(current_version) < version.parse(CONFIG_FILE_VERSION):
self.logger.warning(f'models.yaml version {current_version} detected. Updating to {CONFIG_FILE_VERSION}')
self.logger.warning('The original file will be renamed models.yaml.orig')
if self.config_path:
old_file = Path(self.config_path)
new_name = old_file.parent / 'models.yaml.orig'
old_file.replace(new_name)
new_config = OmegaConf.create()
new_config["_version"] = CONFIG_FILE_VERSION
for model_key in self.config:
old_stanza = self.config[model_key]
if not isinstance(old_stanza,DictConfig):
continue
# ignore old and ugly way of associating a legacy
# vae with a legacy checkpont model
if old_stanza.get("config") and '/VAE/' in old_stanza.get("config"):
continue
# bare keys are updated to be prefixed with 'diffusers/'
if '/' not in model_key:
new_key = f'diffusers/{model_key}'
else:
new_key = model_key
if old_stanza.get('format')=='diffusers':
model_format = 'folder'
elif old_stanza.get('weights') and Path(old_stanza.get('weights')).suffix == '.ckpt':
model_format = 'ckpt'
elif old_stanza.get('weights') and Path(old_stanza.get('weights')).suffix == '.safetensors':
model_format = 'safetensors'
else:
model_format = old_stanza.get('format')
# copy fields over manually rather than doing a copy() or deepcopy()
# in order to avoid bringing in unwanted fields.
new_config[new_key] = dict(
description = old_stanza.get('description'),
format = model_format,
)
for field in ["repo_id", "path", "weights", "config", "vae"]:
if field_value := old_stanza.get(field):
new_config[new_key].update({field: field_value})
self.config = new_config
if self.config_path:
self.commit()

View File

@ -0,0 +1,37 @@
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase
from .stable_diffusion import StableDiffusion15Model, StableDiffusion2Model, StableDiffusion2BaseModel
from .vae import VaeModel
from .lora import LoRAModel
#from .controlnet import ControlNetModel # TODO:
from .textual_inversion import TextualInversionModel
MODEL_CLASSES = {
BaseModelType.StableDiffusion1_5: {
ModelType.Pipeline: StableDiffusion15Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
#ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModelType.StableDiffusion2: {
ModelType.Pipeline: StableDiffusion2Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
#ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModelType.StableDiffusion2Base: {
ModelType.Pipeline: StableDiffusion2BaseModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
#ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
#BaseModelType.Kandinsky2_1: {
# ModelType.Pipeline: Kandinsky2_1Model,
# ModelType.MoVQ: MoVQModel,
# ModelType.Lora: LoRAModel,
# ModelType.ControlNet: ControlNetModel,
# ModelType.TextualInversion: TextualInversionModel,
#},
}

View File

@ -0,0 +1,297 @@
import sys
import typing
import inspect
from enum import Enum
import torch
from diffusers import DiffusionPipeline, ConfigMixin
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal
class BaseModelType(str, Enum):
#StableDiffusion1_5 = "stable_diffusion_1_5"
#StableDiffusion2 = "stable_diffusion_2"
#StableDiffusion2Base = "stable_diffusion_2_base"
# TODO: maybe then add sample size(512/768)?
StableDiffusion1_5 = "sd-1.5"
StableDiffusion2Base = "sd-2-base" # 512 pixels; this will have epsilon parameterization
StableDiffusion2 = "sd-2" # 768 pixels; this will have v-prediction parameterization
#Kandinsky2_1 = "kandinsky_2_1"
class ModelType(str, Enum):
Pipeline = "pipeline"
Vae = "vae"
Lora = "lora"
#ControlNet = "controlnet"
TextualInversion = "embedding"
class SubModelType(str, Enum):
UNet = "unet"
TextEncoder = "text_encoder"
Tokenizer = "tokenizer"
Vae = "vae"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
#MoVQ = "movq"
class ModelError(str, Enum):
NotFound = "not_found"
class ModelConfigBase(BaseModel):
path: str # or Path
#name: str # not included as present in model key
description: Optional[str] = Field(None)
format: Optional[str] = Field(None)
default: Optional[bool] = Field(False)
# do not save to config
error: Optional[ModelError] = Field(None, exclude=True)
class EmptyConfigLoader(ConfigMixin):
@classmethod
def load_config(cls, *args, **kwargs):
cls.config_name = kwargs.pop("config_name")
return super().load_config(*args, **kwargs)
class ModelBase:
#model_path: str
#base_model: BaseModelType
#model_type: ModelType
def __init__(
self,
model_path: str,
base_model: BaseModelType,
model_type: ModelType,
):
self.model_path = model_path
self.base_model = base_model
self.model_type = model_type
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
if len(subtypes) < 2:
raise Exception("Invalid subfolder definition!")
if subtypes[0] in ["diffusers", "transformers"]:
res_type = sys.modules[subtypes[0]]
subtypes = subtypes[1:]
else:
res_type = sys.modules["diffusers"]
res_type = getattr(res_type, "pipelines")
for subtype in subtypes:
res_type = getattr(res_type, subtype)
return res_type
@classmethod
def _get_configs(cls):
if not hasattr(cls, "__configs"):
configs = dict()
for name in dir(cls):
if name.startswith("__"):
continue
value = getattr(cls, name)
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
continue
fields = inspect.get_annotations(value)
if "format" not in fields or typing.get_origin(fields["format"]) != Literal:
raise Exception("Invalid config definition - format field not found")
format_type = typing.get_origin(fields["format"])
if format_type not in {None, Literal}:
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
if format_type is Literal:
format = fields["format"].__args__[0]
else:
format = None
configs[format] = value # TODO: error when override(multiple)?
cls.__configs = configs
return cls.__configs
@classmethod
def build_config(cls, **kwargs):
if "format" not in kwargs:
kwargs["format"] = cls.detect_format(kwargs["path"])
configs = cls._get_configs()
return configs[kwargs["format"]](**kwargs)
@classmethod
def detect_format(cls, path: str) -> str:
raise NotImplementedError()
class DiffusersModel(ModelBase):
#child_types: Dict[str, Type]
#child_sizes: Dict[str, int]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
super().__init__(model_path, base_model, model_type)
self.child_types: Dict[str, Type] = dict()
self.child_sizes: Dict[str, int] = dict()
try:
config_data = DiffusionPipeline.load_config(self.model_path)
#config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
except:
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
config_data.pop("_ignore_files", None)
# retrieve all folder_names that contain relevant files
child_components = [k for k, v in config_data.items() if isinstance(v, list)]
for child_name in child_components:
child_type = self._hf_definition_to_type(config_data[child_name])
self.child_types[child_name] = child_type
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is None:
return sum(self.child_sizes.values())
else:
return self.child_sizes[child_type]
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
# return pipeline in different function to pass more arguments
if child_type is None:
raise Exception("Child model type can't be null on diffusers model")
if child_type not in self.child_types:
return None # TODO: or raise
if torch_dtype == torch.float16:
variants = ["fp16", None]
else:
variants = [None, "fp16"]
# TODO: better error handling(differentiate not found from others)
for variant in variants:
try:
# TODO: set cache_dir to /dev/null to be sure that cache not used?
model = self.child_types[child_type].from_pretrained(
self.model_path,
subfolder=child_type.value,
torch_dtype=torch_dtype,
variant=variant,
local_files_only=True,
)
break
except Exception as e:
print("====ERR LOAD====")
print(f"{variant}: {e}")
# calc more accurate size
self.child_sizes[child_type] = calc_model_size_by_data(model)
return model
#def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
def calc_model_size_by_fs(
model_path: str,
subfolder: Optional[str] = None,
variant: Optional[str] = None
):
if subfolder is not None:
model_path = os.path.join(model_path, subfolder)
# this can happen when, for example, the safety checker
# is not downloaded.
if not os.path.exists(model_path):
return 0
all_files = os.listdir(model_path)
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
other_files = set(all_files) - fp16_files - bit8_files
if variant is None:
files = other_files
elif variant == "fp16":
files = fp16_files
elif variant == "8bit":
files = bit8_files
else:
raise NotImplementedError(f"Unknown variant: {variant}")
# try read from index if exists
index_postfix = ".index.json"
if variant is not None:
index_postfix = f".index.{variant}.json"
for file in files:
if not file.endswith(index_postfix):
continue
try:
with open(os.path.join(model_path, file), "r") as f:
index_data = json.loads(f.read())
return int(index_data["metadata"]["total_size"])
except:
pass
# calculate files size if there is no index file
formats = [
(".safetensors",), # safetensors
(".bin",), # torch
(".onnx", ".pb"), # onnx
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
]
for file_format in formats:
model_files = [f for f in files if f.endswith(file_format)]
if len(model_files) == 0:
continue
model_size = 0
for model_file in model_files:
file_stats = os.stat(os.path.join(model_path, model_file))
model_size += file_stats.st_size
return model_size
#raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
def calc_model_size_by_data(model) -> int:
if isinstance(model, DiffusionPipeline):
return _calc_pipeline_by_data(model)
elif isinstance(model, torch.nn.Module):
return _calc_model_by_data(model)
else:
return 0
def _calc_pipeline_by_data(pipeline) -> int:
res = 0
for submodel_key in pipeline.components.keys():
submodel = getattr(pipeline, submodel_key)
if submodel is not None and isinstance(submodel, torch.nn.Module):
res += _calc_model_by_data(submodel)
return res
def _calc_model_by_data(model) -> int:
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes
return mem

View File

@ -0,0 +1,63 @@
import torch
from typing import Optional
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
)
# TODO: naming
from ..lora import LoRAModel as LoRAModelRaw
class LoRAModel(ModelBase):
#model_size: int
class Config(ModelConfigBase):
format: None
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Lora
super().__init__(model_path, base_model, model_type)
self.model_size = os.path.getsize(self.model_path)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in lora")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in lora")
model = LoRAModelRaw.from_checkpoint(
file_path=self.model_path,
dtype=torch_dtype,
)
self.model_size = model.calc_size()
return model
@classmethod
def save_to_config(cls) -> bool:
return False
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
return "diffusers"
else:
return "lycoris"
@staticmethod
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str:
if cls.detect_format(model_path) == "diffusers":
# TODO: add diffusers lora when it stabilizes a bit
raise NotImplementedError("Diffusers lora not supported")
else:
return model_path

View File

@ -0,0 +1,131 @@
import os
import torch
from pydantic import Field
from typing import Literal, Optional
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
DiffusersModel,
)
from invokeai.app.services.config import InvokeAIAppConfig
# TODO: how to name properly
class StableDiffusion15Model(DiffusersModel):
# TODO: str -> Path?
class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"]
vae: Optional[str] = Field(None)
class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1_5
assert model_type == ModelType.Pipeline
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion1_5,
model_type=ModelType.Pipeline,
)
@classmethod
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
if os.path.isdir(model_path):
return "diffusers"
else:
return "checkpoint"
@classmethod
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str:
cfg = cls.build_config(**config)
if isinstance(cfg, cls.CheckpointConfig):
return _convert_ckpt_and_cache(cfg) # TODO: args
else:
return model_path
# all same
class StableDiffusion2BaseModel(StableDiffusion15Model):
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
# skip StableDiffusion15Model __init__
assert base_model == BaseModelType.StableDiffusion2Base
assert model_type == ModelType.Pipeline
super(StableDiffusion15Model, self).__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion2Base,
model_type=ModelType.Pipeline,
)
class StableDiffusion2Model(DiffusersModel):
# TODO: str -> Path?
# overwrite configs
class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"]
vae: Optional[str] = Field(None)
attention_upscale: bool = Field(True)
class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
attention_upscale: bool = Field(True)
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
# skip StableDiffusion15Model __init__
assert base_model == BaseModelType.StableDiffusion2
assert model_type == ModelType.Pipeline
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion2,
model_type=ModelType.Pipeline,
)
# TODO: rework
DictConfig = dict
def _convert_ckpt_and_cache(self, mconfig: DictConfig) -> str:
"""
Convert the checkpoint model indicated in mconfig into a
diffusers, cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
app_config = InvokeAIAppConfig.get_config()
weights = app_config.root_dir / mconfig.path
config_file = app_config.root_dir / mconfig.config
diffusers_path = app_config.converted_ckpts_dir / weights.stem
# return cached version if it exists
if diffusers_path.exists():
return diffusers_path
# TODO: I think that it more correctly to convert with embedded vae
# as if user will delete custom vae he will got not embedded but also custom vae
#vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
vae_ckpt_path, vae_model = None, None
# to avoid circular import errors
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
with SilenceWarnings():
convert_ckpt_to_diffusers(
weights,
diffusers_path,
extract_ema=True,
original_config_file=config_file,
vae=vae_model,
vae_path=str(app_config.root_dir / vae_ckpt_path) if vae_ckpt_path else None,
scan_needed=True,
)
return diffusers_path

View File

@ -0,0 +1,56 @@
import torch
from typing import Optional
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
)
# TODO: naming
from ..lora import TextualInversionModel as TextualInversionModelRaw
class TextualInversionModel(ModelBase):
#model_size: int
class Config(ModelConfigBase):
format: None
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.TextualInversion
super().__init__(model_path, base_model, model_type)
self.model_size = os.path.getsize(self.model_path)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in textual inversion")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in textual inversion")
model = TextualInversionModelRaw.from_checkpoint(
file_path=self.model_path,
dtype=torch_dtype,
)
self.model_size = model.embedding.nelement() * model.embedding.element_size()
return model
@classmethod
def save_to_config(cls) -> bool:
return False
@classmethod
def detect_format(cls, path: str):
return None
@staticmethod
def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
return model_path

View File

@ -0,0 +1,122 @@
import os
import torch
from typing import Optional
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
EmptyConfigLoader,
calc_model_size_by_fs,
calc_model_size_by_data,
)
from invokeai.app.services.config import InvokeAIAppConfig
class VaeModel(ModelBase):
#vae_class: Type
#model_size: int
class Config(ModelConfigBase):
format: None
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Vae
super().__init__(model_path, base_model, model_type)
try:
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
#config = json.loads(os.path.join(self.model_path, "config.json"))
except:
raise Exception("Invalid vae model! (config.json not found or invalid)")
try:
vae_class_name = config.get("_class_name", "AutoencoderKL")
self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name])
self.model_size = calc_model_size_by_fs(self.model_path)
except:
raise Exception("Invalid vae model! (Unkown vae type)")
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in vae model")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in vae model")
model = self.vae_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
)
# calc more accurate size
self.model_size = calc_model_size_by_data(model)
return model
@classmethod
def save_to_config(cls) -> bool:
return False
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
return "diffusers"
else:
return "checkpoint"
@classmethod
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str:
if cls.detect_format(model_path) != "diffusers":
# TODO:
#_convert_vae_ckpt_and_cache
raise NotImplementedError("TODO: vae convert")
else:
return model_path
# TODO: rework
DictConfig = dict
def _convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> str:
"""
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
object, cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
app_config = InvokeAIAppConfig.get_config()
root = app_config.root_dir
weights_file = root / mconfig.path
config_file = root / mconfig.config
diffusers_path = app_config.converted_ckpts_dir / weights_file.stem
image_size = mconfig.get('width') or mconfig.get('height') or 512
# return cached version if it exists
if diffusers_path.exists():
return diffusers_path
# this avoids circular import error
from .convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
if weights_file.suffix == '.safetensors':
checkpoint = safetensors.torch.load_file(weights_file)
else:
checkpoint = torch.load(weights_file, map_location="cpu")
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
config = OmegaConf.load(config_file)
vae_model = convert_ldm_vae_to_diffusers(
checkpoint = checkpoint,
vae_config = config,
image_size = image_size
)
vae_model.save_pretrained(
diffusers_path,
safe_serialization=is_safetensors_available()
)
return diffusers_path

View File

@ -14,7 +14,7 @@ export const receivedModels = createAppAsyncThunk(
const response = await ModelsService.listModels(); const response = await ModelsService.listModels();
const deserializedModels = reduce( const deserializedModels = reduce(
response.models['diffusers'], response.models['sd-1.5']['pipeline'],
(modelsAccumulator, model, modelName) => { (modelsAccumulator, model, modelName) => {
modelsAccumulator[modelName] = { ...model, name: modelName }; modelsAccumulator[modelName] = { ...model, name: modelName };
@ -25,7 +25,7 @@ export const receivedModels = createAppAsyncThunk(
models.info( models.info(
{ response }, { response },
`Received ${size(response.models['diffusers'])} models` `Received ${size(response.models['sd-1.5']['pipeline'])} models`
); );
return deserializedModels; return deserializedModels;