services rewritten; starting work on routes

This commit is contained in:
Lincoln Stein
2023-09-15 18:22:24 -04:00
parent a033ccc776
commit 3529925234
11 changed files with 224 additions and 195 deletions

View File

@ -114,7 +114,7 @@ class ApiDependencies:
)
services = InvocationServices(
model_manager=ModelManagerService(config, logger),
model_manager=ModelManagerService(config, events),
events=events,
latents=latents,
images=images,

View File

@ -10,11 +10,12 @@ from pydantic import BaseModel, parse_obj_as
from starlette.exceptions import HTTPException
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import MergeInterpolationMethod
from invokeai.backend.model_management.models import (
from invokeai.backend.model_manager import MergeInterpolationMethod
from invokeai.backend.model_manager import (
OPENAPI_MODEL_CONFIGS,
ModelConfigBase,
InvalidModelException,
ModelNotFoundException,
UnknownModelException,
SchedulerPredictionType,
)
@ -28,9 +29,12 @@ ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
# class ModelsList(BaseModel):
# models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
class ModelsList(BaseModel):
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
models: List[ModelConfigBase]
@models_router.get(
@ -42,13 +46,14 @@ async def list_models(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
) -> ModelsList:
"""Gets a list of models"""
"""Get a list of models."""
manager = ApiDependencies.invoker.services.model_manager
if base_models and len(base_models) > 0:
models_raw = list()
for base_model in base_models:
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
models_raw.extend(manager.list_models(base_model=base_model, model_type=model_type))
else:
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
models_raw = manager.list_models(model_type=model_type)
models = parse_obj_as(ModelsList, {"models": models_raw})
return models
@ -118,7 +123,7 @@ async def update_model(
model_type=model_type,
)
model_response = parse_obj_as(UpdateModelResponse, model_raw)
except ModelNotFoundException as e:
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
@ -171,7 +176,7 @@ async def import_model(
)
return parse_obj_as(ImportModelResponse, model_raw)
except ModelNotFoundException as e:
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except InvalidModelException as e:
@ -210,7 +215,7 @@ async def add_model(
model_name=info.model_name, base_model=info.base_model, model_type=info.model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except ModelNotFoundException as e:
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
@ -239,7 +244,7 @@ async def delete_model(
)
logger.info(f"Deleted model: {model_name}")
return Response(status_code=204)
except ModelNotFoundException as e:
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
@ -278,7 +283,7 @@ async def convert_model(
model_name, base_model=base_model, model_type=model_type
)
response = parse_obj_as(ConvertModelResponse, model_raw)
except ModelNotFoundException as e:
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@ -380,7 +385,7 @@ async def merge_models(
model_type=ModelType.Main,
)
response = parse_obj_as(ConvertModelResponse, model_raw)
except ModelNotFoundException:
except UnknownModelException:
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

View File

@ -147,7 +147,7 @@ def custom_openapi():
invoker_schema["output"] = outputs_ref
invoker_schema["class"] = "invocation"
from invokeai.backend.model_management.models import get_model_config_enums
from invokeai.backend.model_manager.models import get_model_config_enums
for model_config_format_enum in set(get_model_config_enums()):
name = model_config_format_enum.__qualname__

View File

@ -3,7 +3,7 @@ from typing import List, Optional
from pydantic import BaseModel, Field
from ...backend.model_manager import BaseModelType, ModelType, SubModelType
from ...backend.model_manager import SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -19,9 +19,7 @@ from .baseinvocation import (
class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel")
base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Info to load submodel")
key: str = Field(description="Unique ID for model")
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
@ -61,16 +59,13 @@ class ModelLoaderOutput(BaseInvocationOutput):
class MainModelField(BaseModel):
"""Main model field"""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Model Type")
key: str = Field(description="Unique ID of the model")
class LoRAModelField(BaseModel):
"""LoRA model field"""
model_name: str = Field(description="Name of the LoRA model")
base_model: BaseModelType = Field(description="Base model")
key: str = Field(description="Unique ID for model")
@invocation("main_model_loader", title="Main Model", tags=["model"], category="model", version="1.0.0")
@ -81,17 +76,12 @@ class MainModelLoaderInvocation(BaseInvocation):
# TODO: precision?
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
"""Load a main model, outputting its submodels."""
key = self.model.key
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
if not context.services.model_manager.model_exists(key):
raise Exception(f"Unknown model {key}")
"""
if not context.services.model_manager.model_exists(
@ -125,30 +115,22 @@ class MainModelLoaderInvocation(BaseInvocation):
return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=key,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=key,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=key,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=key,
submodel=SubModelType.TextEncoder,
),
loras=[],
@ -156,9 +138,7 @@ class MainModelLoaderInvocation(BaseInvocation):
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=key,
submodel=SubModelType.Vae,
),
),
@ -167,7 +147,7 @@ class MainModelLoaderInvocation(BaseInvocation):
@invocation_output("lora_loader_output")
class LoraLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
"""Model loader output."""
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@ -187,24 +167,20 @@ class LoraLoaderInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
"""Load a LoRA model."""
if self.lora is None:
raise Exception("No LoRA provided")
base_model = self.lora.base_model
lora_name = self.lora.model_name
key = self.lora.key
if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
):
raise Exception(f"Unkown lora name: {lora_name}!")
if not context.services.model_manager.model_exists(key):
raise Exception(f"Unkown lora: {key}!")
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
raise Exception(f'Lora "{lora_name}" already applied to unet')
if self.unet is not None and any(lora.key == key for lora in self.unet.loras):
raise Exception(f'Lora "{key}" already applied to unet')
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
raise Exception(f'Lora "{lora_name}" already applied to clip')
if self.clip is not None and any(lora.key == key for lora in self.clip.loras):
raise Exception(f'Lora "{key}" already applied to clip')
output = LoraLoaderOutput()
@ -212,9 +188,7 @@ class LoraLoaderInvocation(BaseInvocation):
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
key=key,
submodel=None,
weight=self.weight,
)
@ -224,9 +198,7 @@ class LoraLoaderInvocation(BaseInvocation):
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
key=key,
submodel=None,
weight=self.weight,
)
@ -237,7 +209,7 @@ class LoraLoaderInvocation(BaseInvocation):
@invocation_output("sdxl_lora_loader_output")
class SDXLLoraLoaderOutput(BaseInvocationOutput):
"""SDXL LoRA Loader Output"""
"""SDXL LoRA Loader Output."""
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
@ -261,27 +233,22 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
"""Load an SDXL LoRA."""
if self.lora is None:
raise Exception("No LoRA provided")
base_model = self.lora.base_model
lora_name = self.lora.model_name
key = self.lora.key
if not context.services.model_manager.model_exists(key):
raise Exception(f"Unknown lora name: {key}!")
if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
):
raise Exception(f"Unknown lora name: {lora_name}!")
if self.unet is not None and any(lora.key == key for lora in self.unet.loras):
raise Exception(f'Lora "{key}" already applied to unet')
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
raise Exception(f'Lora "{lora_name}" already applied to unet')
if self.clip is not None and any(lora.key == key for lora in self.clip.loras):
raise Exception(f'Lora "{key}" already applied to clip')
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
raise Exception(f'Lora "{lora_name}" already applied to clip')
if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras):
raise Exception(f'Lora "{lora_name}" already applied to clip2')
if self.clip2 is not None and any(lora.key == key for lora in self.clip2.loras):
raise Exception(f'Lora "{key}" already applied to clip2')
output = SDXLLoraLoaderOutput()
@ -289,9 +256,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
key=key,
submodel=None,
weight=self.weight,
)
@ -301,9 +266,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
key=key,
submodel=None,
weight=self.weight,
)
@ -313,9 +276,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
output.clip2 = copy.deepcopy(self.clip2)
output.clip2.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
key=key,
submodel=None,
weight=self.weight,
)
@ -325,10 +286,9 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
class VAEModelField(BaseModel):
"""Vae model field"""
"""Vae model field."""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
key: str = Field(description="Unique ID for VAE model")
@invocation_output("vae_loader_output")
@ -340,29 +300,22 @@ class VaeLoaderOutput(BaseInvocationOutput):
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
"""Loads a VAE model, outputting a VaeLoaderOutput."""
vae_model: VAEModelField = InputField(
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE"
)
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
base_model = self.vae_model.base_model
model_name = self.vae_model.model_name
model_type = ModelType.Vae
"""Load a VAE model."""
key = self.vae_model.key
if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=model_name,
model_type=model_type,
):
raise Exception(f"Unkown vae name: {model_name}!")
if not context.services.model_manager.model_exists(key):
raise Exception(f"Unkown vae name: {key}!")
return VaeLoaderOutput(
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=key,
)
)
)
@ -370,7 +323,7 @@ class VaeLoaderInvocation(BaseInvocation):
@invocation_output("seamless_output")
class SeamlessModeOutput(BaseInvocationOutput):
"""Modified Seamless Model output"""
"""Modified Seamless Model output."""
unet: Optional[UNetField] = OutputField(description=FieldDescriptions.unet, title="UNet")
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
@ -390,6 +343,7 @@ class SeamlessModeInvocation(BaseInvocation):
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
"""Apply seamless transformation."""
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
unet = copy.deepcopy(self.unet)
vae = copy.deepcopy(self.vae)

View File

@ -117,9 +117,7 @@ class EventServiceBase:
def emit_model_load_started(
self,
graph_execution_state_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_key: str,
submodel: SubModelType,
) -> None:
"""Emitted when a model is requested"""
@ -127,9 +125,7 @@ class EventServiceBase:
event_name="model_load_started",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_key=model_key,
submodel=submodel,
),
)
@ -137,9 +133,7 @@ class EventServiceBase:
def emit_model_load_completed(
self,
graph_execution_state_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_key: str,
submodel: SubModelType,
model_info: ModelInfo,
) -> None:
@ -148,9 +142,7 @@ class EventServiceBase:
event_name="model_load_completed",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_key=model_key,
submodel=submodel,
hash=model_info.hash,
location=str(model_info.location),
@ -192,8 +184,8 @@ class EventServiceBase:
),
)
def emit_model_download_event(self, job: DownloadJobBase):
"""Emit event when the status of a download job changes."""
def emit_model_event(self, job: DownloadJobBase):
"""Emit event when the status of a download/install job changes."""
self.dispatch( # use dispatch() directly here because we are not a session event.
event_name="install_model_event", payload=dict(job=job)
event_name="model_event", payload=dict(job=job)
)

View File

@ -28,24 +28,23 @@ from pydantic.networks import AnyHttpUrl
from invokeai.app.models.exceptions import CanceledException
from .config import InvokeAIAppConfig
from .events import EventServiceBase
if TYPE_CHECKING:
from ..invocations.baseinvocation import InvocationContext
class ModelManagerServiceBase(ABC):
"""Responsible for managing models on disk and in memory"""
"""Responsible for managing models on disk and in memory."""
@abstractmethod
def __init__(
self,
config: InvokeAIAppConfig,
):
def __init__(self, config: InvokeAIAppConfig, event_bus: Optional[EventServiceBase] = None):
"""
Initialize with the path to the models.yaml config file.
Optional parameters are the torch device type, precision, max_models,
and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision.
Initialize a ModelManagerService.
:param config: InvokeAIAppConfig object
:param event_bus: Optional EventServiceBase object. If provided,
installation and download events will be sent to the event bus.
"""
pass
@ -102,6 +101,7 @@ class ModelManagerServiceBase(ABC):
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> ModelConfigBase:
"""
Return information about the model using the same format as list_models().
If there are more than one model that match, raises a DuplicateModelException.
If no model matches, raises an UnknownModelException
"""
@ -115,9 +115,7 @@ class ModelManagerServiceBase(ABC):
return model_configs[0]
def all_models(self) -> List[ModelConfigBase]:
"""
Returns a list of all the models.
"""
"""Return a list of all the models."""
return self.list_models()
@abstractmethod
@ -125,8 +123,9 @@ class ModelManagerServiceBase(ABC):
self, model_path: Path, probe_overrides: Optional[Dict[str, Any]] = None, wait: bool = False
) -> ModelInstallJob:
"""
Add a model using its path, with a dictionary of attributes. Will fail with an
assertion error if the name already exists.
Add a model using its path, with a dictionary of attributes.
Will fail with an assertion error if the name already exists.
"""
pass
@ -167,9 +166,7 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def list_checkpoint_configs(self) -> List[Path]:
"""
List the checkpoint config paths from ROOT/configs/stable-diffusion.
"""
"""List the checkpoint config paths from ROOT/configs/stable-diffusion."""
pass
@abstractmethod
@ -248,6 +245,8 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def sync_to_config(self):
"""
Synchronize the in-memory models with on-disk.
Re-read models.yaml, rescan the models directory, and reimport models
in the autoimport directories. Call after making changes outside the
model manager API.
@ -256,29 +255,57 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def collect_cache_stats(self, cache_stats: CacheStats):
"""Reset model cache statistics for graph with graph_id."""
pass
@abstractmethod
def cancel_job(self, job: ModelInstallJob):
"""Cancel this job."""
pass
@abstractmethod
def pause_job(self, job: ModelInstallJob):
"""Pause this job."""
pass
@abstractmethod
def start_job(self, job: ModelInstallJob):
"""(re)start this job."""
pass
@abstractmethod
def change_priority(self, job: ModelInstallJob, delta: int):
"""
Reset model cache statistics for graph with graph_id.
Raise or lower the priority of the job.
:param job: Job to apply change to
:param delta: Value to increment or decrement priority.
Lower values are higher priority. The default starting value is 10.
Thus to make my_job a really high priority job:
manager.change_priority(my_job, -10).
"""
pass
# implementation
class ModelManagerService(ModelManagerServiceBase):
"""Responsible for managing models on disk and in memory"""
"""Responsible for managing models on disk and in memory."""
_loader: ModelLoader = Field(description="InvokeAIAppConfig object for the current process")
_event_bus: EventServiceBase = Field(description="an event bus to send install events to", default=None)
def __init__(
self,
config: InvokeAIAppConfig,
):
def __init__(self, config: InvokeAIAppConfig, event_bus: Optional[EventServiceBase] = None):
"""
Initialize with the path to the models.yaml config file.
Optional parameters are the torch device type, precision, max_models,
and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision.
Initialize a ModelManagerService.
:param config: InvokeAIAppConfig object
:param event_bus: Optional EventServiceBase object. If provided,
installation and download events will be sent to the event bus.
"""
self._loader = ModelLoader(config)
self._event_bus = event_bus
handlers = [self._event_bus.emit_model_event] if self._event_bus else None
self._loader = ModelLoader(config, event_handlers=handlers)
def get_model(
self,
@ -287,10 +314,10 @@ class ModelManagerService(ModelManagerServiceBase):
context: Optional[InvocationContext] = None,
) -> ModelInfo:
"""
Retrieve the indicated model. submodel can be used to get a
part (such as the vae) of a diffusers mode.
"""
Retrieve the indicated model.
The submodel is required when fetching a main model.
"""
model_info: ModelInfo = self._loader.get_model(key, submodel_type)
# we can emit model loading events if we are executing with access to the invocation context
@ -309,6 +336,8 @@ class ModelManagerService(ModelManagerServiceBase):
key: str,
) -> bool:
"""
Verify that a model with the given key exists.
Given a model key, returns True if it is a valid
identifier.
"""
@ -316,7 +345,9 @@ class ModelManagerService(ModelManagerServiceBase):
def model_info(self, key: str) -> ModelConfigBase:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
Return configuration information about a model.
Given a model key returns the ModelConfigBase describing it.
"""
return self._loader.store.get_model(key)
@ -332,12 +363,8 @@ class ModelManagerService(ModelManagerServiceBase):
"""
Return a ModelConfigBase object for each model in the database.
"""
return self._loader.store.search_by_name(model_name, base_model, model_type)
"""
Return information about the model using the same format as list_models()
"""
return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type)
self.logger.warning(f"list_model(model_type={model_type})")
return self._loader.store.search_by_name(model_name=model_name, base_model=base_model, model_type=model_type)
def add_model(
self, model_path: Path, model_attributes: Optional[dict] = None, wait: bool = False
@ -438,9 +465,7 @@ class ModelManagerService(ModelManagerServiceBase):
def _emit_load_event(
self,
context,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_key: str,
submodel: Optional[SubModelType] = None,
model_info: Optional[ModelInfo] = None,
):
@ -450,18 +475,14 @@ class ModelManagerService(ModelManagerServiceBase):
if model_info:
context.services.events.emit_model_load_completed(
graph_execution_state_id=context.graph_execution_state_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_key=model_key,
submodel=submodel,
model_info=model_info,
)
else:
context.services.events.emit_model_load_started(
graph_execution_state_id=context.graph_execution_state_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_key=model_key,
submodel=submodel,
)
@ -488,7 +509,7 @@ class ModelManagerService(ModelManagerServiceBase):
:param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
merger = ModelMerger(self.mgr)
merger = ModelMerger(self._loader)
try:
self.logger.error("ModelMerger needs to be rewritten.")
result = merger.merge_diffusion_models_and_save(
@ -506,11 +527,16 @@ class ModelManagerService(ModelManagerServiceBase):
def search_for_models(self, directory: Path) -> List[Path]:
"""
Return list of all models found in the designated directory.
:param directory: Path to the directory to recursively search.
returns a list of model paths
"""
return ModelSearch().search(directory)
def sync_to_config(self):
"""
Synchronize the model manager to the database.
Re-read models.yaml, rescan the models directory, and reimport models
in the autoimport directories. Call after making changes outside the
model manager API.
@ -518,10 +544,8 @@ class ModelManagerService(ModelManagerServiceBase):
return self._loader.sync_to_config()
def list_checkpoint_configs(self) -> List[Path]:
"""
List the checkpoint config paths from ROOT/configs/stable-diffusion.
"""
config = self.mgr.app_config
"""List the checkpoint config paths from ROOT/configs/stable-diffusion."""
config = self._loader.config
conf_path = config.legacy_conf_path
root_path = config.root_path
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
@ -538,3 +562,28 @@ class ModelManagerService(ModelManagerServiceBase):
:param new_name: New name for the model
"""
return self.update_model(key, {"name": new_name})
def cancel_job(self, job: ModelInstallJob):
"""Cancel this job."""
self._loader.queue.cancel_job(job)
def pause_job(self, job: ModelInstallJob):
"""Pause this job."""
self._loader.queue.pause_job(job)
def start_job(self, job: ModelInstallJob):
"""(re)start this job."""
self._loader.queue.start_job(job)
def change_priority(self, job: ModelInstallJob, delta: int):
"""
Raise or lower the priority of the job.
:param job: Job to apply change to
:param delta: Value to increment or decrement priority.
Lower values are higher priority. The default starting value is 10.
Thus to make my_job a really high priority job:
manager.change_priority(my_job, -10).
"""
self._loader.queue.change_priority(job, delta)

View File

@ -1,7 +1,7 @@
"""
Initialization file for invokeai.backend.model_manager.config
"""
from .models.base import read_checkpoint_meta # noqa F401
from .models import read_checkpoint_meta, OPENAPI_MODEL_CONFIGS # noqa F401
from .config import ( # noqa F401
BaseModelType,
InvalidModelConfigException,

View File

@ -83,10 +83,6 @@ class DownloadQueue(DownloadQueueBase):
_sequence: int = 0 # This is for debugging and used to tag jobs in dequeueing order
_requests: requests.sessions.Session
# for debugging
_gets: int = 0
_dones: int = 0
def __init__(
self,
max_parallel_dl: int = 5,
@ -112,10 +108,6 @@ class DownloadQueue(DownloadQueueBase):
self._start_workers(max_parallel_dl)
# debugging - get rid of this
self._gets = 0
self._dones = 0
def create_download_job(
self,
source: Union[str, Path, AnyHttpUrl],
@ -297,7 +289,6 @@ class DownloadQueue(DownloadQueueBase):
done = False
while not done:
job = self._queue.get()
self._gets += 1
try: # this is for debugging priority
self._lock.acquire()
@ -326,7 +317,6 @@ class DownloadQueue(DownloadQueueBase):
if self._in_terminal_state(job):
del self._jobs[job.id]
self._dones += 1
self._queue.task_done()
def _get_metadata_and_url(self, job: DownloadJobBase) -> AnyHttpUrl:

View File

@ -53,14 +53,20 @@ import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from shutil import rmtree
from typing import Optional, List, Union, Dict, Set, Any
from typing import Optional, List, Union, Dict, Set, Any, Callable
from pydantic import Field
from pydantic.networks import AnyHttpUrl
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from .search import ModelSearch
from .storage import ModelConfigStore, DuplicateModelException, get_config_store
from .download import DownloadQueueBase, DownloadQueue, DownloadJobBase, ModelSourceMetadata
from .download import (
DownloadQueueBase,
DownloadQueue,
DownloadJobBase,
ModelSourceMetadata,
DownloadEventHandler,
)
from .download.queue import DownloadJobURL, DownloadJobRepoID, DownloadJobPath
from .hash import FastModelHash
from .probe import ModelProbe, ModelProbeInfo, InvalidModelException
@ -97,6 +103,9 @@ class ModelInstallPathJob(DownloadJobPath, ModelInstallJob):
"""Job for installing local paths."""
ModelInstallEventHandler = Callable[["ModelInstallJob"], None]
class ModelInstallBase(ABC):
"""Abstract base class for InvokeAI model installation"""
@ -107,6 +116,7 @@ class ModelInstallBase(ABC):
config: Optional[InvokeAIAppConfig] = None,
logger: Optional[InvokeAILogger] = None,
download: Optional[DownloadQueueBase] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
):
"""
Create ModelInstall object.
@ -119,6 +129,7 @@ class ModelInstallBase(ABC):
uses the system-wide default logger.
:param download: Optional DownloadQueueBase object. If None passed,
a default queue object will be created.
:param event_handlers: List of event handlers to pass to the queue object.
"""
pass
@ -302,11 +313,12 @@ class ModelInstall(ModelInstallBase):
config: Optional[InvokeAIAppConfig] = None,
logger: Optional[InvokeAILogger] = None,
download: Optional[DownloadQueueBase] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
): # noqa D107 - use base class docstrings
self._config = config or InvokeAIAppConfig.get_config()
self._logger = logger or InvokeAILogger.getLogger(config=self._config)
self._store = store or get_config_store(self._config.model_conf_path)
self._download_queue = download or DownloadQueue(config=self._config)
self._download_queue = download or DownloadQueue(config=self._config, event_handlers=event_handlers)
self._async_installs = dict()
self._installed = set()
self._tmpdir = None

View File

@ -5,7 +5,7 @@ import hashlib
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Union, Optional
from typing import Union, Optional, List
import torch
@ -16,6 +16,7 @@ from .install import ModelInstallBase, ModelInstall
from .storage import ModelConfigStore, get_config_store
from .cache import ModelCache, ModelLocker, CacheStats
from .models import InvalidModelException, ModelBase, MODEL_CLASSES
from .download import DownloadEventHandler
@dataclass
@ -75,6 +76,17 @@ class ModelLoaderBase(ABC):
"""Return the current logger."""
pass
@property
def config(self) -> InvokeAIAppConfig:
"""Return the config object used by this installer."""
pass
@property
@abstractmethod
def queue(self) -> str:
"""Return the download queue object used by this object."""
pass
@abstractmethod
def collect_cache_stats(self, cache_stats: CacheStats):
"""Replace cache statistics."""
@ -110,6 +122,7 @@ class ModelLoader(ModelLoaderBase):
def __init__(
self,
config: InvokeAIAppConfig,
event_handlers: Optional[List[DownloadEventHandler]] = None,
):
"""
Initialize ModelLoader object.
@ -127,7 +140,12 @@ class ModelLoader(ModelLoaderBase):
self._app_config = config
self._store = store
self._logger = InvokeAILogger.getLogger()
self._installer = ModelInstall(store=self._store, logger=self._logger, config=self._app_config)
self._installer = ModelInstall(
store=self._store,
logger=self._logger,
config=self._app_config,
event_handlers=event_handlers,
)
self._cache_keys = dict()
self._models_file = models_file
device = torch.device(choose_torch_device())
@ -173,6 +191,16 @@ class ModelLoader(ModelLoaderBase):
"""Return the current logger."""
return self._logger
@property
def config(self) -> InvokeAIAppConfig:
"""Return the config object used by the installer."""
return self._app_config
@property
def queue(self) -> str:
"""Return the download queue object used by this object."""
return self._installer.queue
def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> ModelInfo:
"""
Get the ModelInfo corresponding to the model with key "key".

View File

@ -15,6 +15,7 @@ from .base import ( # noqa: F401
ModelVariantType,
SchedulerPredictionType,
InvalidModelException,
read_checkpoint_meta,
)
from .controlnet import ControlNetModel # TODO:
from .lora import LoRAModel
@ -73,9 +74,7 @@ OPENAPI_MODEL_CONFIGS = list()
class OpenAPIModelInfoBase(BaseModel):
model_name: str
base_model: BaseModelType
model_type: ModelType
key: str
for base_model, models in MODEL_CLASSES.items():