mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
services rewritten; starting work on routes
This commit is contained in:
@ -114,7 +114,7 @@ class ApiDependencies:
|
||||
)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=ModelManagerService(config, logger),
|
||||
model_manager=ModelManagerService(config, events),
|
||||
events=events,
|
||||
latents=latents,
|
||||
images=images,
|
||||
|
@ -10,11 +10,12 @@ from pydantic import BaseModel, parse_obj_as
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management import MergeInterpolationMethod
|
||||
from invokeai.backend.model_management.models import (
|
||||
from invokeai.backend.model_manager import MergeInterpolationMethod
|
||||
from invokeai.backend.model_manager import (
|
||||
OPENAPI_MODEL_CONFIGS,
|
||||
ModelConfigBase,
|
||||
InvalidModelException,
|
||||
ModelNotFoundException,
|
||||
UnknownModelException,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
|
||||
@ -28,9 +29,12 @@ ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
# class ModelsList(BaseModel):
|
||||
# models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
models: List[ModelConfigBase]
|
||||
|
||||
|
||||
@models_router.get(
|
||||
@ -42,13 +46,14 @@ async def list_models(
|
||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
) -> ModelsList:
|
||||
"""Gets a list of models"""
|
||||
"""Get a list of models."""
|
||||
manager = ApiDependencies.invoker.services.model_manager
|
||||
if base_models and len(base_models) > 0:
|
||||
models_raw = list()
|
||||
for base_model in base_models:
|
||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||
models_raw.extend(manager.list_models(base_model=base_model, model_type=model_type))
|
||||
else:
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
||||
models_raw = manager.list_models(model_type=model_type)
|
||||
models = parse_obj_as(ModelsList, {"models": models_raw})
|
||||
return models
|
||||
|
||||
@ -118,7 +123,7 @@ async def update_model(
|
||||
model_type=model_type,
|
||||
)
|
||||
model_response = parse_obj_as(UpdateModelResponse, model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
@ -171,7 +176,7 @@ async def import_model(
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
|
||||
except ModelNotFoundException as e:
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
@ -210,7 +215,7 @@ async def add_model(
|
||||
model_name=info.model_name, base_model=info.base_model, model_type=info.model_type
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
@ -239,7 +244,7 @@ async def delete_model(
|
||||
)
|
||||
logger.info(f"Deleted model: {model_name}")
|
||||
return Response(status_code=204)
|
||||
except ModelNotFoundException as e:
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@ -278,7 +283,7 @@ async def convert_model(
|
||||
model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@ -380,7 +385,7 @@ async def merge_models(
|
||||
model_type=ModelType.Main,
|
||||
)
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
except ModelNotFoundException:
|
||||
except UnknownModelException:
|
||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
@ -147,7 +147,7 @@ def custom_openapi():
|
||||
invoker_schema["output"] = outputs_ref
|
||||
invoker_schema["class"] = "invocation"
|
||||
|
||||
from invokeai.backend.model_management.models import get_model_config_enums
|
||||
from invokeai.backend.model_manager.models import get_model_config_enums
|
||||
|
||||
for model_config_format_enum in set(get_model_config_enums()):
|
||||
name = model_config_format_enum.__qualname__
|
||||
|
@ -3,7 +3,7 @@ from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ...backend.model_manager import BaseModelType, ModelType, SubModelType
|
||||
from ...backend.model_manager import SubModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@ -19,9 +19,7 @@ from .baseinvocation import (
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
model_name: str = Field(description="Info to load submodel")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Info to load submodel")
|
||||
key: str = Field(description="Unique ID for model")
|
||||
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||
|
||||
|
||||
@ -61,16 +59,13 @@ class ModelLoaderOutput(BaseInvocationOutput):
|
||||
class MainModelField(BaseModel):
|
||||
"""Main model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
key: str = Field(description="Unique ID of the model")
|
||||
|
||||
|
||||
class LoRAModelField(BaseModel):
|
||||
"""LoRA model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the LoRA model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
key: str = Field(description="Unique ID for model")
|
||||
|
||||
|
||||
@invocation("main_model_loader", title="Main Model", tags=["model"], category="model", version="1.0.0")
|
||||
@ -81,17 +76,12 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.Main
|
||||
"""Load a main model, outputting its submodels."""
|
||||
key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
if not context.services.model_manager.model_exists(key):
|
||||
raise Exception(f"Unknown model {key}")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
@ -125,30 +115,22 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
return ModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
@ -156,9 +138,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
submodel=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
@ -167,7 +147,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
@invocation_output("lora_loader_output")
|
||||
class LoraLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
"""Model loader output."""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
@ -187,24 +167,20 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||
"""Load a LoRA model."""
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
key = self.lora.key
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
):
|
||||
raise Exception(f"Unkown lora name: {lora_name}!")
|
||||
if not context.services.model_manager.model_exists(key):
|
||||
raise Exception(f"Unkown lora: {key}!")
|
||||
|
||||
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||
if self.unet is not None and any(lora.key == key for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{key}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||
if self.clip is not None and any(lora.key == key for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{key}" already applied to clip')
|
||||
|
||||
output = LoraLoaderOutput()
|
||||
|
||||
@ -212,9 +188,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
key=key,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
@ -224,9 +198,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
key=key,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
@ -237,7 +209,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
|
||||
@invocation_output("sdxl_lora_loader_output")
|
||||
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL LoRA Loader Output"""
|
||||
"""SDXL LoRA Loader Output."""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
|
||||
@ -261,27 +233,22 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||
"""Load an SDXL LoRA."""
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
key = self.lora.key
|
||||
if not context.services.model_manager.model_exists(key):
|
||||
raise Exception(f"Unknown lora name: {key}!")
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
):
|
||||
raise Exception(f"Unknown lora name: {lora_name}!")
|
||||
if self.unet is not None and any(lora.key == key for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{key}" already applied to unet')
|
||||
|
||||
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||
if self.clip is not None and any(lora.key == key for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{key}" already applied to clip')
|
||||
|
||||
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||
|
||||
if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip2')
|
||||
if self.clip2 is not None and any(lora.key == key for lora in self.clip2.loras):
|
||||
raise Exception(f'Lora "{key}" already applied to clip2')
|
||||
|
||||
output = SDXLLoraLoaderOutput()
|
||||
|
||||
@ -289,9 +256,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
key=key,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
@ -301,9 +266,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
key=key,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
@ -313,9 +276,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
output.clip2 = copy.deepcopy(self.clip2)
|
||||
output.clip2.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
key=key,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
@ -325,10 +286,9 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
|
||||
|
||||
class VAEModelField(BaseModel):
|
||||
"""Vae model field"""
|
||||
"""Vae model field."""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
key: str = Field(description="Unique ID for VAE model")
|
||||
|
||||
|
||||
@invocation_output("vae_loader_output")
|
||||
@ -340,29 +300,22 @@ class VaeLoaderOutput(BaseInvocationOutput):
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
|
||||
class VaeLoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput."""
|
||||
|
||||
vae_model: VAEModelField = InputField(
|
||||
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
|
||||
base_model = self.vae_model.base_model
|
||||
model_name = self.vae_model.model_name
|
||||
model_type = ModelType.Vae
|
||||
"""Load a VAE model."""
|
||||
key = self.vae_model.key
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unkown vae name: {model_name}!")
|
||||
if not context.services.model_manager.model_exists(key):
|
||||
raise Exception(f"Unkown vae name: {key}!")
|
||||
return VaeLoaderOutput(
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
)
|
||||
)
|
||||
)
|
||||
@ -370,7 +323,7 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
|
||||
@invocation_output("seamless_output")
|
||||
class SeamlessModeOutput(BaseInvocationOutput):
|
||||
"""Modified Seamless Model output"""
|
||||
"""Modified Seamless Model output."""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
@ -390,6 +343,7 @@ class SeamlessModeInvocation(BaseInvocation):
|
||||
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
|
||||
"""Apply seamless transformation."""
|
||||
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
|
||||
unet = copy.deepcopy(self.unet)
|
||||
vae = copy.deepcopy(self.vae)
|
||||
|
@ -117,9 +117,7 @@ class EventServiceBase:
|
||||
def emit_model_load_started(
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_key: str,
|
||||
submodel: SubModelType,
|
||||
) -> None:
|
||||
"""Emitted when a model is requested"""
|
||||
@ -127,9 +125,7 @@ class EventServiceBase:
|
||||
event_name="model_load_started",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_key=model_key,
|
||||
submodel=submodel,
|
||||
),
|
||||
)
|
||||
@ -137,9 +133,7 @@ class EventServiceBase:
|
||||
def emit_model_load_completed(
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_key: str,
|
||||
submodel: SubModelType,
|
||||
model_info: ModelInfo,
|
||||
) -> None:
|
||||
@ -148,9 +142,7 @@ class EventServiceBase:
|
||||
event_name="model_load_completed",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_key=model_key,
|
||||
submodel=submodel,
|
||||
hash=model_info.hash,
|
||||
location=str(model_info.location),
|
||||
@ -192,8 +184,8 @@ class EventServiceBase:
|
||||
),
|
||||
)
|
||||
|
||||
def emit_model_download_event(self, job: DownloadJobBase):
|
||||
"""Emit event when the status of a download job changes."""
|
||||
def emit_model_event(self, job: DownloadJobBase):
|
||||
"""Emit event when the status of a download/install job changes."""
|
||||
self.dispatch( # use dispatch() directly here because we are not a session event.
|
||||
event_name="install_model_event", payload=dict(job=job)
|
||||
event_name="model_event", payload=dict(job=job)
|
||||
)
|
||||
|
@ -28,24 +28,23 @@ from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from .config import InvokeAIAppConfig
|
||||
from .events import EventServiceBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
|
||||
|
||||
class ModelManagerServiceBase(ABC):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
"""Responsible for managing models on disk and in memory."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
):
|
||||
def __init__(self, config: InvokeAIAppConfig, event_bus: Optional[EventServiceBase] = None):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
Initialize a ModelManagerService.
|
||||
|
||||
:param config: InvokeAIAppConfig object
|
||||
:param event_bus: Optional EventServiceBase object. If provided,
|
||||
installation and download events will be sent to the event bus.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -102,6 +101,7 @@ class ModelManagerServiceBase(ABC):
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> ModelConfigBase:
|
||||
"""
|
||||
Return information about the model using the same format as list_models().
|
||||
|
||||
If there are more than one model that match, raises a DuplicateModelException.
|
||||
If no model matches, raises an UnknownModelException
|
||||
"""
|
||||
@ -115,9 +115,7 @@ class ModelManagerServiceBase(ABC):
|
||||
return model_configs[0]
|
||||
|
||||
def all_models(self) -> List[ModelConfigBase]:
|
||||
"""
|
||||
Returns a list of all the models.
|
||||
"""
|
||||
"""Return a list of all the models."""
|
||||
return self.list_models()
|
||||
|
||||
@abstractmethod
|
||||
@ -125,8 +123,9 @@ class ModelManagerServiceBase(ABC):
|
||||
self, model_path: Path, probe_overrides: Optional[Dict[str, Any]] = None, wait: bool = False
|
||||
) -> ModelInstallJob:
|
||||
"""
|
||||
Add a model using its path, with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists.
|
||||
Add a model using its path, with a dictionary of attributes.
|
||||
|
||||
Will fail with an assertion error if the name already exists.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -167,9 +166,7 @@ class ModelManagerServiceBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def list_checkpoint_configs(self) -> List[Path]:
|
||||
"""
|
||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||
"""
|
||||
"""List the checkpoint config paths from ROOT/configs/stable-diffusion."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -248,6 +245,8 @@ class ModelManagerServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Synchronize the in-memory models with on-disk.
|
||||
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
@ -256,29 +255,57 @@ class ModelManagerServiceBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""Reset model cache statistics for graph with graph_id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_job(self, job: ModelInstallJob):
|
||||
"""Cancel this job."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pause_job(self, job: ModelInstallJob):
|
||||
"""Pause this job."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self, job: ModelInstallJob):
|
||||
"""(re)start this job."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def change_priority(self, job: ModelInstallJob, delta: int):
|
||||
"""
|
||||
Reset model cache statistics for graph with graph_id.
|
||||
Raise or lower the priority of the job.
|
||||
|
||||
:param job: Job to apply change to
|
||||
:param delta: Value to increment or decrement priority.
|
||||
|
||||
Lower values are higher priority. The default starting value is 10.
|
||||
Thus to make my_job a really high priority job:
|
||||
manager.change_priority(my_job, -10).
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# implementation
|
||||
class ModelManagerService(ModelManagerServiceBase):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
"""Responsible for managing models on disk and in memory."""
|
||||
|
||||
_loader: ModelLoader = Field(description="InvokeAIAppConfig object for the current process")
|
||||
_event_bus: EventServiceBase = Field(description="an event bus to send install events to", default=None)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
):
|
||||
def __init__(self, config: InvokeAIAppConfig, event_bus: Optional[EventServiceBase] = None):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
Initialize a ModelManagerService.
|
||||
|
||||
:param config: InvokeAIAppConfig object
|
||||
:param event_bus: Optional EventServiceBase object. If provided,
|
||||
installation and download events will be sent to the event bus.
|
||||
"""
|
||||
self._loader = ModelLoader(config)
|
||||
self._event_bus = event_bus
|
||||
handlers = [self._event_bus.emit_model_event] if self._event_bus else None
|
||||
self._loader = ModelLoader(config, event_handlers=handlers)
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
@ -287,10 +314,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""
|
||||
Retrieve the indicated model. submodel can be used to get a
|
||||
part (such as the vae) of a diffusers mode.
|
||||
"""
|
||||
Retrieve the indicated model.
|
||||
|
||||
The submodel is required when fetching a main model.
|
||||
"""
|
||||
model_info: ModelInfo = self._loader.get_model(key, submodel_type)
|
||||
|
||||
# we can emit model loading events if we are executing with access to the invocation context
|
||||
@ -309,6 +336,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
key: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Verify that a model with the given key exists.
|
||||
|
||||
Given a model key, returns True if it is a valid
|
||||
identifier.
|
||||
"""
|
||||
@ -316,7 +345,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
|
||||
def model_info(self, key: str) -> ModelConfigBase:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
Return configuration information about a model.
|
||||
|
||||
Given a model key returns the ModelConfigBase describing it.
|
||||
"""
|
||||
return self._loader.store.get_model(key)
|
||||
|
||||
@ -332,12 +363,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
"""
|
||||
Return a ModelConfigBase object for each model in the database.
|
||||
"""
|
||||
return self._loader.store.search_by_name(model_name, base_model, model_type)
|
||||
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
"""
|
||||
return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type)
|
||||
self.logger.warning(f"list_model(model_type={model_type})")
|
||||
return self._loader.store.search_by_name(model_name=model_name, base_model=base_model, model_type=model_type)
|
||||
|
||||
def add_model(
|
||||
self, model_path: Path, model_attributes: Optional[dict] = None, wait: bool = False
|
||||
@ -438,9 +465,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_key: str,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
):
|
||||
@ -450,18 +475,14 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
if model_info:
|
||||
context.services.events.emit_model_load_completed(
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_key=model_key,
|
||||
submodel=submodel,
|
||||
model_info=model_info,
|
||||
)
|
||||
else:
|
||||
context.services.events.emit_model_load_started(
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_key=model_key,
|
||||
submodel=submodel,
|
||||
)
|
||||
|
||||
@ -488,7 +509,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
merger = ModelMerger(self.mgr)
|
||||
merger = ModelMerger(self._loader)
|
||||
try:
|
||||
self.logger.error("ModelMerger needs to be rewritten.")
|
||||
result = merger.merge_diffusion_models_and_save(
|
||||
@ -506,11 +527,16 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def search_for_models(self, directory: Path) -> List[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
|
||||
:param directory: Path to the directory to recursively search.
|
||||
returns a list of model paths
|
||||
"""
|
||||
return ModelSearch().search(directory)
|
||||
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Synchronize the model manager to the database.
|
||||
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
@ -518,10 +544,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
return self._loader.sync_to_config()
|
||||
|
||||
def list_checkpoint_configs(self) -> List[Path]:
|
||||
"""
|
||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||
"""
|
||||
config = self.mgr.app_config
|
||||
"""List the checkpoint config paths from ROOT/configs/stable-diffusion."""
|
||||
config = self._loader.config
|
||||
conf_path = config.legacy_conf_path
|
||||
root_path = config.root_path
|
||||
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
|
||||
@ -538,3 +562,28 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
:param new_name: New name for the model
|
||||
"""
|
||||
return self.update_model(key, {"name": new_name})
|
||||
|
||||
def cancel_job(self, job: ModelInstallJob):
|
||||
"""Cancel this job."""
|
||||
self._loader.queue.cancel_job(job)
|
||||
|
||||
def pause_job(self, job: ModelInstallJob):
|
||||
"""Pause this job."""
|
||||
self._loader.queue.pause_job(job)
|
||||
|
||||
def start_job(self, job: ModelInstallJob):
|
||||
"""(re)start this job."""
|
||||
self._loader.queue.start_job(job)
|
||||
|
||||
def change_priority(self, job: ModelInstallJob, delta: int):
|
||||
"""
|
||||
Raise or lower the priority of the job.
|
||||
|
||||
:param job: Job to apply change to
|
||||
:param delta: Value to increment or decrement priority.
|
||||
|
||||
Lower values are higher priority. The default starting value is 10.
|
||||
Thus to make my_job a really high priority job:
|
||||
manager.change_priority(my_job, -10).
|
||||
"""
|
||||
self._loader.queue.change_priority(job, delta)
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_manager.config
|
||||
"""
|
||||
from .models.base import read_checkpoint_meta # noqa F401
|
||||
from .models import read_checkpoint_meta, OPENAPI_MODEL_CONFIGS # noqa F401
|
||||
from .config import ( # noqa F401
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
|
@ -83,10 +83,6 @@ class DownloadQueue(DownloadQueueBase):
|
||||
_sequence: int = 0 # This is for debugging and used to tag jobs in dequeueing order
|
||||
_requests: requests.sessions.Session
|
||||
|
||||
# for debugging
|
||||
_gets: int = 0
|
||||
_dones: int = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_parallel_dl: int = 5,
|
||||
@ -112,10 +108,6 @@ class DownloadQueue(DownloadQueueBase):
|
||||
|
||||
self._start_workers(max_parallel_dl)
|
||||
|
||||
# debugging - get rid of this
|
||||
self._gets = 0
|
||||
self._dones = 0
|
||||
|
||||
def create_download_job(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
@ -297,7 +289,6 @@ class DownloadQueue(DownloadQueueBase):
|
||||
done = False
|
||||
while not done:
|
||||
job = self._queue.get()
|
||||
self._gets += 1
|
||||
|
||||
try: # this is for debugging priority
|
||||
self._lock.acquire()
|
||||
@ -326,7 +317,6 @@ class DownloadQueue(DownloadQueueBase):
|
||||
if self._in_terminal_state(job):
|
||||
del self._jobs[job.id]
|
||||
|
||||
self._dones += 1
|
||||
self._queue.task_done()
|
||||
|
||||
def _get_metadata_and_url(self, job: DownloadJobBase) -> AnyHttpUrl:
|
||||
|
@ -53,14 +53,20 @@ import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Optional, List, Union, Dict, Set, Any
|
||||
from typing import Optional, List, Union, Dict, Set, Any, Callable
|
||||
from pydantic import Field
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from .search import ModelSearch
|
||||
from .storage import ModelConfigStore, DuplicateModelException, get_config_store
|
||||
from .download import DownloadQueueBase, DownloadQueue, DownloadJobBase, ModelSourceMetadata
|
||||
from .download import (
|
||||
DownloadQueueBase,
|
||||
DownloadQueue,
|
||||
DownloadJobBase,
|
||||
ModelSourceMetadata,
|
||||
DownloadEventHandler,
|
||||
)
|
||||
from .download.queue import DownloadJobURL, DownloadJobRepoID, DownloadJobPath
|
||||
from .hash import FastModelHash
|
||||
from .probe import ModelProbe, ModelProbeInfo, InvalidModelException
|
||||
@ -97,6 +103,9 @@ class ModelInstallPathJob(DownloadJobPath, ModelInstallJob):
|
||||
"""Job for installing local paths."""
|
||||
|
||||
|
||||
ModelInstallEventHandler = Callable[["ModelInstallJob"], None]
|
||||
|
||||
|
||||
class ModelInstallBase(ABC):
|
||||
"""Abstract base class for InvokeAI model installation"""
|
||||
|
||||
@ -107,6 +116,7 @@ class ModelInstallBase(ABC):
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
logger: Optional[InvokeAILogger] = None,
|
||||
download: Optional[DownloadQueueBase] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
):
|
||||
"""
|
||||
Create ModelInstall object.
|
||||
@ -119,6 +129,7 @@ class ModelInstallBase(ABC):
|
||||
uses the system-wide default logger.
|
||||
:param download: Optional DownloadQueueBase object. If None passed,
|
||||
a default queue object will be created.
|
||||
:param event_handlers: List of event handlers to pass to the queue object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -302,11 +313,12 @@ class ModelInstall(ModelInstallBase):
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
logger: Optional[InvokeAILogger] = None,
|
||||
download: Optional[DownloadQueueBase] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
): # noqa D107 - use base class docstrings
|
||||
self._config = config or InvokeAIAppConfig.get_config()
|
||||
self._logger = logger or InvokeAILogger.getLogger(config=self._config)
|
||||
self._store = store or get_config_store(self._config.model_conf_path)
|
||||
self._download_queue = download or DownloadQueue(config=self._config)
|
||||
self._download_queue = download or DownloadQueue(config=self._config, event_handlers=event_handlers)
|
||||
self._async_installs = dict()
|
||||
self._installed = set()
|
||||
self._tmpdir = None
|
||||
|
@ -5,7 +5,7 @@ import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Union, Optional
|
||||
from typing import Union, Optional, List
|
||||
|
||||
import torch
|
||||
|
||||
@ -16,6 +16,7 @@ from .install import ModelInstallBase, ModelInstall
|
||||
from .storage import ModelConfigStore, get_config_store
|
||||
from .cache import ModelCache, ModelLocker, CacheStats
|
||||
from .models import InvalidModelException, ModelBase, MODEL_CLASSES
|
||||
from .download import DownloadEventHandler
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -75,6 +76,17 @@ class ModelLoaderBase(ABC):
|
||||
"""Return the current logger."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def config(self) -> InvokeAIAppConfig:
|
||||
"""Return the config object used by this installer."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def queue(self) -> str:
|
||||
"""Return the download queue object used by this object."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""Replace cache statistics."""
|
||||
@ -110,6 +122,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize ModelLoader object.
|
||||
@ -127,7 +140,12 @@ class ModelLoader(ModelLoaderBase):
|
||||
self._app_config = config
|
||||
self._store = store
|
||||
self._logger = InvokeAILogger.getLogger()
|
||||
self._installer = ModelInstall(store=self._store, logger=self._logger, config=self._app_config)
|
||||
self._installer = ModelInstall(
|
||||
store=self._store,
|
||||
logger=self._logger,
|
||||
config=self._app_config,
|
||||
event_handlers=event_handlers,
|
||||
)
|
||||
self._cache_keys = dict()
|
||||
self._models_file = models_file
|
||||
device = torch.device(choose_torch_device())
|
||||
@ -173,6 +191,16 @@ class ModelLoader(ModelLoaderBase):
|
||||
"""Return the current logger."""
|
||||
return self._logger
|
||||
|
||||
@property
|
||||
def config(self) -> InvokeAIAppConfig:
|
||||
"""Return the config object used by the installer."""
|
||||
return self._app_config
|
||||
|
||||
@property
|
||||
def queue(self) -> str:
|
||||
"""Return the download queue object used by this object."""
|
||||
return self._installer.queue
|
||||
|
||||
def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> ModelInfo:
|
||||
"""
|
||||
Get the ModelInfo corresponding to the model with key "key".
|
||||
|
@ -15,6 +15,7 @@ from .base import ( # noqa: F401
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
InvalidModelException,
|
||||
read_checkpoint_meta,
|
||||
)
|
||||
from .controlnet import ControlNetModel # TODO:
|
||||
from .lora import LoRAModel
|
||||
@ -73,9 +74,7 @@ OPENAPI_MODEL_CONFIGS = list()
|
||||
|
||||
|
||||
class OpenAPIModelInfoBase(BaseModel):
|
||||
model_name: str
|
||||
base_model: BaseModelType
|
||||
model_type: ModelType
|
||||
key: str
|
||||
|
||||
|
||||
for base_model, models in MODEL_CLASSES.items():
|
||||
|
Reference in New Issue
Block a user