mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
implement review suggestions from @RyanjDick
This commit is contained in:
parent
018ccebd6f
commit
2b583ffcdf
@ -643,11 +643,10 @@ install_job = installer.import_model(
|
||||
|
||||
This section describes additional methods provided by the installer class.
|
||||
|
||||
#### source2job = installer.wait_for_installs()
|
||||
#### jobs = installer.wait_for_installs()
|
||||
|
||||
Block until all pending installs are completed or errored and return a
|
||||
dictionary that maps the model `source` to the completed
|
||||
`ModelInstallJob`.
|
||||
Block until all pending installs are completed or errored and then
|
||||
returns a list of completed jobs.
|
||||
|
||||
#### jobs = installer.list_jobs([source])
|
||||
|
||||
@ -655,9 +654,10 @@ Return a list of all active and complete `ModelInstallJobs`. An
|
||||
optional `source` argument allows you to filter the returned list by a
|
||||
model source string pattern using a partial string match.
|
||||
|
||||
#### job = installer.get_job(source)
|
||||
#### jobs = installer.get_job(source)
|
||||
|
||||
Return the `ModelInstallJob` corresponding to the indicated model source.
|
||||
Return a list of `ModelInstallJob` corresponding to the indicated
|
||||
model source.
|
||||
|
||||
#### installer.prune_jobs
|
||||
|
||||
|
@ -178,7 +178,6 @@ async def add_model_record(
|
||||
operation_id="import_model_record",
|
||||
responses={
|
||||
201: {"description": "The model imported successfully"},
|
||||
404: {"description": "The model could not be found"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
@ -258,7 +257,7 @@ async def import_model(
|
||||
logger.info(f"Started installation of {source}")
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
raise HTTPException(status_code=424, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
|
@ -112,6 +112,16 @@ class URLModelSource(StringLikeSource):
|
||||
return str(self.url)
|
||||
|
||||
|
||||
# Body() is being applied here rather than Field() because otherwise FastAPI will
|
||||
# refuse to generate a schema. Relevant links:
|
||||
#
|
||||
# "Model Manager Refactor Phase 1 - SQL-based config storage
|
||||
# https://github.com/invoke-ai/InvokeAI/pull/5039#discussion_r1389752119 (comment)
|
||||
# Param: xyz can only be a request body, using Body() when using discriminated unions
|
||||
# https://github.com/tiangolo/fastapi/discussions/9761
|
||||
# Body parameter cannot be a pydantic union anymore sinve v0.95
|
||||
# https://github.com/tiangolo/fastapi/discussions/9287
|
||||
|
||||
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Body(discriminator="type")]
|
||||
ModelSourceValidator = TypeAdapter(ModelSource)
|
||||
|
||||
@ -263,8 +273,8 @@ class ModelInstallServiceBase(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_job(self, source: ModelSource) -> ModelInstallJob:
|
||||
"""Return the ModelInstallJob corresponding to the provided source."""
|
||||
def get_job(self, source: ModelSource) -> List[ModelInstallJob]:
|
||||
"""Return the ModelInstallJob(s) corresponding to the provided source."""
|
||||
|
||||
@abstractmethod
|
||||
def list_jobs(self, source: Optional[ModelSource | str] = None) -> List[ModelInstallJob]: # noqa D102
|
||||
@ -279,7 +289,7 @@ class ModelInstallServiceBase(ABC):
|
||||
"""Prune all completed and errored jobs."""
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_installs(self) -> Dict[ModelSource, ModelInstallJob]:
|
||||
def wait_for_installs(self) -> List[ModelInstallJob]:
|
||||
"""
|
||||
Wait for all pending installs to complete.
|
||||
|
||||
@ -288,8 +298,7 @@ class ModelInstallServiceBase(ABC):
|
||||
block indefinitely if one or more jobs are in the
|
||||
paused state.
|
||||
|
||||
It will return a dict that maps the source model
|
||||
path, URL or repo_id to the ID of the installed model.
|
||||
It will return the current list of jobs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@ -305,12 +314,3 @@ class ModelInstallServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def sync_to_config(self) -> None:
|
||||
"""Synchronize models on disk to those in the model record database."""
|
||||
|
||||
@abstractmethod
|
||||
def release(self) -> None:
|
||||
"""
|
||||
Signal the install thread to exit.
|
||||
|
||||
This is useful if you are done with the installer and wish to
|
||||
release its resources.
|
||||
"""
|
||||
|
@ -29,7 +29,6 @@ from .model_install_base import (
|
||||
ModelInstallJob,
|
||||
ModelInstallServiceBase,
|
||||
ModelSource,
|
||||
UnknownInstallJobException,
|
||||
)
|
||||
|
||||
# marker that the queue is done and that thread should exit
|
||||
@ -46,7 +45,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
_record_store: ModelRecordServiceBase
|
||||
_event_bus: Optional[EventServiceBase] = None
|
||||
_install_queue: Queue[ModelInstallJob]
|
||||
_install_jobs: Dict[ModelSource, ModelInstallJob]
|
||||
_install_jobs: List[ModelInstallJob]
|
||||
_logger: Logger
|
||||
_cached_model_paths: Set[Path]
|
||||
_models_installed: Set[str]
|
||||
@ -68,12 +67,16 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._record_store = record_store
|
||||
self._event_bus = event_bus
|
||||
self._logger = InvokeAILogger.get_logger(name=self.__class__.__name__)
|
||||
self._install_jobs = {}
|
||||
self._install_jobs = []
|
||||
self._install_queue = Queue()
|
||||
self._cached_model_paths = set()
|
||||
self._models_installed = set()
|
||||
self._start_installer_thread()
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""At GC time, we stop the install thread and release its resources."""
|
||||
self._install_queue.put(STOP_JOB)
|
||||
|
||||
@property
|
||||
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
||||
return self._app_config
|
||||
@ -189,7 +192,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
config_in=config,
|
||||
local_path=Path(source.path),
|
||||
)
|
||||
self._install_jobs[source] = job
|
||||
self._install_jobs.append(job)
|
||||
self._install_queue.put(job)
|
||||
return job
|
||||
|
||||
@ -199,29 +202,23 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def list_jobs(self, source: Optional[ModelSource | str] = None) -> List[ModelInstallJob]: # noqa D102
|
||||
jobs = self._install_jobs
|
||||
if not source:
|
||||
return list(jobs.values())
|
||||
return jobs
|
||||
else:
|
||||
return [jobs[x] for x in jobs if str(source) in str(x)]
|
||||
return [x for x in jobs if str(source) in str(x)]
|
||||
|
||||
def get_job(self, source: ModelSource) -> ModelInstallJob: # noqa D102
|
||||
try:
|
||||
return self._install_jobs[source]
|
||||
except KeyError:
|
||||
raise UnknownInstallJobException(f"{source}: unknown install job")
|
||||
def get_job(self, source: ModelSource) -> List[ModelInstallJob]: # noqa D102
|
||||
return [x for x in self._install_jobs if x.source == source]
|
||||
|
||||
def wait_for_installs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
|
||||
def wait_for_installs(self) -> List[ModelInstallJob]: # noqa D102
|
||||
self._install_queue.join()
|
||||
return self._install_jobs
|
||||
|
||||
def prune_jobs(self) -> None:
|
||||
"""Prune all completed and errored jobs."""
|
||||
finished_jobs = [
|
||||
source
|
||||
for source in self._install_jobs
|
||||
if self._install_jobs[source].status in [InstallStatus.COMPLETED, InstallStatus.ERROR]
|
||||
unfinished_jobs = [
|
||||
x for x in self._install_jobs if x.status not in [InstallStatus.COMPLETED, InstallStatus.ERROR]
|
||||
]
|
||||
for source in finished_jobs:
|
||||
del self._install_jobs[source]
|
||||
self._install_jobs = unfinished_jobs
|
||||
|
||||
def sync_to_config(self) -> None:
|
||||
"""Synchronize models on disk to those in the config record store database."""
|
||||
@ -344,10 +341,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
path.unlink()
|
||||
self.unregister(key)
|
||||
|
||||
def release(self) -> None:
|
||||
"""Stop the install thread and release its resources."""
|
||||
self._install_queue.put(STOP_JOB)
|
||||
|
||||
def _copy_model(self, old_path: Path, new_path: Path) -> Path:
|
||||
if old_path == new_path:
|
||||
return old_path
|
||||
|
@ -115,7 +115,6 @@ class ModelSearch(ModelSearchBase):
|
||||
# returns all models that have 'anime' in the path
|
||||
"""
|
||||
|
||||
directory: Path = Field(default=None)
|
||||
models_found: Set[Path] = Field(default=None)
|
||||
scanned_dirs: Set[Path] = Field(default=None)
|
||||
pruned_paths: Set[Path] = Field(default=None)
|
||||
|
@ -16,7 +16,6 @@ from invokeai.app.services.model_install import (
|
||||
ModelInstallJob,
|
||||
ModelInstallService,
|
||||
ModelInstallServiceBase,
|
||||
UnknownInstallJobException,
|
||||
)
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
@ -133,13 +132,14 @@ def test_background_install(installer: ModelInstallServiceBase, test_file: Path,
|
||||
assert isinstance(job, ModelInstallJob)
|
||||
|
||||
# See if job is registered properly
|
||||
assert installer.get_job(source) == job
|
||||
assert job in installer.get_job(source)
|
||||
|
||||
# test that the job object tracked installation correctly
|
||||
jobs = installer.wait_for_installs()
|
||||
assert jobs[source] is not None
|
||||
assert jobs[source] == job
|
||||
assert jobs[source].status == InstallStatus.COMPLETED
|
||||
assert len(jobs) > 0
|
||||
my_job = [x for x in jobs if x.source == source]
|
||||
assert len(my_job) == 1
|
||||
assert my_job[0].status == InstallStatus.COMPLETED
|
||||
|
||||
# test that the expected events were issued
|
||||
bus = installer.event_bus
|
||||
@ -165,9 +165,7 @@ def test_background_install(installer: ModelInstallServiceBase, test_file: Path,
|
||||
|
||||
# see if prune works properly
|
||||
installer.prune_jobs()
|
||||
with pytest.raises(UnknownInstallJobException):
|
||||
assert installer.get_job(source)
|
||||
|
||||
assert not installer.get_job(source)
|
||||
|
||||
def test_delete_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
|
||||
store = installer.record_store
|
||||
|
@ -0,0 +1 @@
|
||||
This directory is used by pytest-datadir.
|
Loading…
Reference in New Issue
Block a user