From fe1038665cadd88f64bebf6015326293a8048a11 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 9 Oct 2023 00:28:21 -0400 Subject: [PATCH] address all PR 4252 comments from ryan through October 5 --- invokeai/app/api/routers/models.py | 102 +++---- invokeai/app/services/download_manager.py | 2 +- .../app/services/model_install_service.py | 2 +- invokeai/app/services/model_loader_service.py | 2 +- invokeai/backend/__init__.py | 2 - invokeai/backend/install/install_helper.py | 3 +- .../backend/install/invokeai_configure.py | 19 +- .../backend/model_manager/download/base.py | 2 +- .../backend/model_manager/download/queue.py | 148 ++++------ invokeai/backend/model_manager/hash.py | 8 +- invokeai/backend/model_manager/storage/sql.py | 270 ++++++++---------- .../backend/model_manager/storage/yaml.py | 40 +-- tests/AC_model_manager/test_model_download.py | 2 - 13 files changed, 260 insertions(+), 342 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 9e1b617bfa..3c86ce5052 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -98,13 +98,18 @@ async def update_model( ) -> InvokeAIModelConfig: """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" logger = ApiDependencies.invoker.services.logger + info_dict = info.dict() + record_store = ApiDependencies.invoker.services.model_record_store + model_install = ApiDependencies.invoker.services.model_installer + try: + new_config = record_store.update_model(key, config=info_dict) + except UnknownModelException as e: + raise HTTPException(status_code=404, detail=str(e)) + except ValueError as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) try: - info_dict = info.dict() - info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()} - record_store = ApiDependencies.invoker.services.model_record_store - model_install = ApiDependencies.invoker.services.model_installer - new_config = record_store.update_model(key, config=info_dict) # In the event that the model's name, type or base has changed, and the model itself # resides in the invokeai root models directory, then the next statement will move # the model file into its new canonical location. @@ -115,9 +120,6 @@ async def update_model( except ValueError as e: logger.error(str(e)) raise HTTPException(status_code=409, detail=str(e)) - except Exception as e: - logger.error(str(e)) - raise HTTPException(status_code=400, detail=str(e)) return model_response @@ -198,8 +200,8 @@ async def import_model( responses={ 201: {"description": "The model added successfully"}, 404: {"description": "The model could not be found"}, - 424: {"description": "The model appeared to add successfully, but could not be found in the model manager"}, 409: {"description": "There is already a model corresponding to this path or repo_id"}, + 415: {"description": "Unrecognized file/folder format"}, }, status_code=201, response_model=InvokeAIModelConfig, @@ -213,16 +215,22 @@ async def add_model( """ logger = ApiDependencies.invoker.services.logger + path = info.path + installer = ApiDependencies.invoker.services.model_installer + record_store = ApiDependencies.invoker.services.model_record_store try: - path = info.path - installer = ApiDependencies.invoker.services.model_installer - record_store = ApiDependencies.invoker.services.model_record_store key = installer.install_path(path) logger.info(f"Created model {key} for {path}") + except DuplicateModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + except InvalidModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=415) - # update with the provided info + # update with the provided info + try: info_dict = info.dict() - info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()} new_config = record_store.update_model(key, new_config=info_dict) return parse_obj_as(InvokeAIModelConfig, new_config.dict()) except UnknownModelException as e: @@ -405,24 +413,19 @@ async def merge_models( ) async def list_install_jobs() -> List[ModelImportStatus]: """List active and pending model installation jobs.""" - logger = ApiDependencies.invoker.services.logger job_mgr = ApiDependencies.invoker.services.model_installer - try: - jobs = job_mgr.list_install_jobs() - return [ - ModelImportStatus( - job_id=x.id, - source=x.source, - priority=x.priority, - bytes=x.bytes, - total_bytes=x.total_bytes, - status=x.status, - ) - for x in jobs - ] - except Exception as e: - logger.error(str(e)) - raise HTTPException(status_code=400, detail=str(e)) + jobs = job_mgr.list_install_jobs() + return [ + ModelImportStatus( + job_id=x.id, + source=x.source, + priority=x.priority, + bytes=x.bytes, + total_bytes=x.total_bytes, + status=x.status, + ) + for x in jobs + ] @models_router.patch( @@ -459,11 +462,11 @@ async def control_install_jobs( elif operation == JobControlOperation.CANCEL: job_mgr.cancel_job(job_id) - elif operation == JobControlOperation.CHANGE_PRIORITY: + elif operation == JobControlOperation.CHANGE_PRIORITY and priority_delta is not None: job_mgr.change_job_priority(job_id, priority_delta) else: - raise ValueError(f"Unknown operation {JobControlOperation.value}") + raise ValueError(f"Unknown operation {operation.value}") return ModelImportStatus( job_id=job_id, @@ -478,9 +481,6 @@ async def control_install_jobs( except ValueError as e: logger.error(str(e)) raise HTTPException(status_code=409, detail=str(e)) - except Exception as e: - logger.error(str(e)) - raise HTTPException(status_code=400, detail=str(e)) @models_router.patch( @@ -490,39 +490,27 @@ async def control_install_jobs( 204: {"description": "All jobs cancelled successfully"}, 400: {"description": "Bad request"}, }, - status_code=200, - response_model=ModelImportStatus, ) async def cancel_install_jobs(): - """Cancel all pending install jobs.""" + """Cancel all model installation jobs.""" logger = ApiDependencies.invoker.services.logger - try: - job_mgr = ApiDependencies.invoker.services.model_installer - logger.info("Cancelling all running model installation jobs.") - job_mgr.cancel_all_jobs() - return Response(status_code=204) - except Exception as e: - logger.error(str(e)) - raise HTTPException(status_code=400, detail=str(e)) + job_mgr = ApiDependencies.invoker.services.model_installer + logger.info("Cancelling all model installation jobs.") + job_mgr.cancel_all_jobs() + return Response(status_code=204) @models_router.patch( "/jobs/prune", operation_id="prune_jobs", responses={ - 204: {"description": "All jobs cancelled successfully"}, + 204: {"description": "All completed jobs have been pruned"}, 400: {"description": "Bad request"}, }, - status_code=200, - response_model=ModelImportStatus, ) async def prune_jobs(): """Prune all completed and errored jobs.""" logger = ApiDependencies.invoker.services.logger - try: - mgr = ApiDependencies.invoker.services.model_installer - mgr.prune_jobs() - return Response(status_code=204) - except Exception as e: - logger.error(str(e)) - raise HTTPException(status_code=400, detail=str(e)) + mgr = ApiDependencies.invoker.services.model_installer + mgr.prune_jobs() + return Response(status_code=204) diff --git a/invokeai/app/services/download_manager.py b/invokeai/app/services/download_manager.py index 27e9c0f9cb..c2455ff0c0 100644 --- a/invokeai/app/services/download_manager.py +++ b/invokeai/app/services/download_manager.py @@ -126,7 +126,7 @@ class DownloadQueueService(DownloadQueueServiceBase): e.g. `max_parallel_dl`. """ self._event_bus = event_bus - self._queue = DownloadQueue() + self._queue = DownloadQueue(**kwargs) def create_download_job( self, diff --git a/invokeai/app/services/model_install_service.py b/invokeai/app/services/model_install_service.py index 4f859cbe59..1f3f9bc994 100644 --- a/invokeai/app/services/model_install_service.py +++ b/invokeai/app/services/model_install_service.py @@ -92,7 +92,7 @@ class ModelInstallServiceBase(ModelInstallBase): # This is an ABC @abstractmethod def cancel_all_jobs(self): - """Cancel all active jobs.""" + """Cancel all installation jobs.""" pass @abstractmethod diff --git a/invokeai/app/services/model_loader_service.py b/invokeai/app/services/model_loader_service.py index 93a69df0c2..8b2de15176 100644 --- a/invokeai/app/services/model_loader_service.py +++ b/invokeai/app/services/model_loader_service.py @@ -62,7 +62,7 @@ class ModelLoadServiceBase(ABC): class ModelLoadService(ModelLoadServiceBase): """Responsible for managing models on disk and in memory.""" - _loader: ModelLoad = Field(description="InvokeAIAppConfig object for the current process") + _loader: ModelLoad def __init__( self, diff --git a/invokeai/backend/__init__.py b/invokeai/backend/__init__.py index 49c71f024e..420b90d7b4 100644 --- a/invokeai/backend/__init__.py +++ b/invokeai/backend/__init__.py @@ -12,6 +12,4 @@ from .model_manager import ( # noqa F401 SilenceWarnings, SubModelType, ) -from .model_manager.install import ModelInstall # noqa F401 -from .model_manager.loader import ModelLoad # noqa F401 from .util.devices import get_precision # noqa F401 diff --git a/invokeai/backend/install/install_helper.py b/invokeai/backend/install/install_helper.py index 0d2e1d81fb..8770ee2c2d 100644 --- a/invokeai/backend/install/install_helper.py +++ b/invokeai/backend/install/install_helper.py @@ -18,7 +18,6 @@ from invokeai.backend.model_manager.install import ModelInstall, ModelInstallJob # name of the starter models file INITIAL_MODELS = "INITIAL_MODELS.yaml" -ACCESS_TOKEN = HfFolder.get_token() class UnifiedModelInfo(BaseModel): @@ -173,7 +172,7 @@ class InstallHelper(object): model.source, subfolder=model.subfolder, variant="fp16" if self._config.precision == "float16" else None, - access_token=ACCESS_TOKEN, # this is a global, + access_token=HfFolder.get_token(), metadata=metadata, ) diff --git a/invokeai/backend/install/invokeai_configure.py b/invokeai/backend/install/invokeai_configure.py index 4e15a023ff..3029b27a8b 100755 --- a/invokeai/backend/install/invokeai_configure.py +++ b/invokeai/backend/install/invokeai_configure.py @@ -185,14 +185,17 @@ class ProgressBar: # --------------------------------------------- def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs): - logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage()) - - model = model_class.from_pretrained( - model_name, - resume_download=True, - **kwargs, - ) - model.save_pretrained(destination, safe_serialization=True) + filter = lambda x: "fp16 is not a valid" not in x.getMessage() + logger.addFilter(filter) + try: + model = model_class.from_pretrained( + model_name, + resume_download=True, + **kwargs, + ) + model.save_pretrained(destination, safe_serialization=True) + finally: + logger.removeFilter(filter) return destination diff --git a/invokeai/backend/model_manager/download/base.py b/invokeai/backend/model_manager/download/base.py index 56998a806a..dc53771f9b 100644 --- a/invokeai/backend/model_manager/download/base.py +++ b/invokeai/backend/model_manager/download/base.py @@ -212,7 +212,7 @@ class DownloadQueueBase(ABC): @abstractmethod def cancel_all_jobs(self, preserve_partial: bool = False): """ - Cancel all active and enquedjobs. + Cancel all jobs (those in enqueued, running and paused states). :param preserve_partial: Keep partially downloaded files [False]. """ diff --git a/invokeai/backend/model_manager/download/queue.py b/invokeai/backend/model_manager/download/queue.py index f25424a44f..bd5adceec8 100644 --- a/invokeai/backend/model_manager/download/queue.py +++ b/invokeai/backend/model_manager/download/queue.py @@ -12,8 +12,7 @@ from queue import PriorityQueue from typing import Callable, Dict, List, Optional, Set, Tuple, Union import requests -from huggingface_hub import HfApi, hf_hub_url -from pydantic import Field, parse_obj_as, validator +from pydantic import Field from pydantic.networks import AnyHttpUrl from requests import HTTPError @@ -59,7 +58,7 @@ class DownloadQueue(DownloadQueueBase): _jobs: Dict[int, DownloadJobBase] _worker_pool: Set[threading.Thread] _queue: PriorityQueue - _lock: threading.RLock + _lock: threading.RLock # to allow methods called within the same thread to lock without blocking _logger: Logger _event_handlers: List[DownloadEventHandler] = Field(default_factory=list) _next_job_id: int = 0 @@ -136,13 +135,10 @@ class DownloadQueue(DownloadQueueBase): # add the queue's handlers for handler in self._event_handlers: job.add_event_handler(handler) - try: - self._lock.acquire() + with self._lock: job.id = self._next_job_id self._jobs[job.id] = job self._next_job_id += 1 - finally: - self._lock.release() if start: self.start_job(job) return job @@ -163,29 +159,25 @@ class DownloadQueue(DownloadQueueBase): def change_priority(self, job: DownloadJobBase, delta: int): """Change the priority of a job. Smaller priorities run first.""" - try: - self._lock.acquire() - assert isinstance(self._jobs[job.id], DownloadJobBase) - job.priority += delta - except (AssertionError, KeyError) as excp: - raise UnknownJobIDException("Unrecognized job") from excp - finally: - self._lock.release() + with self._lock: + try: + assert isinstance(self._jobs[job.id], DownloadJobBase) + job.priority += delta + except (AssertionError, KeyError) as excp: + raise UnknownJobIDException("Unrecognized job") from excp def prune_jobs(self): """Prune completed and errored queue items from the job list.""" - try: + with self._lock: to_delete = set() - self._lock.acquire() - for job_id, job in self._jobs.items(): - if self._in_terminal_state(job): - to_delete.add(job_id) - for job_id in to_delete: - del self._jobs[job_id] - except KeyError as excp: - raise UnknownJobIDException("Unrecognized job") from excp - finally: - self._lock.release() + try: + for job_id, job in self._jobs.items(): + if self._in_terminal_state(job): + to_delete.add(job_id) + for job_id in to_delete: + del self._jobs[job_id] + except KeyError as excp: + raise UnknownJobIDException("Unrecognized job") from excp def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False): """ @@ -194,16 +186,14 @@ class DownloadQueue(DownloadQueueBase): If it is running it will be stopped. job.status will be set to DownloadJobStatus.CANCELLED """ - try: - self._lock.acquire() - assert isinstance(self._jobs[job.id], DownloadJobBase) - job.preserve_partial_downloads = preserve_partial - self._update_job_status(job, DownloadJobStatus.CANCELLED) - job.cleanup() - except (AssertionError, KeyError) as excp: - raise UnknownJobIDException("Unrecognized job") from excp - finally: - self._lock.release() + with self._lock: + try: + assert isinstance(self._jobs[job.id], DownloadJobBase) + job.preserve_partial_downloads = preserve_partial + self._update_job_status(job, DownloadJobStatus.CANCELLED) + job.cleanup() + except (AssertionError, KeyError) as excp: + raise UnknownJobIDException("Unrecognized job") from excp def id_to_job(self, id: int) -> DownloadJobBase: """Translate a job ID into a DownloadJobBase object.""" @@ -214,12 +204,13 @@ class DownloadQueue(DownloadQueueBase): def start_job(self, job: DownloadJobBase): """Enqueue (start) the indicated job.""" - try: - assert isinstance(self._jobs[job.id], DownloadJobBase) - self._update_job_status(job, DownloadJobStatus.ENQUEUED) - self._queue.put(job) - except (AssertionError, KeyError) as excp: - raise UnknownJobIDException("Unrecognized job") from excp + with self._lock: + try: + assert isinstance(self._jobs[job.id], DownloadJobBase) + self._update_job_status(job, DownloadJobStatus.ENQUEUED) + self._queue.put(job) + except (AssertionError, KeyError) as excp: + raise UnknownJobIDException("Unrecognized job") from excp def pause_job(self, job: DownloadJobBase): """ @@ -228,45 +219,34 @@ class DownloadQueue(DownloadQueueBase): The job can be restarted with start_job() and the download will pick up from where it left off. """ - try: - self._lock.acquire() - assert isinstance(self._jobs[job.id], DownloadJobBase) - self._update_job_status(job, DownloadJobStatus.PAUSED) - job.cleanup() - except (AssertionError, KeyError) as excp: - raise UnknownJobIDException("Unrecognized job") from excp - finally: - self._lock.release() + with self._lock: + try: + assert isinstance(self._jobs[job.id], DownloadJobBase) + self._update_job_status(job, DownloadJobStatus.PAUSED) + job.cleanup() + except (AssertionError, KeyError) as excp: + raise UnknownJobIDException("Unrecognized job") from excp def start_all_jobs(self): """Start (enqueue) all jobs that are idle or paused.""" - try: - self._lock.acquire() + with self._lock: for job in self._jobs.values(): if job.status in [DownloadJobStatus.IDLE, DownloadJobStatus.PAUSED]: self.start_job(job) - finally: - self._lock.release() def pause_all_jobs(self): """Pause all running jobs.""" - try: - self._lock.acquire() + with self._lock: for job in self._jobs.values(): if job.status == DownloadJobStatus.RUNNING: self.pause_job(job) - finally: - self._lock.release() def cancel_all_jobs(self, preserve_partial: bool = False): - """Cancel all running jobs.""" - try: - self._lock.acquire() + """Cancel all jobs (those not in enqueued, running or paused state).""" + with self._lock: for job in self._jobs.values(): if not self._in_terminal_state(job): self.cancel_job(job, preserve_partial) - finally: - self._lock.release() def _in_terminal_state(self, job: DownloadJobBase): return job.status in [ @@ -288,26 +268,26 @@ class DownloadQueue(DownloadQueueBase): while not done: job = self._queue.get() - try: # this is for debugging priority - self._lock.acquire() + with self._lock: job.job_sequence = self._sequence self._sequence += 1 + + try: + if job == STOP_JOB: # marker that queue is done + done = True + + if ( + job.status == DownloadJobStatus.ENQUEUED + ): # Don't do anything for non-enqueued jobs (shouldn't happen) + if not self._quiet: + self._logger.info(f"{job.source}: Downloading to {job.destination}") + do_download = self.select_downloader(job) + do_download(job) + + if job.status == DownloadJobStatus.CANCELLED: + self._cleanup_cancelled_job(job) finally: - self._lock.release() - - if job == STOP_JOB: # marker that queue is done - done = True - - if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for non-enqueued jobs (shouldn't happen) - if not self._quiet: - self._logger.info(f"{job.source}: Downloading to {job.destination}") - do_download = self.select_downloader(job) - do_download(job) - - if job.status == DownloadJobStatus.CANCELLED: - self._cleanup_cancelled_job(job) - - self._queue.task_done() + self._queue.task_done() def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]: """Based on the job type select the download method.""" @@ -397,11 +377,7 @@ class DownloadQueue(DownloadQueueBase): self._update_job_status(job, DownloadJobStatus.COMPLETED) except KeyboardInterrupt as excp: raise excp - except DuplicateModelException as excp: - self._logger.error(f"A model with the same hash as {dest} is already installed.") - job.error = excp - self._update_job_status(job, DownloadJobStatus.ERROR) - except Exception as excp: + except (HTTPError, OSError) as excp: self._logger.error(f"An error occurred while downloading/installing {job.source}: {str(excp)}") print(traceback.format_exc()) job.error = excp diff --git a/invokeai/backend/model_manager/hash.py b/invokeai/backend/model_manager/hash.py index c7ebe2628b..e445fa03ab 100644 --- a/invokeai/backend/model_manager/hash.py +++ b/invokeai/backend/model_manager/hash.py @@ -15,6 +15,8 @@ from typing import Dict, Union from imohash import hashfile +from .models import InvalidModelException + class FastModelHash(object): """FastModelHash obect provides one public class method, hash().""" @@ -32,9 +34,6 @@ class FastModelHash(object): elif model_location.is_dir(): return cls._hash_dir(model_location) else: - # avoid circular import - from .models import InvalidModelException - raise InvalidModelException(f"Not a valid file or directory: {model_location}") @classmethod @@ -54,7 +53,8 @@ class FastModelHash(object): for root, dirs, files in os.walk(model_location): for file in files: - # only tally tensor files + # only tally tensor files because diffusers config files change slightly + # depending on how the model was downloaded/converted. if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")): continue path = (Path(root) / file).as_posix() diff --git a/invokeai/backend/model_manager/storage/sql.py b/invokeai/backend/model_manager/storage/sql.py index 73972c5b96..e76cd87071 100644 --- a/invokeai/backend/model_manager/storage/sql.py +++ b/invokeai/backend/model_manager/storage/sql.py @@ -72,14 +72,11 @@ class ModelConfigStoreSQL(ModelConfigStore): self._cursor = self._conn.cursor() self._lock = lock - try: - self._lock.acquire() + with self._lock: # Enable foreign keys self._conn.execute("PRAGMA foreign_keys = ON;") self._create_tables() self._conn.commit() - finally: - self._lock.release() assert ( str(self.version) == CONFIG_FILE_VERSION ), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}" @@ -189,52 +186,49 @@ class ModelConfigStoreSQL(ModelConfigStore): """ record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect. json_serialized = json.dumps(record.dict()) # and turn it into a json string. - try: - self._lock.acquire() - self._cursor.execute( - """--sql - INSERT INTO model_config ( - id, - base_model, - model_type, - model_name, - model_path, - config - ) - VALUES (?,?,?,?,?,?); - """, - ( - key, - record.base_model, - record.model_type, - record.name, - record.path, - json_serialized, - ), - ) - if record.tags: - self._update_tags(key, record.tags) - self._conn.commit() + with self._lock: + try: + self._cursor.execute( + """--sql + INSERT INTO model_config ( + id, + base_model, + model_type, + model_name, + model_path, + config + ) + VALUES (?,?,?,?,?,?); + """, + ( + key, + record.base_model, + record.model_type, + record.name, + record.path, + json_serialized, + ), + ) + if record.tags: + self._update_tags(key, record.tags) + self._conn.commit() - except sqlite3.IntegrityError as e: - self._conn.rollback() - if "UNIQUE constraint failed" in str(e): - raise DuplicateModelException(f"A model with key '{key}' is already installed") from e - else: + except sqlite3.IntegrityError as e: + self._conn.rollback() + if "UNIQUE constraint failed" in str(e): + raise DuplicateModelException(f"A model with key '{key}' is already installed") from e + else: + raise e + except sqlite3.Error as e: + self._conn.rollback() raise e - except sqlite3.Error as e: - self._conn.rollback() - raise e - finally: - self._lock.release() return self.get_model(key) @property def version(self) -> str: """Return the version of the database schema.""" - try: - self._lock.acquire() + with self._lock: self._cursor.execute( """--sql SELECT metadata_value FROM model_manager_metadata @@ -246,8 +240,6 @@ class ModelConfigStoreSQL(ModelConfigStore): if not rows: raise KeyError("Models database does not have metadata key 'version'") return rows[0] - finally: - self._lock.release() def _update_tags(self, key: str, tags: List[str]) -> None: """Update tags for model with key.""" @@ -301,23 +293,21 @@ class ModelConfigStoreSQL(ModelConfigStore): Can raise an UnknownModelException """ - try: - self._lock.acquire() - self._cursor.execute( - """--sql - DELETE FROM model_config - WHERE id=?; - """, - (key,), - ) - if self._cursor.rowcount == 0: - raise UnknownModelException - self._conn.commit() - except sqlite3.Error as e: - self._conn.rollback() - raise e - finally: - self._lock.release() + with self._lock: + try: + self._cursor.execute( + """--sql + DELETE FROM model_config + WHERE id=?; + """, + (key,), + ) + if self._cursor.rowcount == 0: + raise UnknownModelException + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise e def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase: """ @@ -329,30 +319,29 @@ class ModelConfigStoreSQL(ModelConfigStore): """ record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect json_serialized = json.dumps(record.dict()) # and turn it into a json string. - try: - self._lock.acquire() - self._cursor.execute( - """--sql - UPDATE model_config - SET base_model=?, - model_type=?, - model_name=?, - model_path=?, - config=? - WHERE id=?; - """, - (record.base_model, record.model_type, record.name, record.path, json_serialized, key), - ) - if self._cursor.rowcount == 0: - raise UnknownModelException - if record.tags: - self._update_tags(key, record.tags) - self._conn.commit() - except sqlite3.Error as e: - self._conn.rollback() - raise e - finally: - self._lock.release() + with self._lock: + try: + self._cursor.execute( + """--sql + UPDATE model_config + SET base_model=?, + model_type=?, + model_name=?, + model_path=?, + config=? + WHERE id=?; + """, + (record.base_model, record.model_type, record.name, record.path, json_serialized, key), + ) + if self._cursor.rowcount == 0: + raise UnknownModelException + if record.tags: + self._update_tags(key, record.tags) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise e + return self.get_model(key) def get_model(self, key: str) -> AnyModelConfig: @@ -363,8 +352,7 @@ class ModelConfigStoreSQL(ModelConfigStore): Exceptions: UnknownModelException """ - try: - self._lock.acquire() + with self._lock: self._cursor.execute( """--sql SELECT config FROM model_config @@ -376,8 +364,6 @@ class ModelConfigStoreSQL(ModelConfigStore): if not rows: raise UnknownModelException model = ModelConfigFactory.make_config(json.loads(rows[0])) - finally: - self._lock.release() return model def exists(self, key: str) -> bool: @@ -387,20 +373,18 @@ class ModelConfigStoreSQL(ModelConfigStore): :param key: Unique key for the model to be deleted """ count = 0 - try: - self._lock.acquire() - self._cursor.execute( - """--sql - select count(*) FROM model_config - WHERE id=?; - """, - (key,), - ) - count = self._cursor.fetchone()[0] - except sqlite3.Error as e: - raise e - finally: - self._lock.release() + with self._lock: + try: + self._cursor.execute( + """--sql + select count(*) FROM model_config + WHERE id=?; + """, + (key,), + ) + count = self._cursor.fetchone()[0] + except sqlite3.Error as e: + raise e return count > 0 def search_by_tag(self, tags: Set[str]) -> List[AnyModelConfig]: @@ -408,34 +392,32 @@ class ModelConfigStoreSQL(ModelConfigStore): # rather than create a hairy SQL cross-product, we intersect # tag results in a stepwise fashion at the python level. results = [] - try: - self._lock.acquire() - matches: Set[str] = set() - for tag in tags: - self._cursor.execute( - """--sql - SELECT a.id FROM model_tag AS a, - tags AS b - WHERE a.tag_id=b.tag_id - AND b.tag_text=?; - """, - (tag,), - ) - model_keys = {x[0] for x in self._cursor.fetchall()} - matches = matches.intersection(model_keys) if len(matches) > 0 else model_keys - if matches: - self._cursor.execute( - f"""--sql - SELECT config FROM model_config - WHERE id IN ({','.join('?' * len(matches))}); - """, - tuple(matches), - ) - results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] - except sqlite3.Error as e: - raise e - finally: - self._lock.release() + with self._lock: + try: + matches: Set[str] = set() + for tag in tags: + self._cursor.execute( + """--sql + SELECT a.id FROM model_tag AS a, + tags AS b + WHERE a.tag_id=b.tag_id + AND b.tag_text=?; + """, + (tag,), + ) + model_keys = {x[0] for x in self._cursor.fetchall()} + matches = matches.intersection(model_keys) if len(matches) > 0 else model_keys + if matches: + self._cursor.execute( + f"""--sql + SELECT config FROM model_config + WHERE id IN ({','.join('?' * len(matches))}); + """, + tuple(matches), + ) + results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] + except sqlite3.Error as e: + raise e return results def search_by_name( @@ -467,20 +449,18 @@ class ModelConfigStoreSQL(ModelConfigStore): where_clause.append("model_type=?") bindings.append(model_type) where = f"WHERE {' AND '.join(where_clause)}" if where_clause else "" - try: - self._lock.acquire() - self._cursor.execute( - f"""--sql - select config FROM model_config - {where}; - """, - tuple(bindings), - ) - results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] - except sqlite3.Error as e: - raise e - finally: - self._lock.release() + with self._lock: + try: + self._cursor.execute( + f"""--sql + select config FROM model_config + {where}; + """, + tuple(bindings), + ) + results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] + except sqlite3.Error as e: + raise e return results def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]: diff --git a/invokeai/backend/model_manager/storage/yaml.py b/invokeai/backend/model_manager/storage/yaml.py index a59c934d42..5eb5b66b3f 100644 --- a/invokeai/backend/model_manager/storage/yaml.py +++ b/invokeai/backend/model_manager/storage/yaml.py @@ -80,24 +80,18 @@ class ModelConfigStoreYAML(ModelConfigStore): raise ConfigFileVersionMismatchException def _initialize_yaml(self): - try: - self._lock.acquire() + with self._lock: self._filename.parent.mkdir(parents=True, exist_ok=True) with open(self._filename, "w") as yaml_file: yaml_file.write(yaml.dump({"__metadata__": {"version": CONFIG_FILE_VERSION}})) - finally: - self._lock.release() def _commit(self): - try: - self._lock.acquire() + with self._lock: newfile = Path(str(self._filename) + ".new") yaml_str = OmegaConf.to_yaml(self._config) with open(newfile, "w", encoding="utf-8") as outfile: outfile.write(yaml_str) newfile.replace(self._filename) - finally: - self._lock.release() @property def version(self) -> str: @@ -116,8 +110,7 @@ class ModelConfigStoreYAML(ModelConfigStore): """ record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect dict_fields = record.dict() # and back to a dict with valid fields - try: - self._lock.acquire() + with self._lock: if key in self._config: existing_model = self.get_model(key) raise DuplicateModelException( @@ -125,8 +118,6 @@ class ModelConfigStoreYAML(ModelConfigStore): ) self._config[key] = self._fix_enums(dict_fields) self._commit() - finally: - self._lock.release() return self.get_model(key) def _fix_enums(self, original: dict) -> dict: @@ -144,14 +135,11 @@ class ModelConfigStoreYAML(ModelConfigStore): Can raise an UnknownModelException """ - try: - self._lock.acquire() + with self._lock: if key not in self._config: raise UnknownModelException(f"Unknown key '{key}' for model config") self._config.pop(key) self._commit() - finally: - self._lock.release() def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase: """ @@ -163,14 +151,11 @@ class ModelConfigStoreYAML(ModelConfigStore): """ record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect dict_fields = record.dict() # and back to a dict with valid fields - try: - self._lock.acquire() + with self._lock: if key not in self._config: raise UnknownModelException(f"Unknown key '{key}' for model config") self._config[key] = self._fix_enums(dict_fields) self._commit() - finally: - self._lock.release() return self.get_model(key) def get_model(self, key: str) -> AnyModelConfig: @@ -203,15 +188,12 @@ class ModelConfigStoreYAML(ModelConfigStore): """ results = [] tags = set(tags) - try: - self._lock.acquire() + with self._lock: for config in self.all_models(): config_tags = set(config.tags or []) if tags.difference(config_tags): # not all tags in the model continue results.append(config) - finally: - self._lock.release() return results def search_by_name( @@ -231,8 +213,7 @@ class ModelConfigStoreYAML(ModelConfigStore): models in the database. """ results: List[ModelConfigBase] = list() - try: - self._lock.acquire() + with self._lock: for key, record in self._config.items(): if key == "__metadata__": continue @@ -244,20 +225,15 @@ class ModelConfigStoreYAML(ModelConfigStore): if model_type and model.model_type != model_type: continue results.append(model) - finally: - self._lock.release() return results def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]: """Return the model with the indicated path, or None.""" - try: - self._lock.acquire() + with self._lock: for key, record in self._config.items(): if key == "__metadata__": continue model = ModelConfigFactory.make_config(record, str(key)) if model.path == path: return model - finally: - self._lock.release() return None diff --git a/tests/AC_model_manager/test_model_download.py b/tests/AC_model_manager/test_model_download.py index 3ef7d0c930..8ea77d164c 100644 --- a/tests/AC_model_manager/test_model_download.py +++ b/tests/AC_model_manager/test_model_download.py @@ -270,8 +270,6 @@ def test_bad_urls(): def test_pause_cancel_url(): # this one is tricky because of potential race conditions def event_handler(job: DownloadJobBase): - if job.id == 0: - print(job.status, job.bytes) time.sleep(0.5) # slow down the thread so that we can recover the paused state queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler])