mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fixes
This commit is contained in:
parent
3ce3a7ee72
commit
738ba40f51
@ -1,12 +1,12 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
|
||||||
|
|
||||||
from typing import Annotated, Literal, Optional, Union
|
from typing import Annotated, Literal, Optional, Union, Dict
|
||||||
|
|
||||||
from fastapi import Query
|
from fastapi import Query
|
||||||
from fastapi.routing import APIRouter, HTTPException
|
from fastapi.routing import APIRouter, HTTPException
|
||||||
from pydantic import BaseModel, Field, parse_obj_as
|
from pydantic import BaseModel, Field, parse_obj_as
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
from invokeai.backend import SDModelType
|
from invokeai.backend import BaseModelType, ModelType
|
||||||
|
|
||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
|
|
||||||
@ -60,7 +60,8 @@ class ConvertedModelResponse(BaseModel):
|
|||||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
info: DiffusersModelInfo = Field(description="The converted model info")
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
|
models: Dict[BaseModelType, Dict[ModelType, Dict[str, dict]]] # TODO: collect all configs
|
||||||
|
#models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
@ -69,7 +70,10 @@ class ModelsList(BaseModel):
|
|||||||
responses={200: {"model": ModelsList }},
|
responses={200: {"model": ModelsList }},
|
||||||
)
|
)
|
||||||
async def list_models(
|
async def list_models(
|
||||||
model_type: SDModelType = Query(
|
base_model: BaseModelType = Query(
|
||||||
|
default=None, description="Base model"
|
||||||
|
),
|
||||||
|
model_type: ModelType = Query(
|
||||||
default=None, description="The type of model to get"
|
default=None, description="The type of model to get"
|
||||||
),
|
),
|
||||||
) -> ModelsList:
|
) -> ModelsList:
|
||||||
|
@ -8,7 +8,7 @@ from .model import ClipField
|
|||||||
|
|
||||||
from ...backend.util.devices import torch_dtype
|
from ...backend.util.devices import torch_dtype
|
||||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||||
from ...backend.model_management import SDModelType
|
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
|
|
||||||
from compel import Compel
|
from compel import Compel
|
||||||
@ -76,7 +76,11 @@ class CompelInvocation(BaseInvocation):
|
|||||||
try:
|
try:
|
||||||
ti_list.append(
|
ti_list.append(
|
||||||
stack.enter_context(
|
stack.enter_context(
|
||||||
context.services.model_manager.get_model(model_name=name, model_type=SDModelType.TextualInversion)
|
context.services.model_manager.get_model(
|
||||||
|
model_name=name,
|
||||||
|
base_model=self.clip.text_encoder.base_model,
|
||||||
|
model_type=ModelType.TextualInversion,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -5,12 +5,13 @@ import copy
|
|||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||||
|
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||||
from ...backend.model_management import SDModelType
|
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||||
|
|
||||||
class ModelInfo(BaseModel):
|
class ModelInfo(BaseModel):
|
||||||
model_name: str = Field(description="Info to load submodel")
|
model_name: str = Field(description="Info to load submodel")
|
||||||
model_type: SDModelType = Field(description="Info to load submodel")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
submodel: Optional[SDModelType] = Field(description="Info to load submodel")
|
model_type: ModelType = Field(description="Info to load submodel")
|
||||||
|
submodel: Optional[SubModelType] = Field(description="Info to load submodel")
|
||||||
|
|
||||||
class LoraInfo(ModelInfo):
|
class LoraInfo(ModelInfo):
|
||||||
weight: float = Field(description="Lora's weight which to use when apply to model")
|
weight: float = Field(description="Lora's weight which to use when apply to model")
|
||||||
@ -63,10 +64,13 @@ class ModelLoaderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||||
|
|
||||||
|
base_model = BaseModelType.StableDiffusion2 # TODO:
|
||||||
|
|
||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.services.model_manager.model_exists(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
model_type=SDModelType.Diffusers,
|
base_model=base_model,
|
||||||
|
model_type=ModelType.Pipeline,
|
||||||
):
|
):
|
||||||
raise Exception(f"Unkown model name: {self.model_name}!")
|
raise Exception(f"Unkown model name: {self.model_name}!")
|
||||||
|
|
||||||
@ -104,12 +108,14 @@ class ModelLoaderInvocation(BaseInvocation):
|
|||||||
unet=UNetField(
|
unet=UNetField(
|
||||||
unet=ModelInfo(
|
unet=ModelInfo(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
model_type=SDModelType.Diffusers,
|
base_model=base_model,
|
||||||
submodel=SDModelType.UNet,
|
model_type=ModelType.Pipeline,
|
||||||
|
submodel=SubModelType.UNet,
|
||||||
),
|
),
|
||||||
scheduler=ModelInfo(
|
scheduler=ModelInfo(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
model_type=SDModelType.Diffusers,
|
base_model=base_model,
|
||||||
|
model_type=ModelType.Pipeline,
|
||||||
submodel=SDModelType.Scheduler,
|
submodel=SDModelType.Scheduler,
|
||||||
),
|
),
|
||||||
loras=[],
|
loras=[],
|
||||||
@ -117,12 +123,14 @@ class ModelLoaderInvocation(BaseInvocation):
|
|||||||
clip=ClipField(
|
clip=ClipField(
|
||||||
tokenizer=ModelInfo(
|
tokenizer=ModelInfo(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
model_type=SDModelType.Diffusers,
|
base_model=base_model,
|
||||||
|
model_type=ModelType.Pipeline,
|
||||||
submodel=SDModelType.Tokenizer,
|
submodel=SDModelType.Tokenizer,
|
||||||
),
|
),
|
||||||
text_encoder=ModelInfo(
|
text_encoder=ModelInfo(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
model_type=SDModelType.Diffusers,
|
base_model=base_model,
|
||||||
|
model_type=ModelType.Pipeline,
|
||||||
submodel=SDModelType.TextEncoder,
|
submodel=SDModelType.TextEncoder,
|
||||||
),
|
),
|
||||||
loras=[],
|
loras=[],
|
||||||
@ -130,7 +138,8 @@ class ModelLoaderInvocation(BaseInvocation):
|
|||||||
vae=VaeField(
|
vae=VaeField(
|
||||||
vae=ModelInfo(
|
vae=ModelInfo(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
model_type=SDModelType.Diffusers,
|
base_model=base_model,
|
||||||
|
model_type=ModelType.Pipeline,
|
||||||
submodel=SDModelType.Vae,
|
submodel=SDModelType.Vae,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from invokeai.app.models.image import ProgressImage
|
from invokeai.app.models.image import ProgressImage
|
||||||
from invokeai.app.util.misc import get_timestamp
|
from invokeai.app.util.misc import get_timestamp
|
||||||
from invokeai.app.services.model_manager_service import SDModelType, SDModelInfo
|
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, SDModelInfo
|
||||||
from invokeai.app.models.exceptions import CanceledException
|
from invokeai.app.models.exceptions import CanceledException
|
||||||
|
|
||||||
class EventServiceBase:
|
class EventServiceBase:
|
||||||
@ -109,8 +109,9 @@ class EventServiceBase:
|
|||||||
node: dict,
|
node: dict,
|
||||||
source_node_id: str,
|
source_node_id: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType,
|
base_model: BaseModelType,
|
||||||
submodel: SDModelType,
|
model_type: ModelType,
|
||||||
|
submodel: SubModelType,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when a model is requested"""
|
"""Emitted when a model is requested"""
|
||||||
self.__emit_session_event(
|
self.__emit_session_event(
|
||||||
@ -120,6 +121,7 @@ class EventServiceBase:
|
|||||||
node=node,
|
node=node,
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
submodel=submodel,
|
submodel=submodel,
|
||||||
),
|
),
|
||||||
@ -131,8 +133,9 @@ class EventServiceBase:
|
|||||||
node: dict,
|
node: dict,
|
||||||
source_node_id: str,
|
source_node_id: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType,
|
base_model: BaseModelType,
|
||||||
submodel: SDModelType,
|
model_type: ModelType,
|
||||||
|
submodel: SubModelType,
|
||||||
model_info: SDModelInfo,
|
model_info: SDModelInfo,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||||
@ -143,6 +146,7 @@ class EventServiceBase:
|
|||||||
node=node,
|
node=node,
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
submodel=submodel,
|
submodel=submodel,
|
||||||
model_info=model_info,
|
model_info=model_info,
|
||||||
|
@ -10,7 +10,9 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
from invokeai.backend.model_management.model_manager import (
|
from invokeai.backend.model_management.model_manager import (
|
||||||
ModelManager,
|
ModelManager,
|
||||||
SDModelType,
|
BaseModelType,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
SDModelInfo,
|
SDModelInfo,
|
||||||
)
|
)
|
||||||
from invokeai.app.models.exceptions import CanceledException
|
from invokeai.app.models.exceptions import CanceledException
|
||||||
@ -20,12 +22,6 @@ from ...backend.util import choose_precision, choose_torch_device
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LastUsedModel:
|
|
||||||
model_name: str=None
|
|
||||||
model_type: SDModelType=None
|
|
||||||
|
|
||||||
last_used_model = LastUsedModel()
|
|
||||||
|
|
||||||
class ModelManagerServiceBase(ABC):
|
class ModelManagerServiceBase(ABC):
|
||||||
"""Responsible for managing models on disk and in memory"""
|
"""Responsible for managing models on disk and in memory"""
|
||||||
@ -48,8 +44,9 @@ class ModelManagerServiceBase(ABC):
|
|||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType = SDModelType.Diffusers,
|
base_model: BaseModelType,
|
||||||
submodel: Optional[SDModelType] = None,
|
model_type: ModelType,
|
||||||
|
submodel: Optional[SubModelType] = None,
|
||||||
node: Optional[BaseInvocation] = None,
|
node: Optional[BaseInvocation] = None,
|
||||||
context: Optional[InvocationContext] = None,
|
context: Optional[InvocationContext] = None,
|
||||||
) -> SDModelInfo:
|
) -> SDModelInfo:
|
||||||
@ -67,12 +64,13 @@ class ModelManagerServiceBase(ABC):
|
|||||||
def model_exists(
|
def model_exists(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType,
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def default_model(self) -> Optional[Tuple[str, SDModelType]]:
|
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
|
||||||
"""
|
"""
|
||||||
Returns the name and typeof the default model, or None
|
Returns the name and typeof the default model, or None
|
||||||
if none is defined.
|
if none is defined.
|
||||||
@ -80,26 +78,26 @@ class ModelManagerServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_default_model(self, model_name: str, model_type: SDModelType):
|
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
"""Sets the default model to the indicated name."""
|
"""Sets the default model to the indicated name."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def model_info(self, model_name: str, model_type: SDModelType) -> dict:
|
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||||
"""
|
"""
|
||||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def model_names(self) -> List[Tuple[str, SDModelType]]:
|
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||||
"""
|
"""
|
||||||
Returns a list of all the model names known.
|
Returns a list of all the model names known.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list_models(self, model_type: SDModelType=None) -> dict:
|
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
|
||||||
"""
|
"""
|
||||||
Return a dict of models in the format:
|
Return a dict of models in the format:
|
||||||
{ model_type1:
|
{ model_type1:
|
||||||
@ -122,7 +120,8 @@ class ModelManagerServiceBase(ABC):
|
|||||||
def add_model(
|
def add_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType,
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
model_attributes: dict,
|
model_attributes: dict,
|
||||||
clobber: bool = False
|
clobber: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -139,7 +138,8 @@ class ModelManagerServiceBase(ABC):
|
|||||||
def del_model(
|
def del_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType,
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
delete_files: bool = False,
|
delete_files: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -297,8 +297,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType = SDModelType.Diffusers,
|
base_model: BaseModelType,
|
||||||
submodel: Optional[SDModelType] = None,
|
model_type: ModelType,
|
||||||
|
submodel: Optional[SubModelType] = None,
|
||||||
node: Optional[BaseInvocation] = None,
|
node: Optional[BaseInvocation] = None,
|
||||||
context: Optional[InvocationContext] = None,
|
context: Optional[InvocationContext] = None,
|
||||||
) -> SDModelInfo:
|
) -> SDModelInfo:
|
||||||
@ -307,23 +308,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
part (such as the vae) of a diffusers mode.
|
part (such as the vae) of a diffusers mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Temporary hack here: we remember the last model fetched
|
|
||||||
# so that when executing a graph, the first node called gets
|
|
||||||
# to set default model for subsequent nodes in the event that
|
|
||||||
# they do not set the model explicitly. This should be
|
|
||||||
# displaced by model loader mechanism.
|
|
||||||
# This is to work around lack of model loader at current time,
|
|
||||||
# which was causing inconsistent model usage throughout graph.
|
|
||||||
global last_used_model
|
|
||||||
|
|
||||||
if not model_name:
|
|
||||||
self.logger.debug('No model name provided, defaulting to last loaded model')
|
|
||||||
model_name = last_used_model.model_name
|
|
||||||
model_type = model_type or last_used_model.model_type
|
|
||||||
else:
|
|
||||||
last_used_model.model_name = model_name
|
|
||||||
last_used_model.model_type = model_type
|
|
||||||
|
|
||||||
# if we are called from within a node, then we get to emit
|
# if we are called from within a node, then we get to emit
|
||||||
# load start and complete events
|
# load start and complete events
|
||||||
if node and context:
|
if node and context:
|
||||||
@ -331,12 +315,14 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
node=node,
|
node=node,
|
||||||
context=context,
|
context=context,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
submodel=submodel
|
submodel=submodel,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_info = self.mgr.get_model(
|
model_info = self.mgr.get_model(
|
||||||
model_name,
|
model_name,
|
||||||
|
base_model,
|
||||||
model_type,
|
model_type,
|
||||||
submodel,
|
submodel,
|
||||||
)
|
)
|
||||||
@ -346,6 +332,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
node=node,
|
node=node,
|
||||||
context=context,
|
context=context,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
submodel=submodel,
|
submodel=submodel,
|
||||||
model_info=model_info
|
model_info=model_info
|
||||||
@ -356,7 +343,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
def model_exists(
|
def model_exists(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType,
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Given a model name, returns True if it is a valid
|
Given a model name, returns True if it is a valid
|
||||||
@ -364,33 +352,38 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
"""
|
"""
|
||||||
return self.mgr.model_exists(
|
return self.mgr.model_exists(
|
||||||
model_name,
|
model_name,
|
||||||
|
base_model,
|
||||||
model_type,
|
model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def default_model(self) -> Optional[Tuple[str, SDModelType]]:
|
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
|
||||||
"""
|
"""
|
||||||
Returns the name of the default model, or None
|
Returns the name of the default model, or None
|
||||||
if none is defined.
|
if none is defined.
|
||||||
"""
|
"""
|
||||||
return self.mgr.default_model()
|
return self.mgr.default_model()
|
||||||
|
|
||||||
def set_default_model(self, model_name: str, model_type: SDModelType):
|
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
"""Sets the default model to the indicated name."""
|
"""Sets the default model to the indicated name."""
|
||||||
self.mgr.set_default_model(model_name)
|
self.mgr.set_default_model(model_name, base_model, model_type)
|
||||||
|
|
||||||
def model_info(self, model_name: str, model_type: SDModelType) -> dict:
|
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||||
"""
|
"""
|
||||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||||
"""
|
"""
|
||||||
return self.mgr.model_info(model_name)
|
return self.mgr.model_info(model_name, base_model, model_type)
|
||||||
|
|
||||||
def model_names(self) -> List[Tuple[str, SDModelType]]:
|
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||||
"""
|
"""
|
||||||
Returns a list of all the model names known.
|
Returns a list of all the model names known.
|
||||||
"""
|
"""
|
||||||
return self.mgr.model_names()
|
return self.mgr.model_names()
|
||||||
|
|
||||||
def list_models(self, model_type: SDModelType=None) -> dict:
|
def list_models(
|
||||||
|
self,
|
||||||
|
base_model: Optional[BaseModelType] = None,
|
||||||
|
model_type: Optional[ModelType] = None
|
||||||
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Return a dict of models in the format:
|
Return a dict of models in the format:
|
||||||
{ model_type1:
|
{ model_type1:
|
||||||
@ -406,12 +399,13 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
{ model_name_n: etc
|
{ model_name_n: etc
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
return self.mgr.list_models()
|
return self.mgr.list_models(base_model, model_type)
|
||||||
|
|
||||||
def add_model(
|
def add_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType,
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
model_attributes: dict,
|
model_attributes: dict,
|
||||||
clobber: bool = False,
|
clobber: bool = False,
|
||||||
)->None:
|
)->None:
|
||||||
@ -422,13 +416,14 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
with an assertion error if provided attributes are incorrect or
|
with an assertion error if provided attributes are incorrect or
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
"""
|
"""
|
||||||
return self.mgr.add_model(model_name, model_type, model_attributes, clobber)
|
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
||||||
|
|
||||||
|
|
||||||
def del_model(
|
def del_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType = SDModelType.Diffusers,
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
delete_files: bool = False,
|
delete_files: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -436,7 +431,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
then the underlying weight file or diffusers directory will be deleted
|
then the underlying weight file or diffusers directory will be deleted
|
||||||
as well. Call commit() to write to disk.
|
as well. Call commit() to write to disk.
|
||||||
"""
|
"""
|
||||||
self.mgr.del_model(model_name, model_type, delete_files)
|
self.mgr.del_model(model_name, base_model, model_type, delete_files)
|
||||||
|
|
||||||
def import_diffuser_model(
|
def import_diffuser_model(
|
||||||
self,
|
self,
|
||||||
@ -541,8 +536,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
node,
|
node,
|
||||||
context,
|
context,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType,
|
base_model: BaseModelType,
|
||||||
submodel: SDModelType,
|
model_type: ModelType,
|
||||||
|
submodel: SubModelType,
|
||||||
model_info: Optional[SDModelInfo] = None,
|
model_info: Optional[SDModelInfo] = None,
|
||||||
):
|
):
|
||||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||||
@ -555,6 +551,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
node=node.dict(),
|
node=node.dict(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
submodel=submodel,
|
submodel=submodel,
|
||||||
model_info=model_info
|
model_info=model_info
|
||||||
@ -565,6 +562,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
node=node.dict(),
|
node=node.dict(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
submodel=submodel,
|
submodel=submodel,
|
||||||
)
|
)
|
||||||
|
@ -9,5 +9,5 @@ from .generator import (
|
|||||||
Img2Img,
|
Img2Img,
|
||||||
Inpaint
|
Inpaint
|
||||||
)
|
)
|
||||||
from .model_management import ModelManager, ModelCache, SDModelType, SDModelInfo
|
from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, SDModelInfo
|
||||||
from .safety_checker import SafetyChecker
|
from .safety_checker import SafetyChecker
|
||||||
|
@ -2,4 +2,5 @@
|
|||||||
Initialization file for invokeai.backend.model_management
|
Initialization file for invokeai.backend.model_management
|
||||||
"""
|
"""
|
||||||
from .model_manager import ModelManager, SDModelInfo
|
from .model_manager import ModelManager, SDModelInfo
|
||||||
from .model_cache import ModelCache, SDModelType
|
from .model_cache import ModelCache
|
||||||
|
from .models import BaseModelType, ModelType, SubModelType
|
||||||
|
@ -39,7 +39,7 @@ from invokeai.app.services.config import get_invokeai_config
|
|||||||
|
|
||||||
from .lora import LoRAModel, TextualInversionModel
|
from .lora import LoRAModel, TextualInversionModel
|
||||||
|
|
||||||
from .models import BaseModelType, ModelType, SubModelType
|
from .models import BaseModelType, ModelType, SubModelType, ModelBase
|
||||||
|
|
||||||
|
|
||||||
# Maximum size of the cache, in gigs
|
# Maximum size of the cache, in gigs
|
||||||
|
@ -167,6 +167,8 @@ from huggingface_hub import scan_cache_dir
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util import CUDA_DEVICE, download_with_resume
|
from invokeai.backend.util import CUDA_DEVICE, download_with_resume
|
||||||
@ -539,14 +541,16 @@ class ModelManager(object):
|
|||||||
def del_model(
|
def del_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType.Diffusers,
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
delete_files: bool = False,
|
delete_files: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Delete the named model.
|
Delete the named model.
|
||||||
"""
|
"""
|
||||||
model_key = self.create_key(model_name, model_type)
|
raise Exception("TODO: del_model") # TODO: redo
|
||||||
model_cfg = self.pop(model_key, None)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
|
model_cfg = self.models.pop(model_key, None)
|
||||||
|
|
||||||
if model_cfg is None:
|
if model_cfg is None:
|
||||||
self.logger.error(
|
self.logger.error(
|
||||||
|
@ -24,7 +24,7 @@ class ModelType(str, Enum):
|
|||||||
ControlNet = "controlnet"
|
ControlNet = "controlnet"
|
||||||
TextualInversion = "embedding"
|
TextualInversion = "embedding"
|
||||||
|
|
||||||
class SubModelType:
|
class SubModelType(str, Enum):
|
||||||
UNet = "unet"
|
UNet = "unet"
|
||||||
TextEncoder = "text_encoder"
|
TextEncoder = "text_encoder"
|
||||||
Tokenizer = "tokenizer"
|
Tokenizer = "tokenizer"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user