mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
implement review suggestions from @RyanjDick
This commit is contained in:
parent
018ccebd6f
commit
2b583ffcdf
@ -643,11 +643,10 @@ install_job = installer.import_model(
|
|||||||
|
|
||||||
This section describes additional methods provided by the installer class.
|
This section describes additional methods provided by the installer class.
|
||||||
|
|
||||||
#### source2job = installer.wait_for_installs()
|
#### jobs = installer.wait_for_installs()
|
||||||
|
|
||||||
Block until all pending installs are completed or errored and return a
|
Block until all pending installs are completed or errored and then
|
||||||
dictionary that maps the model `source` to the completed
|
returns a list of completed jobs.
|
||||||
`ModelInstallJob`.
|
|
||||||
|
|
||||||
#### jobs = installer.list_jobs([source])
|
#### jobs = installer.list_jobs([source])
|
||||||
|
|
||||||
@ -655,9 +654,10 @@ Return a list of all active and complete `ModelInstallJobs`. An
|
|||||||
optional `source` argument allows you to filter the returned list by a
|
optional `source` argument allows you to filter the returned list by a
|
||||||
model source string pattern using a partial string match.
|
model source string pattern using a partial string match.
|
||||||
|
|
||||||
#### job = installer.get_job(source)
|
#### jobs = installer.get_job(source)
|
||||||
|
|
||||||
Return the `ModelInstallJob` corresponding to the indicated model source.
|
Return a list of `ModelInstallJob` corresponding to the indicated
|
||||||
|
model source.
|
||||||
|
|
||||||
#### installer.prune_jobs
|
#### installer.prune_jobs
|
||||||
|
|
||||||
|
@ -178,7 +178,6 @@ async def add_model_record(
|
|||||||
operation_id="import_model_record",
|
operation_id="import_model_record",
|
||||||
responses={
|
responses={
|
||||||
201: {"description": "The model imported successfully"},
|
201: {"description": "The model imported successfully"},
|
||||||
404: {"description": "The model could not be found"},
|
|
||||||
415: {"description": "Unrecognized file/folder format"},
|
415: {"description": "Unrecognized file/folder format"},
|
||||||
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
||||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||||
@ -258,7 +257,7 @@ async def import_model(
|
|||||||
logger.info(f"Started installation of {source}")
|
logger.info(f"Started installation of {source}")
|
||||||
except UnknownModelException as e:
|
except UnknownModelException as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=424, detail=str(e))
|
||||||
except InvalidModelException as e:
|
except InvalidModelException as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=415)
|
raise HTTPException(status_code=415)
|
||||||
|
@ -112,6 +112,16 @@ class URLModelSource(StringLikeSource):
|
|||||||
return str(self.url)
|
return str(self.url)
|
||||||
|
|
||||||
|
|
||||||
|
# Body() is being applied here rather than Field() because otherwise FastAPI will
|
||||||
|
# refuse to generate a schema. Relevant links:
|
||||||
|
#
|
||||||
|
# "Model Manager Refactor Phase 1 - SQL-based config storage
|
||||||
|
# https://github.com/invoke-ai/InvokeAI/pull/5039#discussion_r1389752119 (comment)
|
||||||
|
# Param: xyz can only be a request body, using Body() when using discriminated unions
|
||||||
|
# https://github.com/tiangolo/fastapi/discussions/9761
|
||||||
|
# Body parameter cannot be a pydantic union anymore sinve v0.95
|
||||||
|
# https://github.com/tiangolo/fastapi/discussions/9287
|
||||||
|
|
||||||
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Body(discriminator="type")]
|
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Body(discriminator="type")]
|
||||||
ModelSourceValidator = TypeAdapter(ModelSource)
|
ModelSourceValidator = TypeAdapter(ModelSource)
|
||||||
|
|
||||||
@ -263,8 +273,8 @@ class ModelInstallServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_job(self, source: ModelSource) -> ModelInstallJob:
|
def get_job(self, source: ModelSource) -> List[ModelInstallJob]:
|
||||||
"""Return the ModelInstallJob corresponding to the provided source."""
|
"""Return the ModelInstallJob(s) corresponding to the provided source."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list_jobs(self, source: Optional[ModelSource | str] = None) -> List[ModelInstallJob]: # noqa D102
|
def list_jobs(self, source: Optional[ModelSource | str] = None) -> List[ModelInstallJob]: # noqa D102
|
||||||
@ -279,7 +289,7 @@ class ModelInstallServiceBase(ABC):
|
|||||||
"""Prune all completed and errored jobs."""
|
"""Prune all completed and errored jobs."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def wait_for_installs(self) -> Dict[ModelSource, ModelInstallJob]:
|
def wait_for_installs(self) -> List[ModelInstallJob]:
|
||||||
"""
|
"""
|
||||||
Wait for all pending installs to complete.
|
Wait for all pending installs to complete.
|
||||||
|
|
||||||
@ -288,8 +298,7 @@ class ModelInstallServiceBase(ABC):
|
|||||||
block indefinitely if one or more jobs are in the
|
block indefinitely if one or more jobs are in the
|
||||||
paused state.
|
paused state.
|
||||||
|
|
||||||
It will return a dict that maps the source model
|
It will return the current list of jobs.
|
||||||
path, URL or repo_id to the ID of the installed model.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -305,12 +314,3 @@ class ModelInstallServiceBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def sync_to_config(self) -> None:
|
def sync_to_config(self) -> None:
|
||||||
"""Synchronize models on disk to those in the model record database."""
|
"""Synchronize models on disk to those in the model record database."""
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def release(self) -> None:
|
|
||||||
"""
|
|
||||||
Signal the install thread to exit.
|
|
||||||
|
|
||||||
This is useful if you are done with the installer and wish to
|
|
||||||
release its resources.
|
|
||||||
"""
|
|
||||||
|
@ -29,7 +29,6 @@ from .model_install_base import (
|
|||||||
ModelInstallJob,
|
ModelInstallJob,
|
||||||
ModelInstallServiceBase,
|
ModelInstallServiceBase,
|
||||||
ModelSource,
|
ModelSource,
|
||||||
UnknownInstallJobException,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# marker that the queue is done and that thread should exit
|
# marker that the queue is done and that thread should exit
|
||||||
@ -46,7 +45,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
_record_store: ModelRecordServiceBase
|
_record_store: ModelRecordServiceBase
|
||||||
_event_bus: Optional[EventServiceBase] = None
|
_event_bus: Optional[EventServiceBase] = None
|
||||||
_install_queue: Queue[ModelInstallJob]
|
_install_queue: Queue[ModelInstallJob]
|
||||||
_install_jobs: Dict[ModelSource, ModelInstallJob]
|
_install_jobs: List[ModelInstallJob]
|
||||||
_logger: Logger
|
_logger: Logger
|
||||||
_cached_model_paths: Set[Path]
|
_cached_model_paths: Set[Path]
|
||||||
_models_installed: Set[str]
|
_models_installed: Set[str]
|
||||||
@ -68,12 +67,16 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._record_store = record_store
|
self._record_store = record_store
|
||||||
self._event_bus = event_bus
|
self._event_bus = event_bus
|
||||||
self._logger = InvokeAILogger.get_logger(name=self.__class__.__name__)
|
self._logger = InvokeAILogger.get_logger(name=self.__class__.__name__)
|
||||||
self._install_jobs = {}
|
self._install_jobs = []
|
||||||
self._install_queue = Queue()
|
self._install_queue = Queue()
|
||||||
self._cached_model_paths = set()
|
self._cached_model_paths = set()
|
||||||
self._models_installed = set()
|
self._models_installed = set()
|
||||||
self._start_installer_thread()
|
self._start_installer_thread()
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
"""At GC time, we stop the install thread and release its resources."""
|
||||||
|
self._install_queue.put(STOP_JOB)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
||||||
return self._app_config
|
return self._app_config
|
||||||
@ -189,7 +192,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
config_in=config,
|
config_in=config,
|
||||||
local_path=Path(source.path),
|
local_path=Path(source.path),
|
||||||
)
|
)
|
||||||
self._install_jobs[source] = job
|
self._install_jobs.append(job)
|
||||||
self._install_queue.put(job)
|
self._install_queue.put(job)
|
||||||
return job
|
return job
|
||||||
|
|
||||||
@ -199,29 +202,23 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
def list_jobs(self, source: Optional[ModelSource | str] = None) -> List[ModelInstallJob]: # noqa D102
|
def list_jobs(self, source: Optional[ModelSource | str] = None) -> List[ModelInstallJob]: # noqa D102
|
||||||
jobs = self._install_jobs
|
jobs = self._install_jobs
|
||||||
if not source:
|
if not source:
|
||||||
return list(jobs.values())
|
return jobs
|
||||||
else:
|
else:
|
||||||
return [jobs[x] for x in jobs if str(source) in str(x)]
|
return [x for x in jobs if str(source) in str(x)]
|
||||||
|
|
||||||
def get_job(self, source: ModelSource) -> ModelInstallJob: # noqa D102
|
def get_job(self, source: ModelSource) -> List[ModelInstallJob]: # noqa D102
|
||||||
try:
|
return [x for x in self._install_jobs if x.source == source]
|
||||||
return self._install_jobs[source]
|
|
||||||
except KeyError:
|
|
||||||
raise UnknownInstallJobException(f"{source}: unknown install job")
|
|
||||||
|
|
||||||
def wait_for_installs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
|
def wait_for_installs(self) -> List[ModelInstallJob]: # noqa D102
|
||||||
self._install_queue.join()
|
self._install_queue.join()
|
||||||
return self._install_jobs
|
return self._install_jobs
|
||||||
|
|
||||||
def prune_jobs(self) -> None:
|
def prune_jobs(self) -> None:
|
||||||
"""Prune all completed and errored jobs."""
|
"""Prune all completed and errored jobs."""
|
||||||
finished_jobs = [
|
unfinished_jobs = [
|
||||||
source
|
x for x in self._install_jobs if x.status not in [InstallStatus.COMPLETED, InstallStatus.ERROR]
|
||||||
for source in self._install_jobs
|
|
||||||
if self._install_jobs[source].status in [InstallStatus.COMPLETED, InstallStatus.ERROR]
|
|
||||||
]
|
]
|
||||||
for source in finished_jobs:
|
self._install_jobs = unfinished_jobs
|
||||||
del self._install_jobs[source]
|
|
||||||
|
|
||||||
def sync_to_config(self) -> None:
|
def sync_to_config(self) -> None:
|
||||||
"""Synchronize models on disk to those in the config record store database."""
|
"""Synchronize models on disk to those in the config record store database."""
|
||||||
@ -344,10 +341,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
path.unlink()
|
path.unlink()
|
||||||
self.unregister(key)
|
self.unregister(key)
|
||||||
|
|
||||||
def release(self) -> None:
|
|
||||||
"""Stop the install thread and release its resources."""
|
|
||||||
self._install_queue.put(STOP_JOB)
|
|
||||||
|
|
||||||
def _copy_model(self, old_path: Path, new_path: Path) -> Path:
|
def _copy_model(self, old_path: Path, new_path: Path) -> Path:
|
||||||
if old_path == new_path:
|
if old_path == new_path:
|
||||||
return old_path
|
return old_path
|
||||||
|
@ -115,7 +115,6 @@ class ModelSearch(ModelSearchBase):
|
|||||||
# returns all models that have 'anime' in the path
|
# returns all models that have 'anime' in the path
|
||||||
"""
|
"""
|
||||||
|
|
||||||
directory: Path = Field(default=None)
|
|
||||||
models_found: Set[Path] = Field(default=None)
|
models_found: Set[Path] = Field(default=None)
|
||||||
scanned_dirs: Set[Path] = Field(default=None)
|
scanned_dirs: Set[Path] = Field(default=None)
|
||||||
pruned_paths: Set[Path] = Field(default=None)
|
pruned_paths: Set[Path] = Field(default=None)
|
||||||
|
@ -16,7 +16,6 @@ from invokeai.app.services.model_install import (
|
|||||||
ModelInstallJob,
|
ModelInstallJob,
|
||||||
ModelInstallService,
|
ModelInstallService,
|
||||||
ModelInstallServiceBase,
|
ModelInstallServiceBase,
|
||||||
UnknownInstallJobException,
|
|
||||||
)
|
)
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
|
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
@ -133,13 +132,14 @@ def test_background_install(installer: ModelInstallServiceBase, test_file: Path,
|
|||||||
assert isinstance(job, ModelInstallJob)
|
assert isinstance(job, ModelInstallJob)
|
||||||
|
|
||||||
# See if job is registered properly
|
# See if job is registered properly
|
||||||
assert installer.get_job(source) == job
|
assert job in installer.get_job(source)
|
||||||
|
|
||||||
# test that the job object tracked installation correctly
|
# test that the job object tracked installation correctly
|
||||||
jobs = installer.wait_for_installs()
|
jobs = installer.wait_for_installs()
|
||||||
assert jobs[source] is not None
|
assert len(jobs) > 0
|
||||||
assert jobs[source] == job
|
my_job = [x for x in jobs if x.source == source]
|
||||||
assert jobs[source].status == InstallStatus.COMPLETED
|
assert len(my_job) == 1
|
||||||
|
assert my_job[0].status == InstallStatus.COMPLETED
|
||||||
|
|
||||||
# test that the expected events were issued
|
# test that the expected events were issued
|
||||||
bus = installer.event_bus
|
bus = installer.event_bus
|
||||||
@ -165,9 +165,7 @@ def test_background_install(installer: ModelInstallServiceBase, test_file: Path,
|
|||||||
|
|
||||||
# see if prune works properly
|
# see if prune works properly
|
||||||
installer.prune_jobs()
|
installer.prune_jobs()
|
||||||
with pytest.raises(UnknownInstallJobException):
|
assert not installer.get_job(source)
|
||||||
assert installer.get_job(source)
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
|
def test_delete_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
|
||||||
store = installer.record_store
|
store = installer.record_store
|
||||||
|
@ -0,0 +1 @@
|
|||||||
|
This directory is used by pytest-datadir.
|
Loading…
Reference in New Issue
Block a user