implement review suggestions from @RyanjDick

This commit is contained in:
Lincoln Stein 2023-12-04 21:12:10 -05:00
parent 018ccebd6f
commit 2b583ffcdf
7 changed files with 43 additions and 53 deletions

View File

@ -643,11 +643,10 @@ install_job = installer.import_model(
This section describes additional methods provided by the installer class.
#### source2job = installer.wait_for_installs()
#### jobs = installer.wait_for_installs()
Block until all pending installs are completed or errored and return a
dictionary that maps the model `source` to the completed
`ModelInstallJob`.
Block until all pending installs are completed or errored and then
returns a list of completed jobs.
#### jobs = installer.list_jobs([source])
@ -655,9 +654,10 @@ Return a list of all active and complete `ModelInstallJobs`. An
optional `source` argument allows you to filter the returned list by a
model source string pattern using a partial string match.
#### job = installer.get_job(source)
#### jobs = installer.get_job(source)
Return the `ModelInstallJob` corresponding to the indicated model source.
Return a list of `ModelInstallJob` corresponding to the indicated
model source.
#### installer.prune_jobs

View File

@ -178,7 +178,6 @@ async def add_model_record(
operation_id="import_model_record",
responses={
201: {"description": "The model imported successfully"},
404: {"description": "The model could not be found"},
415: {"description": "Unrecognized file/folder format"},
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
@ -258,7 +257,7 @@ async def import_model(
logger.info(f"Started installation of {source}")
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
raise HTTPException(status_code=424, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)

View File

@ -112,6 +112,16 @@ class URLModelSource(StringLikeSource):
return str(self.url)
# Body() is being applied here rather than Field() because otherwise FastAPI will
# refuse to generate a schema. Relevant links:
#
# "Model Manager Refactor Phase 1 - SQL-based config storage
# https://github.com/invoke-ai/InvokeAI/pull/5039#discussion_r1389752119 (comment)
# Param: xyz can only be a request body, using Body() when using discriminated unions
# https://github.com/tiangolo/fastapi/discussions/9761
# Body parameter cannot be a pydantic union anymore sinve v0.95
# https://github.com/tiangolo/fastapi/discussions/9287
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Body(discriminator="type")]
ModelSourceValidator = TypeAdapter(ModelSource)
@ -263,8 +273,8 @@ class ModelInstallServiceBase(ABC):
"""
@abstractmethod
def get_job(self, source: ModelSource) -> ModelInstallJob:
"""Return the ModelInstallJob corresponding to the provided source."""
def get_job(self, source: ModelSource) -> List[ModelInstallJob]:
"""Return the ModelInstallJob(s) corresponding to the provided source."""
@abstractmethod
def list_jobs(self, source: Optional[ModelSource | str] = None) -> List[ModelInstallJob]: # noqa D102
@ -279,7 +289,7 @@ class ModelInstallServiceBase(ABC):
"""Prune all completed and errored jobs."""
@abstractmethod
def wait_for_installs(self) -> Dict[ModelSource, ModelInstallJob]:
def wait_for_installs(self) -> List[ModelInstallJob]:
"""
Wait for all pending installs to complete.
@ -288,8 +298,7 @@ class ModelInstallServiceBase(ABC):
block indefinitely if one or more jobs are in the
paused state.
It will return a dict that maps the source model
path, URL or repo_id to the ID of the installed model.
It will return the current list of jobs.
"""
@abstractmethod
@ -305,12 +314,3 @@ class ModelInstallServiceBase(ABC):
@abstractmethod
def sync_to_config(self) -> None:
"""Synchronize models on disk to those in the model record database."""
@abstractmethod
def release(self) -> None:
"""
Signal the install thread to exit.
This is useful if you are done with the installer and wish to
release its resources.
"""

View File

@ -29,7 +29,6 @@ from .model_install_base import (
ModelInstallJob,
ModelInstallServiceBase,
ModelSource,
UnknownInstallJobException,
)
# marker that the queue is done and that thread should exit
@ -46,7 +45,7 @@ class ModelInstallService(ModelInstallServiceBase):
_record_store: ModelRecordServiceBase
_event_bus: Optional[EventServiceBase] = None
_install_queue: Queue[ModelInstallJob]
_install_jobs: Dict[ModelSource, ModelInstallJob]
_install_jobs: List[ModelInstallJob]
_logger: Logger
_cached_model_paths: Set[Path]
_models_installed: Set[str]
@ -68,12 +67,16 @@ class ModelInstallService(ModelInstallServiceBase):
self._record_store = record_store
self._event_bus = event_bus
self._logger = InvokeAILogger.get_logger(name=self.__class__.__name__)
self._install_jobs = {}
self._install_jobs = []
self._install_queue = Queue()
self._cached_model_paths = set()
self._models_installed = set()
self._start_installer_thread()
def __del__(self) -> None:
"""At GC time, we stop the install thread and release its resources."""
self._install_queue.put(STOP_JOB)
@property
def app_config(self) -> InvokeAIAppConfig: # noqa D102
return self._app_config
@ -189,7 +192,7 @@ class ModelInstallService(ModelInstallServiceBase):
config_in=config,
local_path=Path(source.path),
)
self._install_jobs[source] = job
self._install_jobs.append(job)
self._install_queue.put(job)
return job
@ -199,29 +202,23 @@ class ModelInstallService(ModelInstallServiceBase):
def list_jobs(self, source: Optional[ModelSource | str] = None) -> List[ModelInstallJob]: # noqa D102
jobs = self._install_jobs
if not source:
return list(jobs.values())
return jobs
else:
return [jobs[x] for x in jobs if str(source) in str(x)]
return [x for x in jobs if str(source) in str(x)]
def get_job(self, source: ModelSource) -> ModelInstallJob: # noqa D102
try:
return self._install_jobs[source]
except KeyError:
raise UnknownInstallJobException(f"{source}: unknown install job")
def get_job(self, source: ModelSource) -> List[ModelInstallJob]: # noqa D102
return [x for x in self._install_jobs if x.source == source]
def wait_for_installs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
def wait_for_installs(self) -> List[ModelInstallJob]: # noqa D102
self._install_queue.join()
return self._install_jobs
def prune_jobs(self) -> None:
"""Prune all completed and errored jobs."""
finished_jobs = [
source
for source in self._install_jobs
if self._install_jobs[source].status in [InstallStatus.COMPLETED, InstallStatus.ERROR]
unfinished_jobs = [
x for x in self._install_jobs if x.status not in [InstallStatus.COMPLETED, InstallStatus.ERROR]
]
for source in finished_jobs:
del self._install_jobs[source]
self._install_jobs = unfinished_jobs
def sync_to_config(self) -> None:
"""Synchronize models on disk to those in the config record store database."""
@ -344,10 +341,6 @@ class ModelInstallService(ModelInstallServiceBase):
path.unlink()
self.unregister(key)
def release(self) -> None:
"""Stop the install thread and release its resources."""
self._install_queue.put(STOP_JOB)
def _copy_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path

View File

@ -115,7 +115,6 @@ class ModelSearch(ModelSearchBase):
# returns all models that have 'anime' in the path
"""
directory: Path = Field(default=None)
models_found: Set[Path] = Field(default=None)
scanned_dirs: Set[Path] = Field(default=None)
pruned_paths: Set[Path] = Field(default=None)

View File

@ -16,7 +16,6 @@ from invokeai.app.services.model_install import (
ModelInstallJob,
ModelInstallService,
ModelInstallServiceBase,
UnknownInstallJobException,
)
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
from invokeai.app.services.shared.sqlite import SqliteDatabase
@ -133,13 +132,14 @@ def test_background_install(installer: ModelInstallServiceBase, test_file: Path,
assert isinstance(job, ModelInstallJob)
# See if job is registered properly
assert installer.get_job(source) == job
assert job in installer.get_job(source)
# test that the job object tracked installation correctly
jobs = installer.wait_for_installs()
assert jobs[source] is not None
assert jobs[source] == job
assert jobs[source].status == InstallStatus.COMPLETED
assert len(jobs) > 0
my_job = [x for x in jobs if x.source == source]
assert len(my_job) == 1
assert my_job[0].status == InstallStatus.COMPLETED
# test that the expected events were issued
bus = installer.event_bus
@ -165,9 +165,7 @@ def test_background_install(installer: ModelInstallServiceBase, test_file: Path,
# see if prune works properly
installer.prune_jobs()
with pytest.raises(UnknownInstallJobException):
assert installer.get_job(source)
assert not installer.get_job(source)
def test_delete_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
store = installer.record_store

View File

@ -0,0 +1 @@
This directory is used by pytest-datadir.