mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
further changes for ruff
This commit is contained in:
parent
8f4f4d48d5
commit
8ef596eac7
@ -88,7 +88,9 @@ class ApiDependencies:
|
|||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||||
model_manager = ModelManagerService(config, logger)
|
model_manager = ModelManagerService(config, logger)
|
||||||
model_record_service = ModelRecordServiceSQL(db=db)
|
model_record_service = ModelRecordServiceSQL(db=db)
|
||||||
model_install_service = ModelInstallService(app_config=config, record_store=model_record_service, event_bus=events)
|
model_install_service = ModelInstallService(
|
||||||
|
app_config=config, record_store=model_record_service, event_bus=events
|
||||||
|
)
|
||||||
names = SimpleNameService()
|
names = SimpleNameService()
|
||||||
performance_statistics = InvocationStatsService()
|
performance_statistics = InvocationStatsService()
|
||||||
processor = DefaultInvocationProcessor()
|
processor = DefaultInvocationProcessor()
|
||||||
|
@ -51,7 +51,9 @@ async def list_model_records(
|
|||||||
found_models: list[AnyModelConfig] = []
|
found_models: list[AnyModelConfig] = []
|
||||||
if base_models:
|
if base_models:
|
||||||
for base_model in base_models:
|
for base_model in base_models:
|
||||||
found_models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name))
|
found_models.extend(
|
||||||
|
record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
found_models.extend(record_store.search_by_attr(model_type=model_type, model_name=model_name))
|
found_models.extend(record_store.search_by_attr(model_type=model_type, model_name=model_name))
|
||||||
return ModelsList(models=found_models)
|
return ModelsList(models=found_models)
|
||||||
@ -184,25 +186,25 @@ async def add_model_record(
|
|||||||
status_code=201,
|
status_code=201,
|
||||||
)
|
)
|
||||||
async def import_model(
|
async def import_model(
|
||||||
source: ModelSource = Body(
|
source: ModelSource = Body(
|
||||||
description="A model path, repo_id or URL to import. NOTE: only model path is implemented currently!"
|
description="A model path, repo_id or URL to import. NOTE: only model path is implemented currently!"
|
||||||
),
|
),
|
||||||
config: Optional[Dict[str, Any]] = Body(
|
config: Optional[Dict[str, Any]] = Body(
|
||||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||||
default=None,
|
default=None,
|
||||||
),
|
),
|
||||||
variant: Optional[str] = Body(
|
variant: Optional[str] = Body(
|
||||||
description="When fetching a repo_id, force variant type to fetch such as 'fp16'",
|
description="When fetching a repo_id, force variant type to fetch such as 'fp16'",
|
||||||
default=None,
|
default=None,
|
||||||
),
|
),
|
||||||
subfolder: Optional[str] = Body(
|
subfolder: Optional[str] = Body(
|
||||||
description="When fetching a repo_id, specify subfolder to fetch model from",
|
description="When fetching a repo_id, specify subfolder to fetch model from",
|
||||||
default=None,
|
default=None,
|
||||||
),
|
),
|
||||||
access_token: Optional[str] = Body(
|
access_token: Optional[str] = Body(
|
||||||
description="When fetching a repo_id or URL, access token for web access",
|
description="When fetching a repo_id or URL, access token for web access",
|
||||||
default=None,
|
default=None,
|
||||||
),
|
),
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
"""Add a model using its local path, repo_id, or remote URL.
|
"""Add a model using its local path, repo_id, or remote URL.
|
||||||
|
|
||||||
@ -250,14 +252,16 @@ async def import_model(
|
|||||||
raise HTTPException(status_code=409, detail=str(e))
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@model_records_router.get(
|
@model_records_router.get(
|
||||||
"/import",
|
"/import",
|
||||||
operation_id="list_model_install_jobs",
|
operation_id="list_model_install_jobs",
|
||||||
)
|
)
|
||||||
async def list_model_install_jobs(
|
async def list_model_install_jobs(
|
||||||
source: Optional[str] = Query(description="Filter list by install source, partial string match.",
|
source: Optional[str] = Query(
|
||||||
default=None,
|
description="Filter list by install source, partial string match.",
|
||||||
)
|
default=None,
|
||||||
|
),
|
||||||
) -> List[ModelInstallJob]:
|
) -> List[ModelInstallJob]:
|
||||||
"""
|
"""
|
||||||
Return list of model install jobs.
|
Return list of model install jobs.
|
||||||
@ -268,6 +272,7 @@ async def list_model_install_jobs(
|
|||||||
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs(source)
|
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs(source)
|
||||||
return jobs
|
return jobs
|
||||||
|
|
||||||
|
|
||||||
@model_records_router.patch(
|
@model_records_router.patch(
|
||||||
"/import",
|
"/import",
|
||||||
operation_id="prune_model_install_jobs",
|
operation_id="prune_model_install_jobs",
|
||||||
@ -276,14 +281,14 @@ async def list_model_install_jobs(
|
|||||||
400: {"description": "Bad request"},
|
400: {"description": "Bad request"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def prune_model_install_jobs(
|
async def prune_model_install_jobs() -> Response:
|
||||||
) -> Response:
|
|
||||||
"""
|
"""
|
||||||
Prune all completed and errored jobs from the install job list.
|
Prune all completed and errored jobs from the install job list.
|
||||||
"""
|
"""
|
||||||
ApiDependencies.invoker.services.model_install.prune_jobs()
|
ApiDependencies.invoker.services.model_install.prune_jobs()
|
||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
|
||||||
@model_records_router.patch(
|
@model_records_router.patch(
|
||||||
"/sync",
|
"/sync",
|
||||||
operation_id="sync_models_to_config",
|
operation_id="sync_models_to_config",
|
||||||
@ -292,8 +297,7 @@ async def prune_model_install_jobs(
|
|||||||
400: {"description": "Bad request"},
|
400: {"description": "Bad request"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def sync_models_to_config(
|
async def sync_models_to_config() -> Response:
|
||||||
) -> Response:
|
|
||||||
"""
|
"""
|
||||||
Traverse the models and autoimport directories. Model files without a corresponding
|
Traverse the models and autoimport directories. Model files without a corresponding
|
||||||
record in the database are added. Orphan records without a models file are deleted.
|
record in the database are added. Orphan records without a models file are deleted.
|
||||||
|
@ -37,9 +37,5 @@ class SocketIO:
|
|||||||
if "queue_id" in data:
|
if "queue_id" in data:
|
||||||
await self.__sio.leave_room(sid, data["queue_id"])
|
await self.__sio.leave_room(sid, data["queue_id"])
|
||||||
|
|
||||||
|
|
||||||
async def _handle_model_event(self, event: Event) -> None:
|
async def _handle_model_event(self, event: Event) -> None:
|
||||||
await self.__sio.emit(
|
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])
|
||||||
event=event[1]["event"],
|
|
||||||
data=event[1]["data"]
|
|
||||||
)
|
|
||||||
|
@ -2,4 +2,4 @@
|
|||||||
|
|
||||||
from .config_default import InvokeAIAppConfig, get_invokeai_config
|
from .config_default import InvokeAIAppConfig, get_invokeai_config
|
||||||
|
|
||||||
__all__ = ['InvokeAIAppConfig', 'get_invokeai_config']
|
__all__ = ["InvokeAIAppConfig", "get_invokeai_config"]
|
||||||
|
@ -331,9 +331,7 @@ class EventServiceBase:
|
|||||||
"""
|
"""
|
||||||
self.__emit_model_event(
|
self.__emit_model_event(
|
||||||
event_name="model_install_started",
|
event_name="model_install_started",
|
||||||
payload={
|
payload={"source": source},
|
||||||
"source": source
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_model_install_completed(self, source: str, key: str) -> None:
|
def emit_model_install_completed(self, source: str, key: str) -> None:
|
||||||
@ -351,11 +349,12 @@ class EventServiceBase:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_model_install_progress(self,
|
def emit_model_install_progress(
|
||||||
source: str,
|
self,
|
||||||
current_bytes: int,
|
source: str,
|
||||||
total_bytes: int,
|
current_bytes: int,
|
||||||
) -> None:
|
total_bytes: int,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Emitted while the install job is in progress.
|
Emitted while the install job is in progress.
|
||||||
(Downloaded models only)
|
(Downloaded models only)
|
||||||
@ -373,12 +372,12 @@ class EventServiceBase:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def emit_model_install_error(
|
||||||
def emit_model_install_error(self,
|
self,
|
||||||
source: str,
|
source: str,
|
||||||
error_type: str,
|
error_type: str,
|
||||||
error: str,
|
error: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Emitted when an install job encounters an exception.
|
Emitted when an install job encounters an exception.
|
||||||
|
|
||||||
|
@ -9,10 +9,11 @@ from .model_install_base import (
|
|||||||
)
|
)
|
||||||
from .model_install_default import ModelInstallService
|
from .model_install_default import ModelInstallService
|
||||||
|
|
||||||
__all__ = ['ModelInstallServiceBase',
|
__all__ = [
|
||||||
'ModelInstallService',
|
"ModelInstallServiceBase",
|
||||||
'InstallStatus',
|
"ModelInstallService",
|
||||||
'ModelInstallJob',
|
"InstallStatus",
|
||||||
'UnknownInstallJobException',
|
"ModelInstallJob",
|
||||||
'ModelSource',
|
"UnknownInstallJobException",
|
||||||
]
|
"ModelSource",
|
||||||
|
]
|
||||||
|
@ -17,10 +17,10 @@ from invokeai.backend.model_manager import AnyModelConfig
|
|||||||
class InstallStatus(str, Enum):
|
class InstallStatus(str, Enum):
|
||||||
"""State of an install job running in the background."""
|
"""State of an install job running in the background."""
|
||||||
|
|
||||||
WAITING = "waiting" # waiting to be dequeued
|
WAITING = "waiting" # waiting to be dequeued
|
||||||
RUNNING = "running" # being processed
|
RUNNING = "running" # being processed
|
||||||
COMPLETED = "completed" # finished running
|
COMPLETED = "completed" # finished running
|
||||||
ERROR = "error" # terminated with an error message
|
ERROR = "error" # terminated with an error message
|
||||||
|
|
||||||
|
|
||||||
class UnknownInstallJobException(Exception):
|
class UnknownInstallJobException(Exception):
|
||||||
@ -32,10 +32,17 @@ ModelSource = Union[str, Path, AnyHttpUrl]
|
|||||||
|
|
||||||
class ModelInstallJob(BaseModel):
|
class ModelInstallJob(BaseModel):
|
||||||
"""Object that tracks the current status of an install request."""
|
"""Object that tracks the current status of an install request."""
|
||||||
|
|
||||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||||
config_in: Dict[str, Any] = Field(default_factory=dict, description="Configuration information (e.g. 'description') to apply to model.")
|
config_in: Dict[str, Any] = Field(
|
||||||
config_out: Optional[AnyModelConfig] = Field(default=None, description="After successful installation, this will hold the configuration object.")
|
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
|
||||||
inplace: bool = Field(default=False, description="Leave model in its current location; otherwise install under models directory")
|
)
|
||||||
|
config_out: Optional[AnyModelConfig] = Field(
|
||||||
|
default=None, description="After successful installation, this will hold the configuration object."
|
||||||
|
)
|
||||||
|
inplace: bool = Field(
|
||||||
|
default=False, description="Leave model in its current location; otherwise install under models directory"
|
||||||
|
)
|
||||||
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
|
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
|
||||||
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
|
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
|
||||||
error_type: Optional[str] = Field(default=None, description="Class name of the exception that led to status==ERROR")
|
error_type: Optional[str] = Field(default=None, description="Class name of the exception that led to status==ERROR")
|
||||||
@ -53,10 +60,10 @@ class ModelInstallServiceBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
app_config: InvokeAIAppConfig,
|
app_config: InvokeAIAppConfig,
|
||||||
record_store: ModelRecordServiceBase,
|
record_store: ModelRecordServiceBase,
|
||||||
event_bus: Optional["EventServiceBase"] = None,
|
event_bus: Optional["EventServiceBase"] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create ModelInstallService object.
|
Create ModelInstallService object.
|
||||||
@ -86,9 +93,9 @@ class ModelInstallServiceBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def register_path(
|
def register_path(
|
||||||
self,
|
self,
|
||||||
model_path: Union[Path, str],
|
model_path: Union[Path, str],
|
||||||
config: Optional[Dict[str, Any]] = None,
|
config: Optional[Dict[str, Any]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Probe and register the model at model_path.
|
Probe and register the model at model_path.
|
||||||
@ -114,9 +121,9 @@ class ModelInstallServiceBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def install_path(
|
def install_path(
|
||||||
self,
|
self,
|
||||||
model_path: Union[Path, str],
|
model_path: Union[Path, str],
|
||||||
config: Optional[Dict[str, Any]] = None,
|
config: Optional[Dict[str, Any]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Probe, register and install the model in the models directory.
|
Probe, register and install the model in the models directory.
|
||||||
@ -131,13 +138,13 @@ class ModelInstallServiceBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def import_model(
|
def import_model(
|
||||||
self,
|
self,
|
||||||
source: Union[str, Path, AnyHttpUrl],
|
source: Union[str, Path, AnyHttpUrl],
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
variant: Optional[str] = None,
|
variant: Optional[str] = None,
|
||||||
subfolder: Optional[str] = None,
|
subfolder: Optional[str] = None,
|
||||||
config: Optional[Dict[str, Any]] = None,
|
config: Optional[Dict[str, Any]] = None,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
"""Install the indicated model.
|
"""Install the indicated model.
|
||||||
|
|
||||||
@ -189,7 +196,7 @@ class ModelInstallServiceBase(ABC):
|
|||||||
"""Return the ModelInstallJob corresponding to the provided source."""
|
"""Return the ModelInstallJob corresponding to the provided source."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list_jobs(self, source: Optional[ModelSource]=None) -> List[ModelInstallJob]: # noqa D102
|
def list_jobs(self, source: Optional[ModelSource] = None) -> List[ModelInstallJob]: # noqa D102
|
||||||
"""
|
"""
|
||||||
List active and complete install jobs.
|
List active and complete install jobs.
|
||||||
|
|
||||||
|
@ -46,11 +46,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
_cached_model_paths: Set[Path]
|
_cached_model_paths: Set[Path]
|
||||||
_models_installed: Set[str]
|
_models_installed: Set[str]
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
app_config: InvokeAIAppConfig,
|
self,
|
||||||
record_store: ModelRecordServiceBase,
|
app_config: InvokeAIAppConfig,
|
||||||
event_bus: Optional[EventServiceBase] = None
|
record_store: ModelRecordServiceBase,
|
||||||
):
|
event_bus: Optional[EventServiceBase] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the installer object.
|
Initialize the installer object.
|
||||||
|
|
||||||
@ -73,11 +74,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
return self._app_config
|
return self._app_config
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def record_store(self) -> ModelRecordServiceBase: # noqa D102
|
def record_store(self) -> ModelRecordServiceBase: # noqa D102
|
||||||
return self._record_store
|
return self._record_store
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
|
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
|
||||||
return self._event_bus
|
return self._event_bus
|
||||||
|
|
||||||
def _start_installer_thread(self) -> None:
|
def _start_installer_thread(self) -> None:
|
||||||
@ -129,25 +130,25 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._event_bus.emit_model_install_error(str(job.source), error_type, error)
|
self._event_bus.emit_model_install_error(str(job.source), error_type, error)
|
||||||
|
|
||||||
def register_path(
|
def register_path(
|
||||||
self,
|
self,
|
||||||
model_path: Union[Path, str],
|
model_path: Union[Path, str],
|
||||||
config: Optional[Dict[str, Any]] = None,
|
config: Optional[Dict[str, Any]] = None,
|
||||||
) -> str: # noqa D102
|
) -> str: # noqa D102
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
config = config or {}
|
config = config or {}
|
||||||
if config.get('source') is None:
|
if config.get("source") is None:
|
||||||
config['source'] = model_path.resolve().as_posix()
|
config["source"] = model_path.resolve().as_posix()
|
||||||
return self._register(model_path, config)
|
return self._register(model_path, config)
|
||||||
|
|
||||||
def install_path(
|
def install_path(
|
||||||
self,
|
self,
|
||||||
model_path: Union[Path, str],
|
model_path: Union[Path, str],
|
||||||
config: Optional[Dict[str, Any]] = None,
|
config: Optional[Dict[str, Any]] = None,
|
||||||
) -> str: # noqa D102
|
) -> str: # noqa D102
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
config = config or {}
|
config = config or {}
|
||||||
if config.get('source') is None:
|
if config.get("source") is None:
|
||||||
config['source'] = model_path.resolve().as_posix()
|
config["source"] = model_path.resolve().as_posix()
|
||||||
|
|
||||||
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
||||||
|
|
||||||
@ -164,14 +165,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def import_model(
|
def import_model(
|
||||||
self,
|
self,
|
||||||
source: ModelSource,
|
source: ModelSource,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
variant: Optional[str] = None,
|
variant: Optional[str] = None,
|
||||||
subfolder: Optional[str] = None,
|
subfolder: Optional[str] = None,
|
||||||
config: Optional[Dict[str, Any]] = None,
|
config: Optional[Dict[str, Any]] = None,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
) -> ModelInstallJob: # noqa D102
|
) -> ModelInstallJob: # noqa D102
|
||||||
# Clean up a common source of error. Doesn't work with Paths.
|
# Clean up a common source of error. Doesn't work with Paths.
|
||||||
if isinstance(source, str):
|
if isinstance(source, str):
|
||||||
source = source.strip()
|
source = source.strip()
|
||||||
@ -181,11 +182,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
# Installing a local path
|
# Installing a local path
|
||||||
if isinstance(source, (str, Path)) and Path(source).exists(): # a path that is already on disk
|
if isinstance(source, (str, Path)) and Path(source).exists(): # a path that is already on disk
|
||||||
job = ModelInstallJob(config_in=config,
|
job = ModelInstallJob(
|
||||||
source=source,
|
config_in=config,
|
||||||
inplace=inplace,
|
source=source,
|
||||||
local_path=Path(source),
|
inplace=inplace,
|
||||||
)
|
local_path=Path(source),
|
||||||
|
)
|
||||||
self._install_jobs[source] = job
|
self._install_jobs[source] = job
|
||||||
self._install_queue.put(job)
|
self._install_queue.put(job)
|
||||||
return job
|
return job
|
||||||
@ -193,7 +195,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
else: # here is where we'd download a URL or repo_id. Implementation pending download queue.
|
else: # here is where we'd download a URL or repo_id. Implementation pending download queue.
|
||||||
raise UnknownModelException("File or directory not found")
|
raise UnknownModelException("File or directory not found")
|
||||||
|
|
||||||
def list_jobs(self, source: Optional[ModelSource]=None) -> List[ModelInstallJob]: # noqa D102
|
def list_jobs(self, source: Optional[ModelSource] = None) -> List[ModelInstallJob]: # noqa D102
|
||||||
jobs = self._install_jobs
|
jobs = self._install_jobs
|
||||||
if not source:
|
if not source:
|
||||||
return list(jobs.values())
|
return list(jobs.values())
|
||||||
@ -205,17 +207,19 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
try:
|
try:
|
||||||
return self._install_jobs[source]
|
return self._install_jobs[source]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise UnknownInstallJobException(f'{source}: unknown install job')
|
raise UnknownInstallJobException(f"{source}: unknown install job")
|
||||||
|
|
||||||
def wait_for_installs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
|
def wait_for_installs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
|
||||||
self._install_queue.join()
|
self._install_queue.join()
|
||||||
return self._install_jobs
|
return self._install_jobs
|
||||||
|
|
||||||
def prune_jobs(self) -> None:
|
def prune_jobs(self) -> None:
|
||||||
"""Prune all completed and errored jobs."""
|
"""Prune all completed and errored jobs."""
|
||||||
finished_jobs = [source for source in self._install_jobs
|
finished_jobs = [
|
||||||
if self._install_jobs[source].status in [InstallStatus.COMPLETED, InstallStatus.ERROR]
|
source
|
||||||
]
|
for source in self._install_jobs
|
||||||
|
if self._install_jobs[source].status in [InstallStatus.COMPLETED, InstallStatus.ERROR]
|
||||||
|
]
|
||||||
for source in finished_jobs:
|
for source in finished_jobs:
|
||||||
del self._install_jobs[source]
|
del self._install_jobs[source]
|
||||||
|
|
||||||
@ -228,7 +232,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._logger.info(f"{len(installed)} new models registered")
|
self._logger.info(f"{len(installed)} new models registered")
|
||||||
self._logger.info("Model installer (re)initialized")
|
self._logger.info("Model installer (re)initialized")
|
||||||
|
|
||||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||||
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
|
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
|
||||||
callback = self._scan_install if install else self._scan_register
|
callback = self._scan_install if install else self._scan_register
|
||||||
search = ModelSearch(on_model_found=callback)
|
search = ModelSearch(on_model_found=callback)
|
||||||
@ -295,7 +299,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self.record_store.update_model(key, model)
|
self.record_store.update_model(key, model)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _scan_register(self, model: Path) -> bool:
|
def _scan_register(self, model: Path) -> bool:
|
||||||
if model in self._cached_model_paths:
|
if model in self._cached_model_paths:
|
||||||
return True
|
return True
|
||||||
@ -308,7 +311,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
pass
|
pass
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _scan_install(self, model: Path) -> bool:
|
def _scan_install(self, model: Path) -> bool:
|
||||||
if model in self._cached_model_paths:
|
if model in self._cached_model_paths:
|
||||||
return True
|
return True
|
||||||
@ -320,7 +322,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
pass
|
pass
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def unregister(self, key: str) -> None: # noqa D102
|
def unregister(self, key: str) -> None: # noqa D102
|
||||||
self.record_store.del_model(key)
|
self.record_store.del_model(key)
|
||||||
|
|
||||||
def delete(self, key: str) -> None: # noqa D102
|
def delete(self, key: str) -> None: # noqa D102
|
||||||
@ -333,7 +335,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
else:
|
else:
|
||||||
self.unregister(key)
|
self.unregister(key)
|
||||||
|
|
||||||
def unconditionally_delete(self, key: str) -> None: # noqa D102
|
def unconditionally_delete(self, key: str) -> None: # noqa D102
|
||||||
model = self.record_store.get_model(key)
|
model = self.record_store.get_model(key)
|
||||||
path = self.app_config.models_path / model.path
|
path = self.app_config.models_path / model.path
|
||||||
if path.is_dir():
|
if path.is_dir():
|
||||||
@ -378,11 +380,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
def _create_key(self) -> str:
|
def _create_key(self) -> str:
|
||||||
return sha256(randbytes(100)).hexdigest()[0:32]
|
return sha256(randbytes(100)).hexdigest()[0:32]
|
||||||
|
|
||||||
def _register(self,
|
def _register(
|
||||||
model_path: Path,
|
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||||
config: Optional[Dict[str, Any]] = None,
|
) -> str:
|
||||||
info: Optional[AnyModelConfig] = None) -> str:
|
|
||||||
|
|
||||||
info = info or ModelProbe.probe(model_path, config)
|
info = info or ModelProbe.probe(model_path, config)
|
||||||
key = self._create_key()
|
key = self._create_key()
|
||||||
|
|
||||||
@ -393,7 +393,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
info.path = model_path.as_posix()
|
info.path = model_path.as_posix()
|
||||||
|
|
||||||
# add 'main' specific fields
|
# add 'main' specific fields
|
||||||
if hasattr(info, 'config'):
|
if hasattr(info, "config"):
|
||||||
# make config relative to our root
|
# make config relative to our root
|
||||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
|
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
|
||||||
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||||
|
@ -8,9 +8,9 @@ from .model_records_base import ( # noqa F401
|
|||||||
from .model_records_sql import ModelRecordServiceSQL # noqa F401
|
from .model_records_sql import ModelRecordServiceSQL # noqa F401
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ModelRecordServiceBase',
|
"ModelRecordServiceBase",
|
||||||
'ModelRecordServiceSQL',
|
"ModelRecordServiceSQL",
|
||||||
'DuplicateModelException',
|
"DuplicateModelException",
|
||||||
'InvalidModelException',
|
"InvalidModelException",
|
||||||
'UnknownModelException',
|
"UnknownModelException",
|
||||||
]
|
]
|
||||||
|
@ -123,8 +123,8 @@ class ModelProbe(object):
|
|||||||
base_type=base_type,
|
base_type=base_type,
|
||||||
variant_type=variant_type,
|
variant_type=variant_type,
|
||||||
prediction_type=prediction_type,
|
prediction_type=prediction_type,
|
||||||
name = name,
|
name=name,
|
||||||
description = description,
|
description=description,
|
||||||
upcast_attention=(
|
upcast_attention=(
|
||||||
base_type == BaseModelType.StableDiffusion2
|
base_type == BaseModelType.StableDiffusion2
|
||||||
and prediction_type == SchedulerPredictionType.VPrediction
|
and prediction_type == SchedulerPredictionType.VPrediction
|
||||||
@ -150,7 +150,7 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model_name(cls, model_path: Path) -> str:
|
def get_model_name(cls, model_path: Path) -> str:
|
||||||
if model_path.suffix in {'.safetensors', '.bin', '.pt', '.ckpt'}:
|
if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
|
||||||
return model_path.stem
|
return model_path.stem
|
||||||
else:
|
else:
|
||||||
return model_path.name
|
return model_path.name
|
||||||
|
@ -14,15 +14,16 @@ from .config import (
|
|||||||
from .probe import ModelProbe
|
from .probe import ModelProbe
|
||||||
from .search import ModelSearch
|
from .search import ModelSearch
|
||||||
|
|
||||||
__all__ = ['ModelProbe', 'ModelSearch',
|
__all__ = [
|
||||||
'InvalidModelConfigException',
|
"ModelProbe",
|
||||||
'ModelConfigFactory',
|
"ModelSearch",
|
||||||
'BaseModelType',
|
"InvalidModelConfigException",
|
||||||
'ModelType',
|
"ModelConfigFactory",
|
||||||
'SubModelType',
|
"BaseModelType",
|
||||||
'ModelVariantType',
|
"ModelType",
|
||||||
'ModelFormat',
|
"SubModelType",
|
||||||
'SchedulerPredictionType',
|
"ModelVariantType",
|
||||||
'AnyModelConfig',
|
"ModelFormat",
|
||||||
]
|
"SchedulerPredictionType",
|
||||||
|
"AnyModelConfig",
|
||||||
|
]
|
||||||
|
@ -49,6 +49,7 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[Sched
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ProbeBase(object):
|
class ProbeBase(object):
|
||||||
"""Base class for probes."""
|
"""Base class for probes."""
|
||||||
|
|
||||||
@ -71,6 +72,7 @@ class ProbeBase(object):
|
|||||||
"""Get model scheduler prediction type."""
|
"""Get model scheduler prediction type."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class ModelProbe(object):
|
class ModelProbe(object):
|
||||||
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
|
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
|
||||||
"diffusers": {},
|
"diffusers": {},
|
||||||
@ -100,9 +102,9 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def heuristic_probe(
|
def heuristic_probe(
|
||||||
cls,
|
cls,
|
||||||
model_path: Path,
|
model_path: Path,
|
||||||
fields: Optional[Dict[str, Any]] = None,
|
fields: Optional[Dict[str, Any]] = None,
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
return cls.probe(model_path, fields)
|
return cls.probe(model_path, fields)
|
||||||
|
|
||||||
@ -138,29 +140,38 @@ class ModelProbe(object):
|
|||||||
hash = FastModelHash.hash(model_path)
|
hash = FastModelHash.hash(model_path)
|
||||||
probe = probe_class(model_path)
|
probe = probe_class(model_path)
|
||||||
|
|
||||||
fields['path'] = model_path.as_posix()
|
fields["path"] = model_path.as_posix()
|
||||||
fields['type'] = fields.get('type') or model_type
|
fields["type"] = fields.get("type") or model_type
|
||||||
fields['base'] = fields.get('base') or probe.get_base_type()
|
fields["base"] = fields.get("base") or probe.get_base_type()
|
||||||
fields['variant'] = fields.get('variant') or probe.get_variant_type()
|
fields["variant"] = fields.get("variant") or probe.get_variant_type()
|
||||||
fields['prediction_type'] = fields.get('prediction_type') or probe.get_scheduler_prediction_type()
|
fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type()
|
||||||
fields['name'] = fields.get('name') or cls.get_model_name(model_path)
|
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
|
||||||
fields['description'] = fields.get('description') or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
fields["description"] = (
|
||||||
fields['format'] = fields.get('format') or probe.get_format()
|
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||||
fields['original_hash'] = fields.get('original_hash') or hash
|
)
|
||||||
fields['current_hash'] = fields.get('current_hash') or hash
|
fields["format"] = fields.get("format") or probe.get_format()
|
||||||
|
fields["original_hash"] = fields.get("original_hash") or hash
|
||||||
|
fields["current_hash"] = fields.get("current_hash") or hash
|
||||||
|
|
||||||
# additional fields needed for main and controlnet models
|
# additional fields needed for main and controlnet models
|
||||||
if fields['type'] in [ModelType.Main, ModelType.ControlNet] and fields['format'] == ModelFormat.Checkpoint:
|
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
|
||||||
fields['config'] = cls._get_checkpoint_config_path(model_path,
|
fields["config"] = cls._get_checkpoint_config_path(
|
||||||
model_type=fields['type'],
|
model_path,
|
||||||
base_type=fields['base'],
|
model_type=fields["type"],
|
||||||
variant_type=fields['variant'],
|
base_type=fields["base"],
|
||||||
prediction_type=fields['prediction_type']).as_posix()
|
variant_type=fields["variant"],
|
||||||
|
prediction_type=fields["prediction_type"],
|
||||||
|
).as_posix()
|
||||||
|
|
||||||
# additional fields needed for main non-checkpoint models
|
# additional fields needed for main non-checkpoint models
|
||||||
elif fields['type'] == ModelType.Main and fields['format'] in [ModelFormat.Onnx, ModelFormat.Olive, ModelFormat.Diffusers]:
|
elif fields["type"] == ModelType.Main and fields["format"] in [
|
||||||
fields['upcast_attention'] = fields.get('upcast_attention') or (
|
ModelFormat.Onnx,
|
||||||
fields['base'] == BaseModelType.StableDiffusion2 and fields['prediction_type'] == SchedulerPredictionType.VPrediction
|
ModelFormat.Olive,
|
||||||
|
ModelFormat.Diffusers,
|
||||||
|
]:
|
||||||
|
fields["upcast_attention"] = fields.get("upcast_attention") or (
|
||||||
|
fields["base"] == BaseModelType.StableDiffusion2
|
||||||
|
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
|
||||||
)
|
)
|
||||||
|
|
||||||
model_info = ModelConfigFactory.make_config(fields)
|
model_info = ModelConfigFactory.make_config(fields)
|
||||||
@ -168,7 +179,7 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model_name(cls, model_path: Path) -> str:
|
def get_model_name(cls, model_path: Path) -> str:
|
||||||
if model_path.suffix in {'.safetensors', '.bin', '.pt', '.ckpt'}:
|
if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
|
||||||
return model_path.stem
|
return model_path.stem
|
||||||
else:
|
else:
|
||||||
return model_path.name
|
return model_path.name
|
||||||
@ -247,13 +258,14 @@ class ModelProbe(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_checkpoint_config_path(cls,
|
def _get_checkpoint_config_path(
|
||||||
model_path: Path,
|
cls,
|
||||||
model_type: ModelType,
|
model_path: Path,
|
||||||
base_type: BaseModelType,
|
model_type: ModelType,
|
||||||
variant_type: ModelVariantType,
|
base_type: BaseModelType,
|
||||||
prediction_type: SchedulerPredictionType) -> Path:
|
variant_type: ModelVariantType,
|
||||||
|
prediction_type: SchedulerPredictionType,
|
||||||
|
) -> Path:
|
||||||
# look for a YAML file adjacent to the model file first
|
# look for a YAML file adjacent to the model file first
|
||||||
possible_conf = model_path.with_suffix(".yaml")
|
possible_conf = model_path.with_suffix(".yaml")
|
||||||
if possible_conf.exists():
|
if possible_conf.exists():
|
||||||
@ -264,9 +276,13 @@ class ModelProbe(object):
|
|||||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||||
config_file = config_file[prediction_type]
|
config_file = config_file[prediction_type]
|
||||||
elif model_type == ModelType.ControlNet:
|
elif model_type == ModelType.ControlNet:
|
||||||
config_file = "../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
|
config_file = (
|
||||||
|
"../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise InvalidModelConfigException(f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}")
|
raise InvalidModelConfigException(
|
||||||
|
f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}"
|
||||||
|
)
|
||||||
assert isinstance(config_file, str)
|
assert isinstance(config_file, str)
|
||||||
return Path(config_file)
|
return Path(config_file)
|
||||||
|
|
||||||
@ -297,6 +313,7 @@ class ModelProbe(object):
|
|||||||
# Checkpoint probing
|
# Checkpoint probing
|
||||||
# ##################################################3
|
# ##################################################3
|
||||||
|
|
||||||
|
|
||||||
class CheckpointProbeBase(ProbeBase):
|
class CheckpointProbeBase(ProbeBase):
|
||||||
def __init__(self, model_path: Path):
|
def __init__(self, model_path: Path):
|
||||||
super().__init__(model_path)
|
super().__init__(model_path)
|
||||||
@ -446,7 +463,6 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
|||||||
# classes for probing folders
|
# classes for probing folders
|
||||||
#######################################################
|
#######################################################
|
||||||
class FolderProbeBase(ProbeBase):
|
class FolderProbeBase(ProbeBase):
|
||||||
|
|
||||||
def get_variant_type(self) -> ModelVariantType:
|
def get_variant_type(self) -> ModelVariantType:
|
||||||
return ModelVariantType.Normal
|
return ModelVariantType.Normal
|
||||||
|
|
||||||
@ -537,7 +553,9 @@ class TextualInversionFolderProbe(FolderProbeBase):
|
|||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
path = self.model_path / "learned_embeds.bin"
|
path = self.model_path / "learned_embeds.bin"
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
raise InvalidModelConfigException(f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file")
|
raise InvalidModelConfigException(
|
||||||
|
f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file"
|
||||||
|
)
|
||||||
return TextualInversionCheckpointProbe(path).get_base_type()
|
return TextualInversionCheckpointProbe(path).get_base_type()
|
||||||
|
|
||||||
|
|
||||||
@ -608,7 +626,9 @@ class IPAdapterFolderProbe(FolderProbeBase):
|
|||||||
elif cross_attention_dim == 2048:
|
elif cross_attention_dim == 2048:
|
||||||
return BaseModelType.StableDiffusionXL
|
return BaseModelType.StableDiffusionXL
|
||||||
else:
|
else:
|
||||||
raise InvalidModelConfigException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
|
raise InvalidModelConfigException(
|
||||||
|
f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionFolderProbe(FolderProbeBase):
|
class CLIPVisionFolderProbe(FolderProbeBase):
|
||||||
|
@ -165,14 +165,14 @@ class ModelSearch(ModelSearchBase):
|
|||||||
self.scanned_dirs.add(path)
|
self.scanned_dirs.add(path)
|
||||||
continue
|
continue
|
||||||
if any(
|
if any(
|
||||||
(path / x).exists()
|
(path / x).exists()
|
||||||
for x in [
|
for x in [
|
||||||
"config.json",
|
"config.json",
|
||||||
"model_index.json",
|
"model_index.json",
|
||||||
"learned_embeds.bin",
|
"learned_embeds.bin",
|
||||||
"pytorch_lora_weights.bin",
|
"pytorch_lora_weights.bin",
|
||||||
"image_encoder.txt",
|
"image_encoder.txt",
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
self.scanned_dirs.add(path)
|
self.scanned_dirs.add(path)
|
||||||
try:
|
try:
|
||||||
|
@ -14,4 +14,4 @@ from .devices import ( # noqa: F401
|
|||||||
from .logging import InvokeAILogger
|
from .logging import InvokeAILogger
|
||||||
from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401
|
from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401
|
||||||
|
|
||||||
__all__ = ['Chdir', 'InvokeAILogger', 'choose_precision', 'choose_torch_device']
|
__all__ = ["Chdir", "InvokeAILogger", "choose_precision", "choose_torch_device"]
|
||||||
|
@ -44,12 +44,12 @@ def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def installer(app_config: InvokeAIAppConfig,
|
def installer(app_config: InvokeAIAppConfig, store: ModelRecordServiceBase) -> ModelInstallServiceBase:
|
||||||
store: ModelRecordServiceBase) -> ModelInstallServiceBase:
|
return ModelInstallService(
|
||||||
return ModelInstallService(app_config=app_config,
|
app_config=app_config,
|
||||||
record_store=store,
|
record_store=store,
|
||||||
event_bus=DummyEventService(),
|
event_bus=DummyEventService(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DummyEvent(BaseModel):
|
class DummyEvent(BaseModel):
|
||||||
@ -70,10 +70,8 @@ class DummyEventService(EventServiceBase):
|
|||||||
|
|
||||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||||
"""Dispatch an event by appending it to self.events."""
|
"""Dispatch an event by appending it to self.events."""
|
||||||
self.events.append(
|
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
|
||||||
DummyEvent(event_name=payload['event'],
|
|
||||||
payload=payload['data'])
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_registration(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
def test_registration(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||||
store = installer.record_store
|
store = installer.record_store
|
||||||
@ -83,6 +81,7 @@ def test_registration(installer: ModelInstallServiceBase, test_file: Path) -> No
|
|||||||
assert key is not None
|
assert key is not None
|
||||||
assert len(key) == 32
|
assert len(key) == 32
|
||||||
|
|
||||||
|
|
||||||
def test_registration_meta(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
def test_registration_meta(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||||
store = installer.record_store
|
store = installer.record_store
|
||||||
key = installer.register_path(test_file)
|
key = installer.register_path(test_file)
|
||||||
@ -91,31 +90,30 @@ def test_registration_meta(installer: ModelInstallServiceBase, test_file: Path)
|
|||||||
assert model_record.name == "test_embedding"
|
assert model_record.name == "test_embedding"
|
||||||
assert model_record.type == ModelType.TextualInversion
|
assert model_record.type == ModelType.TextualInversion
|
||||||
assert Path(model_record.path) == test_file
|
assert Path(model_record.path) == test_file
|
||||||
assert model_record.base == BaseModelType('sd-1')
|
assert model_record.base == BaseModelType("sd-1")
|
||||||
assert model_record.description is not None
|
assert model_record.description is not None
|
||||||
assert model_record.source is not None
|
assert model_record.source is not None
|
||||||
assert Path(model_record.source) == test_file
|
assert Path(model_record.source) == test_file
|
||||||
|
|
||||||
|
|
||||||
def test_registration_meta_override_fail(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
def test_registration_meta_override_fail(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||||
key = None
|
key = None
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
key = installer.register_path(test_file, {"name": "banana_sushi", "type": ModelType("lora")})
|
key = installer.register_path(test_file, {"name": "banana_sushi", "type": ModelType("lora")})
|
||||||
assert key is None
|
assert key is None
|
||||||
|
|
||||||
|
|
||||||
def test_registration_meta_override_succeed(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
def test_registration_meta_override_succeed(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||||
store = installer.record_store
|
store = installer.record_store
|
||||||
key = installer.register_path(test_file,
|
key = installer.register_path(
|
||||||
{
|
test_file, {"name": "banana_sushi", "source": "fake/repo_id", "current_hash": "New Hash"}
|
||||||
"name": "banana_sushi",
|
)
|
||||||
"source": "fake/repo_id",
|
|
||||||
"current_hash": "New Hash"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
model_record = store.get_model(key)
|
model_record = store.get_model(key)
|
||||||
assert model_record.name == "banana_sushi"
|
assert model_record.name == "banana_sushi"
|
||||||
assert model_record.source == "fake/repo_id"
|
assert model_record.source == "fake/repo_id"
|
||||||
assert model_record.current_hash == "New Hash"
|
assert model_record.current_hash == "New Hash"
|
||||||
|
|
||||||
|
|
||||||
def test_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
|
def test_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
|
||||||
store = installer.record_store
|
store = installer.record_store
|
||||||
key = installer.install_path(test_file)
|
key = installer.install_path(test_file)
|
||||||
@ -123,6 +121,7 @@ def test_install(installer: ModelInstallServiceBase, test_file: Path, app_config
|
|||||||
assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
|
assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
|
||||||
assert model_record.source == test_file.as_posix()
|
assert model_record.source == test_file.as_posix()
|
||||||
|
|
||||||
|
|
||||||
def test_background_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
|
def test_background_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
|
||||||
"""Note: may want to break this down into several smaller unit tests."""
|
"""Note: may want to break this down into several smaller unit tests."""
|
||||||
source = test_file
|
source = test_file
|
||||||
@ -142,7 +141,7 @@ def test_background_install(installer: ModelInstallServiceBase, test_file: Path,
|
|||||||
|
|
||||||
# test that the expected events were issued
|
# test that the expected events were issued
|
||||||
bus = installer.event_bus
|
bus = installer.event_bus
|
||||||
assert bus is not None # sigh - ruff is a stickler for type checking
|
assert bus is not None # sigh - ruff is a stickler for type checking
|
||||||
assert isinstance(bus, DummyEventService)
|
assert isinstance(bus, DummyEventService)
|
||||||
assert len(bus.events) == 2
|
assert len(bus.events) == 2
|
||||||
event_names = [x.event_name for x in bus.events]
|
event_names = [x.event_name for x in bus.events]
|
||||||
@ -167,6 +166,7 @@ def test_background_install(installer: ModelInstallServiceBase, test_file: Path,
|
|||||||
with pytest.raises(UnknownInstallJobException):
|
with pytest.raises(UnknownInstallJobException):
|
||||||
assert installer.get_job(source)
|
assert installer.get_job(source)
|
||||||
|
|
||||||
|
|
||||||
def test_delete_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
|
def test_delete_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
|
||||||
store = installer.record_store
|
store = installer.record_store
|
||||||
key = installer.install_path(test_file)
|
key = installer.install_path(test_file)
|
||||||
@ -174,11 +174,14 @@ def test_delete_install(installer: ModelInstallServiceBase, test_file: Path, app
|
|||||||
assert Path(app_config.models_dir / model_record.path).exists()
|
assert Path(app_config.models_dir / model_record.path).exists()
|
||||||
assert test_file.exists() # original should still be there after installation
|
assert test_file.exists() # original should still be there after installation
|
||||||
installer.delete(key)
|
installer.delete(key)
|
||||||
assert not Path(app_config.models_dir / model_record.path).exists() # after deletion, installed copy should not exist
|
assert not Path(
|
||||||
|
app_config.models_dir / model_record.path
|
||||||
|
).exists() # after deletion, installed copy should not exist
|
||||||
assert test_file.exists() # but original should still be there
|
assert test_file.exists() # but original should still be there
|
||||||
with pytest.raises(UnknownModelException):
|
with pytest.raises(UnknownModelException):
|
||||||
store.get_model(key)
|
store.get_model(key)
|
||||||
|
|
||||||
|
|
||||||
def test_delete_register(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
|
def test_delete_register(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
|
||||||
store = installer.record_store
|
store = installer.record_store
|
||||||
key = installer.register_path(test_file)
|
key = installer.register_path(test_file)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user