From 2b583ffcdf8e37c54c04731e9a6d2a19e5d85efb Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 4 Dec 2023 21:12:10 -0500 Subject: [PATCH] implement review suggestions from @RyanjDick --- docs/contributing/MODEL_MANAGER.md | 12 +++--- invokeai/app/api/routers/model_records.py | 3 +- .../model_install/model_install_base.py | 28 +++++++------- .../model_install/model_install_default.py | 37 ++++++++----------- invokeai/backend/model_manager/search.py | 1 - .../model_install/test_model_install.py | 14 +++---- .../model_install/test_model_install/README | 1 + 7 files changed, 43 insertions(+), 53 deletions(-) create mode 100644 tests/app/services/model_install/test_model_install/README diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index c230979361..ce55a4d0ad 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -643,11 +643,10 @@ install_job = installer.import_model( This section describes additional methods provided by the installer class. -#### source2job = installer.wait_for_installs() +#### jobs = installer.wait_for_installs() -Block until all pending installs are completed or errored and return a -dictionary that maps the model `source` to the completed -`ModelInstallJob`. +Block until all pending installs are completed or errored and then +returns a list of completed jobs. #### jobs = installer.list_jobs([source]) @@ -655,9 +654,10 @@ Return a list of all active and complete `ModelInstallJobs`. An optional `source` argument allows you to filter the returned list by a model source string pattern using a partial string match. -#### job = installer.get_job(source) +#### jobs = installer.get_job(source) -Return the `ModelInstallJob` corresponding to the indicated model source. +Return a list of `ModelInstallJob` corresponding to the indicated +model source. #### installer.prune_jobs diff --git a/invokeai/app/api/routers/model_records.py b/invokeai/app/api/routers/model_records.py index 9cae5e44ce..0efd539476 100644 --- a/invokeai/app/api/routers/model_records.py +++ b/invokeai/app/api/routers/model_records.py @@ -178,7 +178,6 @@ async def add_model_record( operation_id="import_model_record", responses={ 201: {"description": "The model imported successfully"}, - 404: {"description": "The model could not be found"}, 415: {"description": "Unrecognized file/folder format"}, 424: {"description": "The model appeared to import successfully, but could not be found in the model manager"}, 409: {"description": "There is already a model corresponding to this path or repo_id"}, @@ -258,7 +257,7 @@ async def import_model( logger.info(f"Started installation of {source}") except UnknownModelException as e: logger.error(str(e)) - raise HTTPException(status_code=404, detail=str(e)) + raise HTTPException(status_code=424, detail=str(e)) except InvalidModelException as e: logger.error(str(e)) raise HTTPException(status_code=415) diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 7f359772b3..1f62f2a740 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -112,6 +112,16 @@ class URLModelSource(StringLikeSource): return str(self.url) +# Body() is being applied here rather than Field() because otherwise FastAPI will +# refuse to generate a schema. Relevant links: +# +# "Model Manager Refactor Phase 1 - SQL-based config storage +# https://github.com/invoke-ai/InvokeAI/pull/5039#discussion_r1389752119 (comment) +# Param: xyz can only be a request body, using Body() when using discriminated unions +# https://github.com/tiangolo/fastapi/discussions/9761 +# Body parameter cannot be a pydantic union anymore sinve v0.95 +# https://github.com/tiangolo/fastapi/discussions/9287 + ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Body(discriminator="type")] ModelSourceValidator = TypeAdapter(ModelSource) @@ -263,8 +273,8 @@ class ModelInstallServiceBase(ABC): """ @abstractmethod - def get_job(self, source: ModelSource) -> ModelInstallJob: - """Return the ModelInstallJob corresponding to the provided source.""" + def get_job(self, source: ModelSource) -> List[ModelInstallJob]: + """Return the ModelInstallJob(s) corresponding to the provided source.""" @abstractmethod def list_jobs(self, source: Optional[ModelSource | str] = None) -> List[ModelInstallJob]: # noqa D102 @@ -279,7 +289,7 @@ class ModelInstallServiceBase(ABC): """Prune all completed and errored jobs.""" @abstractmethod - def wait_for_installs(self) -> Dict[ModelSource, ModelInstallJob]: + def wait_for_installs(self) -> List[ModelInstallJob]: """ Wait for all pending installs to complete. @@ -288,8 +298,7 @@ class ModelInstallServiceBase(ABC): block indefinitely if one or more jobs are in the paused state. - It will return a dict that maps the source model - path, URL or repo_id to the ID of the installed model. + It will return the current list of jobs. """ @abstractmethod @@ -305,12 +314,3 @@ class ModelInstallServiceBase(ABC): @abstractmethod def sync_to_config(self) -> None: """Synchronize models on disk to those in the model record database.""" - - @abstractmethod - def release(self) -> None: - """ - Signal the install thread to exit. - - This is useful if you are done with the installer and wish to - release its resources. - """ diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 657e4aa293..9ae1ee50c3 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -29,7 +29,6 @@ from .model_install_base import ( ModelInstallJob, ModelInstallServiceBase, ModelSource, - UnknownInstallJobException, ) # marker that the queue is done and that thread should exit @@ -46,7 +45,7 @@ class ModelInstallService(ModelInstallServiceBase): _record_store: ModelRecordServiceBase _event_bus: Optional[EventServiceBase] = None _install_queue: Queue[ModelInstallJob] - _install_jobs: Dict[ModelSource, ModelInstallJob] + _install_jobs: List[ModelInstallJob] _logger: Logger _cached_model_paths: Set[Path] _models_installed: Set[str] @@ -68,12 +67,16 @@ class ModelInstallService(ModelInstallServiceBase): self._record_store = record_store self._event_bus = event_bus self._logger = InvokeAILogger.get_logger(name=self.__class__.__name__) - self._install_jobs = {} + self._install_jobs = [] self._install_queue = Queue() self._cached_model_paths = set() self._models_installed = set() self._start_installer_thread() + def __del__(self) -> None: + """At GC time, we stop the install thread and release its resources.""" + self._install_queue.put(STOP_JOB) + @property def app_config(self) -> InvokeAIAppConfig: # noqa D102 return self._app_config @@ -189,7 +192,7 @@ class ModelInstallService(ModelInstallServiceBase): config_in=config, local_path=Path(source.path), ) - self._install_jobs[source] = job + self._install_jobs.append(job) self._install_queue.put(job) return job @@ -199,29 +202,23 @@ class ModelInstallService(ModelInstallServiceBase): def list_jobs(self, source: Optional[ModelSource | str] = None) -> List[ModelInstallJob]: # noqa D102 jobs = self._install_jobs if not source: - return list(jobs.values()) + return jobs else: - return [jobs[x] for x in jobs if str(source) in str(x)] + return [x for x in jobs if str(source) in str(x)] - def get_job(self, source: ModelSource) -> ModelInstallJob: # noqa D102 - try: - return self._install_jobs[source] - except KeyError: - raise UnknownInstallJobException(f"{source}: unknown install job") + def get_job(self, source: ModelSource) -> List[ModelInstallJob]: # noqa D102 + return [x for x in self._install_jobs if x.source == source] - def wait_for_installs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102 + def wait_for_installs(self) -> List[ModelInstallJob]: # noqa D102 self._install_queue.join() return self._install_jobs def prune_jobs(self) -> None: """Prune all completed and errored jobs.""" - finished_jobs = [ - source - for source in self._install_jobs - if self._install_jobs[source].status in [InstallStatus.COMPLETED, InstallStatus.ERROR] + unfinished_jobs = [ + x for x in self._install_jobs if x.status not in [InstallStatus.COMPLETED, InstallStatus.ERROR] ] - for source in finished_jobs: - del self._install_jobs[source] + self._install_jobs = unfinished_jobs def sync_to_config(self) -> None: """Synchronize models on disk to those in the config record store database.""" @@ -344,10 +341,6 @@ class ModelInstallService(ModelInstallServiceBase): path.unlink() self.unregister(key) - def release(self) -> None: - """Stop the install thread and release its resources.""" - self._install_queue.put(STOP_JOB) - def _copy_model(self, old_path: Path, new_path: Path) -> Path: if old_path == new_path: return old_path diff --git a/invokeai/backend/model_manager/search.py b/invokeai/backend/model_manager/search.py index 7492e471d3..45019bb103 100644 --- a/invokeai/backend/model_manager/search.py +++ b/invokeai/backend/model_manager/search.py @@ -115,7 +115,6 @@ class ModelSearch(ModelSearchBase): # returns all models that have 'anime' in the path """ - directory: Path = Field(default=None) models_found: Set[Path] = Field(default=None) scanned_dirs: Set[Path] = Field(default=None) pruned_paths: Set[Path] = Field(default=None) diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index 7d6656f23e..849d21188d 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -16,7 +16,6 @@ from invokeai.app.services.model_install import ( ModelInstallJob, ModelInstallService, ModelInstallServiceBase, - UnknownInstallJobException, ) from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException from invokeai.app.services.shared.sqlite import SqliteDatabase @@ -133,13 +132,14 @@ def test_background_install(installer: ModelInstallServiceBase, test_file: Path, assert isinstance(job, ModelInstallJob) # See if job is registered properly - assert installer.get_job(source) == job + assert job in installer.get_job(source) # test that the job object tracked installation correctly jobs = installer.wait_for_installs() - assert jobs[source] is not None - assert jobs[source] == job - assert jobs[source].status == InstallStatus.COMPLETED + assert len(jobs) > 0 + my_job = [x for x in jobs if x.source == source] + assert len(my_job) == 1 + assert my_job[0].status == InstallStatus.COMPLETED # test that the expected events were issued bus = installer.event_bus @@ -165,9 +165,7 @@ def test_background_install(installer: ModelInstallServiceBase, test_file: Path, # see if prune works properly installer.prune_jobs() - with pytest.raises(UnknownInstallJobException): - assert installer.get_job(source) - + assert not installer.get_job(source) def test_delete_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig): store = installer.record_store diff --git a/tests/app/services/model_install/test_model_install/README b/tests/app/services/model_install/test_model_install/README new file mode 100644 index 0000000000..af555d18a3 --- /dev/null +++ b/tests/app/services/model_install/test_model_install/README @@ -0,0 +1 @@ +This directory is used by pytest-datadir.