address all PR 4252 comments from ryan through October 5

This commit is contained in:
Lincoln Stein
2023-10-09 00:28:21 -04:00
parent ce2baa36a9
commit fe1038665c
13 changed files with 260 additions and 342 deletions

View File

@ -98,13 +98,18 @@ async def update_model(
) -> InvokeAIModelConfig: ) -> InvokeAIModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
info_dict = info.dict()
record_store = ApiDependencies.invoker.services.model_record_store
model_install = ApiDependencies.invoker.services.model_installer
try:
new_config = record_store.update_model(key, config=info_dict)
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
try: try:
info_dict = info.dict()
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
record_store = ApiDependencies.invoker.services.model_record_store
model_install = ApiDependencies.invoker.services.model_installer
new_config = record_store.update_model(key, config=info_dict)
# In the event that the model's name, type or base has changed, and the model itself # In the event that the model's name, type or base has changed, and the model itself
# resides in the invokeai root models directory, then the next statement will move # resides in the invokeai root models directory, then the next statement will move
# the model file into its new canonical location. # the model file into its new canonical location.
@ -115,9 +120,6 @@ async def update_model(
except ValueError as e: except ValueError as e:
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e))
return model_response return model_response
@ -198,8 +200,8 @@ async def import_model(
responses={ responses={
201: {"description": "The model added successfully"}, 201: {"description": "The model added successfully"},
404: {"description": "The model could not be found"}, 404: {"description": "The model could not be found"},
424: {"description": "The model appeared to add successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"}, 409: {"description": "There is already a model corresponding to this path or repo_id"},
415: {"description": "Unrecognized file/folder format"},
}, },
status_code=201, status_code=201,
response_model=InvokeAIModelConfig, response_model=InvokeAIModelConfig,
@ -213,16 +215,22 @@ async def add_model(
""" """
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
path = info.path
installer = ApiDependencies.invoker.services.model_installer
record_store = ApiDependencies.invoker.services.model_record_store
try: try:
path = info.path
installer = ApiDependencies.invoker.services.model_installer
record_store = ApiDependencies.invoker.services.model_record_store
key = installer.install_path(path) key = installer.install_path(path)
logger.info(f"Created model {key} for {path}") logger.info(f"Created model {key} for {path}")
except DuplicateModelException as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
# update with the provided info # update with the provided info
try:
info_dict = info.dict() info_dict = info.dict()
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
new_config = record_store.update_model(key, new_config=info_dict) new_config = record_store.update_model(key, new_config=info_dict)
return parse_obj_as(InvokeAIModelConfig, new_config.dict()) return parse_obj_as(InvokeAIModelConfig, new_config.dict())
except UnknownModelException as e: except UnknownModelException as e:
@ -405,24 +413,19 @@ async def merge_models(
) )
async def list_install_jobs() -> List[ModelImportStatus]: async def list_install_jobs() -> List[ModelImportStatus]:
"""List active and pending model installation jobs.""" """List active and pending model installation jobs."""
logger = ApiDependencies.invoker.services.logger
job_mgr = ApiDependencies.invoker.services.model_installer job_mgr = ApiDependencies.invoker.services.model_installer
try: jobs = job_mgr.list_install_jobs()
jobs = job_mgr.list_install_jobs() return [
return [ ModelImportStatus(
ModelImportStatus( job_id=x.id,
job_id=x.id, source=x.source,
source=x.source, priority=x.priority,
priority=x.priority, bytes=x.bytes,
bytes=x.bytes, total_bytes=x.total_bytes,
total_bytes=x.total_bytes, status=x.status,
status=x.status, )
) for x in jobs
for x in jobs ]
]
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e))
@models_router.patch( @models_router.patch(
@ -459,11 +462,11 @@ async def control_install_jobs(
elif operation == JobControlOperation.CANCEL: elif operation == JobControlOperation.CANCEL:
job_mgr.cancel_job(job_id) job_mgr.cancel_job(job_id)
elif operation == JobControlOperation.CHANGE_PRIORITY: elif operation == JobControlOperation.CHANGE_PRIORITY and priority_delta is not None:
job_mgr.change_job_priority(job_id, priority_delta) job_mgr.change_job_priority(job_id, priority_delta)
else: else:
raise ValueError(f"Unknown operation {JobControlOperation.value}") raise ValueError(f"Unknown operation {operation.value}")
return ModelImportStatus( return ModelImportStatus(
job_id=job_id, job_id=job_id,
@ -478,9 +481,6 @@ async def control_install_jobs(
except ValueError as e: except ValueError as e:
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e))
@models_router.patch( @models_router.patch(
@ -490,39 +490,27 @@ async def control_install_jobs(
204: {"description": "All jobs cancelled successfully"}, 204: {"description": "All jobs cancelled successfully"},
400: {"description": "Bad request"}, 400: {"description": "Bad request"},
}, },
status_code=200,
response_model=ModelImportStatus,
) )
async def cancel_install_jobs(): async def cancel_install_jobs():
"""Cancel all pending install jobs.""" """Cancel all model installation jobs."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: job_mgr = ApiDependencies.invoker.services.model_installer
job_mgr = ApiDependencies.invoker.services.model_installer logger.info("Cancelling all model installation jobs.")
logger.info("Cancelling all running model installation jobs.") job_mgr.cancel_all_jobs()
job_mgr.cancel_all_jobs() return Response(status_code=204)
return Response(status_code=204)
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e))
@models_router.patch( @models_router.patch(
"/jobs/prune", "/jobs/prune",
operation_id="prune_jobs", operation_id="prune_jobs",
responses={ responses={
204: {"description": "All jobs cancelled successfully"}, 204: {"description": "All completed jobs have been pruned"},
400: {"description": "Bad request"}, 400: {"description": "Bad request"},
}, },
status_code=200,
response_model=ModelImportStatus,
) )
async def prune_jobs(): async def prune_jobs():
"""Prune all completed and errored jobs.""" """Prune all completed and errored jobs."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: mgr = ApiDependencies.invoker.services.model_installer
mgr = ApiDependencies.invoker.services.model_installer mgr.prune_jobs()
mgr.prune_jobs() return Response(status_code=204)
return Response(status_code=204)
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e))

View File

@ -126,7 +126,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
e.g. `max_parallel_dl`. e.g. `max_parallel_dl`.
""" """
self._event_bus = event_bus self._event_bus = event_bus
self._queue = DownloadQueue() self._queue = DownloadQueue(**kwargs)
def create_download_job( def create_download_job(
self, self,

View File

@ -92,7 +92,7 @@ class ModelInstallServiceBase(ModelInstallBase): # This is an ABC
@abstractmethod @abstractmethod
def cancel_all_jobs(self): def cancel_all_jobs(self):
"""Cancel all active jobs.""" """Cancel all installation jobs."""
pass pass
@abstractmethod @abstractmethod

View File

@ -62,7 +62,7 @@ class ModelLoadServiceBase(ABC):
class ModelLoadService(ModelLoadServiceBase): class ModelLoadService(ModelLoadServiceBase):
"""Responsible for managing models on disk and in memory.""" """Responsible for managing models on disk and in memory."""
_loader: ModelLoad = Field(description="InvokeAIAppConfig object for the current process") _loader: ModelLoad
def __init__( def __init__(
self, self,

View File

@ -12,6 +12,4 @@ from .model_manager import ( # noqa F401
SilenceWarnings, SilenceWarnings,
SubModelType, SubModelType,
) )
from .model_manager.install import ModelInstall # noqa F401
from .model_manager.loader import ModelLoad # noqa F401
from .util.devices import get_precision # noqa F401 from .util.devices import get_precision # noqa F401

View File

@ -18,7 +18,6 @@ from invokeai.backend.model_manager.install import ModelInstall, ModelInstallJob
# name of the starter models file # name of the starter models file
INITIAL_MODELS = "INITIAL_MODELS.yaml" INITIAL_MODELS = "INITIAL_MODELS.yaml"
ACCESS_TOKEN = HfFolder.get_token()
class UnifiedModelInfo(BaseModel): class UnifiedModelInfo(BaseModel):
@ -173,7 +172,7 @@ class InstallHelper(object):
model.source, model.source,
subfolder=model.subfolder, subfolder=model.subfolder,
variant="fp16" if self._config.precision == "float16" else None, variant="fp16" if self._config.precision == "float16" else None,
access_token=ACCESS_TOKEN, # this is a global, access_token=HfFolder.get_token(),
metadata=metadata, metadata=metadata,
) )

View File

@ -185,14 +185,17 @@ class ProgressBar:
# --------------------------------------------- # ---------------------------------------------
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs): def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage()) filter = lambda x: "fp16 is not a valid" not in x.getMessage()
logger.addFilter(filter)
model = model_class.from_pretrained( try:
model_name, model = model_class.from_pretrained(
resume_download=True, model_name,
**kwargs, resume_download=True,
) **kwargs,
model.save_pretrained(destination, safe_serialization=True) )
model.save_pretrained(destination, safe_serialization=True)
finally:
logger.removeFilter(filter)
return destination return destination

View File

@ -212,7 +212,7 @@ class DownloadQueueBase(ABC):
@abstractmethod @abstractmethod
def cancel_all_jobs(self, preserve_partial: bool = False): def cancel_all_jobs(self, preserve_partial: bool = False):
""" """
Cancel all active and enquedjobs. Cancel all jobs (those in enqueued, running and paused states).
:param preserve_partial: Keep partially downloaded files [False]. :param preserve_partial: Keep partially downloaded files [False].
""" """

View File

@ -12,8 +12,7 @@ from queue import PriorityQueue
from typing import Callable, Dict, List, Optional, Set, Tuple, Union from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import requests import requests
from huggingface_hub import HfApi, hf_hub_url from pydantic import Field
from pydantic import Field, parse_obj_as, validator
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from requests import HTTPError from requests import HTTPError
@ -59,7 +58,7 @@ class DownloadQueue(DownloadQueueBase):
_jobs: Dict[int, DownloadJobBase] _jobs: Dict[int, DownloadJobBase]
_worker_pool: Set[threading.Thread] _worker_pool: Set[threading.Thread]
_queue: PriorityQueue _queue: PriorityQueue
_lock: threading.RLock _lock: threading.RLock # to allow methods called within the same thread to lock without blocking
_logger: Logger _logger: Logger
_event_handlers: List[DownloadEventHandler] = Field(default_factory=list) _event_handlers: List[DownloadEventHandler] = Field(default_factory=list)
_next_job_id: int = 0 _next_job_id: int = 0
@ -136,13 +135,10 @@ class DownloadQueue(DownloadQueueBase):
# add the queue's handlers # add the queue's handlers
for handler in self._event_handlers: for handler in self._event_handlers:
job.add_event_handler(handler) job.add_event_handler(handler)
try: with self._lock:
self._lock.acquire()
job.id = self._next_job_id job.id = self._next_job_id
self._jobs[job.id] = job self._jobs[job.id] = job
self._next_job_id += 1 self._next_job_id += 1
finally:
self._lock.release()
if start: if start:
self.start_job(job) self.start_job(job)
return job return job
@ -163,29 +159,25 @@ class DownloadQueue(DownloadQueueBase):
def change_priority(self, job: DownloadJobBase, delta: int): def change_priority(self, job: DownloadJobBase, delta: int):
"""Change the priority of a job. Smaller priorities run first.""" """Change the priority of a job. Smaller priorities run first."""
try: with self._lock:
self._lock.acquire() try:
assert isinstance(self._jobs[job.id], DownloadJobBase) assert isinstance(self._jobs[job.id], DownloadJobBase)
job.priority += delta job.priority += delta
except (AssertionError, KeyError) as excp: except (AssertionError, KeyError) as excp:
raise UnknownJobIDException("Unrecognized job") from excp raise UnknownJobIDException("Unrecognized job") from excp
finally:
self._lock.release()
def prune_jobs(self): def prune_jobs(self):
"""Prune completed and errored queue items from the job list.""" """Prune completed and errored queue items from the job list."""
try: with self._lock:
to_delete = set() to_delete = set()
self._lock.acquire() try:
for job_id, job in self._jobs.items(): for job_id, job in self._jobs.items():
if self._in_terminal_state(job): if self._in_terminal_state(job):
to_delete.add(job_id) to_delete.add(job_id)
for job_id in to_delete: for job_id in to_delete:
del self._jobs[job_id] del self._jobs[job_id]
except KeyError as excp: except KeyError as excp:
raise UnknownJobIDException("Unrecognized job") from excp raise UnknownJobIDException("Unrecognized job") from excp
finally:
self._lock.release()
def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False): def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False):
""" """
@ -194,16 +186,14 @@ class DownloadQueue(DownloadQueueBase):
If it is running it will be stopped. If it is running it will be stopped.
job.status will be set to DownloadJobStatus.CANCELLED job.status will be set to DownloadJobStatus.CANCELLED
""" """
try: with self._lock:
self._lock.acquire() try:
assert isinstance(self._jobs[job.id], DownloadJobBase) assert isinstance(self._jobs[job.id], DownloadJobBase)
job.preserve_partial_downloads = preserve_partial job.preserve_partial_downloads = preserve_partial
self._update_job_status(job, DownloadJobStatus.CANCELLED) self._update_job_status(job, DownloadJobStatus.CANCELLED)
job.cleanup() job.cleanup()
except (AssertionError, KeyError) as excp: except (AssertionError, KeyError) as excp:
raise UnknownJobIDException("Unrecognized job") from excp raise UnknownJobIDException("Unrecognized job") from excp
finally:
self._lock.release()
def id_to_job(self, id: int) -> DownloadJobBase: def id_to_job(self, id: int) -> DownloadJobBase:
"""Translate a job ID into a DownloadJobBase object.""" """Translate a job ID into a DownloadJobBase object."""
@ -214,12 +204,13 @@ class DownloadQueue(DownloadQueueBase):
def start_job(self, job: DownloadJobBase): def start_job(self, job: DownloadJobBase):
"""Enqueue (start) the indicated job.""" """Enqueue (start) the indicated job."""
try: with self._lock:
assert isinstance(self._jobs[job.id], DownloadJobBase) try:
self._update_job_status(job, DownloadJobStatus.ENQUEUED) assert isinstance(self._jobs[job.id], DownloadJobBase)
self._queue.put(job) self._update_job_status(job, DownloadJobStatus.ENQUEUED)
except (AssertionError, KeyError) as excp: self._queue.put(job)
raise UnknownJobIDException("Unrecognized job") from excp except (AssertionError, KeyError) as excp:
raise UnknownJobIDException("Unrecognized job") from excp
def pause_job(self, job: DownloadJobBase): def pause_job(self, job: DownloadJobBase):
""" """
@ -228,45 +219,34 @@ class DownloadQueue(DownloadQueueBase):
The job can be restarted with start_job() and the download will pick up The job can be restarted with start_job() and the download will pick up
from where it left off. from where it left off.
""" """
try: with self._lock:
self._lock.acquire() try:
assert isinstance(self._jobs[job.id], DownloadJobBase) assert isinstance(self._jobs[job.id], DownloadJobBase)
self._update_job_status(job, DownloadJobStatus.PAUSED) self._update_job_status(job, DownloadJobStatus.PAUSED)
job.cleanup() job.cleanup()
except (AssertionError, KeyError) as excp: except (AssertionError, KeyError) as excp:
raise UnknownJobIDException("Unrecognized job") from excp raise UnknownJobIDException("Unrecognized job") from excp
finally:
self._lock.release()
def start_all_jobs(self): def start_all_jobs(self):
"""Start (enqueue) all jobs that are idle or paused.""" """Start (enqueue) all jobs that are idle or paused."""
try: with self._lock:
self._lock.acquire()
for job in self._jobs.values(): for job in self._jobs.values():
if job.status in [DownloadJobStatus.IDLE, DownloadJobStatus.PAUSED]: if job.status in [DownloadJobStatus.IDLE, DownloadJobStatus.PAUSED]:
self.start_job(job) self.start_job(job)
finally:
self._lock.release()
def pause_all_jobs(self): def pause_all_jobs(self):
"""Pause all running jobs.""" """Pause all running jobs."""
try: with self._lock:
self._lock.acquire()
for job in self._jobs.values(): for job in self._jobs.values():
if job.status == DownloadJobStatus.RUNNING: if job.status == DownloadJobStatus.RUNNING:
self.pause_job(job) self.pause_job(job)
finally:
self._lock.release()
def cancel_all_jobs(self, preserve_partial: bool = False): def cancel_all_jobs(self, preserve_partial: bool = False):
"""Cancel all running jobs.""" """Cancel all jobs (those not in enqueued, running or paused state)."""
try: with self._lock:
self._lock.acquire()
for job in self._jobs.values(): for job in self._jobs.values():
if not self._in_terminal_state(job): if not self._in_terminal_state(job):
self.cancel_job(job, preserve_partial) self.cancel_job(job, preserve_partial)
finally:
self._lock.release()
def _in_terminal_state(self, job: DownloadJobBase): def _in_terminal_state(self, job: DownloadJobBase):
return job.status in [ return job.status in [
@ -288,26 +268,26 @@ class DownloadQueue(DownloadQueueBase):
while not done: while not done:
job = self._queue.get() job = self._queue.get()
try: # this is for debugging priority with self._lock:
self._lock.acquire()
job.job_sequence = self._sequence job.job_sequence = self._sequence
self._sequence += 1 self._sequence += 1
try:
if job == STOP_JOB: # marker that queue is done
done = True
if (
job.status == DownloadJobStatus.ENQUEUED
): # Don't do anything for non-enqueued jobs (shouldn't happen)
if not self._quiet:
self._logger.info(f"{job.source}: Downloading to {job.destination}")
do_download = self.select_downloader(job)
do_download(job)
if job.status == DownloadJobStatus.CANCELLED:
self._cleanup_cancelled_job(job)
finally: finally:
self._lock.release() self._queue.task_done()
if job == STOP_JOB: # marker that queue is done
done = True
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for non-enqueued jobs (shouldn't happen)
if not self._quiet:
self._logger.info(f"{job.source}: Downloading to {job.destination}")
do_download = self.select_downloader(job)
do_download(job)
if job.status == DownloadJobStatus.CANCELLED:
self._cleanup_cancelled_job(job)
self._queue.task_done()
def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]: def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]:
"""Based on the job type select the download method.""" """Based on the job type select the download method."""
@ -397,11 +377,7 @@ class DownloadQueue(DownloadQueueBase):
self._update_job_status(job, DownloadJobStatus.COMPLETED) self._update_job_status(job, DownloadJobStatus.COMPLETED)
except KeyboardInterrupt as excp: except KeyboardInterrupt as excp:
raise excp raise excp
except DuplicateModelException as excp: except (HTTPError, OSError) as excp:
self._logger.error(f"A model with the same hash as {dest} is already installed.")
job.error = excp
self._update_job_status(job, DownloadJobStatus.ERROR)
except Exception as excp:
self._logger.error(f"An error occurred while downloading/installing {job.source}: {str(excp)}") self._logger.error(f"An error occurred while downloading/installing {job.source}: {str(excp)}")
print(traceback.format_exc()) print(traceback.format_exc())
job.error = excp job.error = excp

View File

@ -15,6 +15,8 @@ from typing import Dict, Union
from imohash import hashfile from imohash import hashfile
from .models import InvalidModelException
class FastModelHash(object): class FastModelHash(object):
"""FastModelHash obect provides one public class method, hash().""" """FastModelHash obect provides one public class method, hash()."""
@ -32,9 +34,6 @@ class FastModelHash(object):
elif model_location.is_dir(): elif model_location.is_dir():
return cls._hash_dir(model_location) return cls._hash_dir(model_location)
else: else:
# avoid circular import
from .models import InvalidModelException
raise InvalidModelException(f"Not a valid file or directory: {model_location}") raise InvalidModelException(f"Not a valid file or directory: {model_location}")
@classmethod @classmethod
@ -54,7 +53,8 @@ class FastModelHash(object):
for root, dirs, files in os.walk(model_location): for root, dirs, files in os.walk(model_location):
for file in files: for file in files:
# only tally tensor files # only tally tensor files because diffusers config files change slightly
# depending on how the model was downloaded/converted.
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")): if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
continue continue
path = (Path(root) / file).as_posix() path = (Path(root) / file).as_posix()

View File

@ -72,14 +72,11 @@ class ModelConfigStoreSQL(ModelConfigStore):
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._lock = lock self._lock = lock
try: with self._lock:
self._lock.acquire()
# Enable foreign keys # Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;") self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables() self._create_tables()
self._conn.commit() self._conn.commit()
finally:
self._lock.release()
assert ( assert (
str(self.version) == CONFIG_FILE_VERSION str(self.version) == CONFIG_FILE_VERSION
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}" ), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
@ -189,52 +186,49 @@ class ModelConfigStoreSQL(ModelConfigStore):
""" """
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect. record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
json_serialized = json.dumps(record.dict()) # and turn it into a json string. json_serialized = json.dumps(record.dict()) # and turn it into a json string.
try: with self._lock:
self._lock.acquire() try:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
INSERT INTO model_config ( INSERT INTO model_config (
id, id,
base_model, base_model,
model_type, model_type,
model_name, model_name,
model_path, model_path,
config config
) )
VALUES (?,?,?,?,?,?); VALUES (?,?,?,?,?,?);
""", """,
( (
key, key,
record.base_model, record.base_model,
record.model_type, record.model_type,
record.name, record.name,
record.path, record.path,
json_serialized, json_serialized,
), ),
) )
if record.tags: if record.tags:
self._update_tags(key, record.tags) self._update_tags(key, record.tags)
self._conn.commit() self._conn.commit()
except sqlite3.IntegrityError as e: except sqlite3.IntegrityError as e:
self._conn.rollback() self._conn.rollback()
if "UNIQUE constraint failed" in str(e): if "UNIQUE constraint failed" in str(e):
raise DuplicateModelException(f"A model with key '{key}' is already installed") from e raise DuplicateModelException(f"A model with key '{key}' is already installed") from e
else: else:
raise e
except sqlite3.Error as e:
self._conn.rollback()
raise e raise e
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return self.get_model(key) return self.get_model(key)
@property @property
def version(self) -> str: def version(self) -> str:
"""Return the version of the database schema.""" """Return the version of the database schema."""
try: with self._lock:
self._lock.acquire()
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
SELECT metadata_value FROM model_manager_metadata SELECT metadata_value FROM model_manager_metadata
@ -246,8 +240,6 @@ class ModelConfigStoreSQL(ModelConfigStore):
if not rows: if not rows:
raise KeyError("Models database does not have metadata key 'version'") raise KeyError("Models database does not have metadata key 'version'")
return rows[0] return rows[0]
finally:
self._lock.release()
def _update_tags(self, key: str, tags: List[str]) -> None: def _update_tags(self, key: str, tags: List[str]) -> None:
"""Update tags for model with key.""" """Update tags for model with key."""
@ -301,23 +293,21 @@ class ModelConfigStoreSQL(ModelConfigStore):
Can raise an UnknownModelException Can raise an UnknownModelException
""" """
try: with self._lock:
self._lock.acquire() try:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
DELETE FROM model_config DELETE FROM model_config
WHERE id=?; WHERE id=?;
""", """,
(key,), (key,),
) )
if self._cursor.rowcount == 0: if self._cursor.rowcount == 0:
raise UnknownModelException raise UnknownModelException
self._conn.commit() self._conn.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()
raise e raise e
finally:
self._lock.release()
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase: def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
""" """
@ -329,30 +319,29 @@ class ModelConfigStoreSQL(ModelConfigStore):
""" """
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
json_serialized = json.dumps(record.dict()) # and turn it into a json string. json_serialized = json.dumps(record.dict()) # and turn it into a json string.
try: with self._lock:
self._lock.acquire() try:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
UPDATE model_config UPDATE model_config
SET base_model=?, SET base_model=?,
model_type=?, model_type=?,
model_name=?, model_name=?,
model_path=?, model_path=?,
config=? config=?
WHERE id=?; WHERE id=?;
""", """,
(record.base_model, record.model_type, record.name, record.path, json_serialized, key), (record.base_model, record.model_type, record.name, record.path, json_serialized, key),
) )
if self._cursor.rowcount == 0: if self._cursor.rowcount == 0:
raise UnknownModelException raise UnknownModelException
if record.tags: if record.tags:
self._update_tags(key, record.tags) self._update_tags(key, record.tags)
self._conn.commit() self._conn.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()
raise e raise e
finally:
self._lock.release()
return self.get_model(key) return self.get_model(key)
def get_model(self, key: str) -> AnyModelConfig: def get_model(self, key: str) -> AnyModelConfig:
@ -363,8 +352,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
Exceptions: UnknownModelException Exceptions: UnknownModelException
""" """
try: with self._lock:
self._lock.acquire()
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
SELECT config FROM model_config SELECT config FROM model_config
@ -376,8 +364,6 @@ class ModelConfigStoreSQL(ModelConfigStore):
if not rows: if not rows:
raise UnknownModelException raise UnknownModelException
model = ModelConfigFactory.make_config(json.loads(rows[0])) model = ModelConfigFactory.make_config(json.loads(rows[0]))
finally:
self._lock.release()
return model return model
def exists(self, key: str) -> bool: def exists(self, key: str) -> bool:
@ -387,20 +373,18 @@ class ModelConfigStoreSQL(ModelConfigStore):
:param key: Unique key for the model to be deleted :param key: Unique key for the model to be deleted
""" """
count = 0 count = 0
try: with self._lock:
self._lock.acquire() try:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
select count(*) FROM model_config select count(*) FROM model_config
WHERE id=?; WHERE id=?;
""", """,
(key,), (key,),
) )
count = self._cursor.fetchone()[0] count = self._cursor.fetchone()[0]
except sqlite3.Error as e: except sqlite3.Error as e:
raise e raise e
finally:
self._lock.release()
return count > 0 return count > 0
def search_by_tag(self, tags: Set[str]) -> List[AnyModelConfig]: def search_by_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
@ -408,34 +392,32 @@ class ModelConfigStoreSQL(ModelConfigStore):
# rather than create a hairy SQL cross-product, we intersect # rather than create a hairy SQL cross-product, we intersect
# tag results in a stepwise fashion at the python level. # tag results in a stepwise fashion at the python level.
results = [] results = []
try: with self._lock:
self._lock.acquire() try:
matches: Set[str] = set() matches: Set[str] = set()
for tag in tags: for tag in tags:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
SELECT a.id FROM model_tag AS a, SELECT a.id FROM model_tag AS a,
tags AS b tags AS b
WHERE a.tag_id=b.tag_id WHERE a.tag_id=b.tag_id
AND b.tag_text=?; AND b.tag_text=?;
""", """,
(tag,), (tag,),
) )
model_keys = {x[0] for x in self._cursor.fetchall()} model_keys = {x[0] for x in self._cursor.fetchall()}
matches = matches.intersection(model_keys) if len(matches) > 0 else model_keys matches = matches.intersection(model_keys) if len(matches) > 0 else model_keys
if matches: if matches:
self._cursor.execute( self._cursor.execute(
f"""--sql f"""--sql
SELECT config FROM model_config SELECT config FROM model_config
WHERE id IN ({','.join('?' * len(matches))}); WHERE id IN ({','.join('?' * len(matches))});
""", """,
tuple(matches), tuple(matches),
) )
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
except sqlite3.Error as e: except sqlite3.Error as e:
raise e raise e
finally:
self._lock.release()
return results return results
def search_by_name( def search_by_name(
@ -467,20 +449,18 @@ class ModelConfigStoreSQL(ModelConfigStore):
where_clause.append("model_type=?") where_clause.append("model_type=?")
bindings.append(model_type) bindings.append(model_type)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else "" where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
try: with self._lock:
self._lock.acquire() try:
self._cursor.execute( self._cursor.execute(
f"""--sql f"""--sql
select config FROM model_config select config FROM model_config
{where}; {where};
""", """,
tuple(bindings), tuple(bindings),
) )
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
except sqlite3.Error as e: except sqlite3.Error as e:
raise e raise e
finally:
self._lock.release()
return results return results
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]: def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:

View File

@ -80,24 +80,18 @@ class ModelConfigStoreYAML(ModelConfigStore):
raise ConfigFileVersionMismatchException raise ConfigFileVersionMismatchException
def _initialize_yaml(self): def _initialize_yaml(self):
try: with self._lock:
self._lock.acquire()
self._filename.parent.mkdir(parents=True, exist_ok=True) self._filename.parent.mkdir(parents=True, exist_ok=True)
with open(self._filename, "w") as yaml_file: with open(self._filename, "w") as yaml_file:
yaml_file.write(yaml.dump({"__metadata__": {"version": CONFIG_FILE_VERSION}})) yaml_file.write(yaml.dump({"__metadata__": {"version": CONFIG_FILE_VERSION}}))
finally:
self._lock.release()
def _commit(self): def _commit(self):
try: with self._lock:
self._lock.acquire()
newfile = Path(str(self._filename) + ".new") newfile = Path(str(self._filename) + ".new")
yaml_str = OmegaConf.to_yaml(self._config) yaml_str = OmegaConf.to_yaml(self._config)
with open(newfile, "w", encoding="utf-8") as outfile: with open(newfile, "w", encoding="utf-8") as outfile:
outfile.write(yaml_str) outfile.write(yaml_str)
newfile.replace(self._filename) newfile.replace(self._filename)
finally:
self._lock.release()
@property @property
def version(self) -> str: def version(self) -> str:
@ -116,8 +110,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
""" """
record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect
dict_fields = record.dict() # and back to a dict with valid fields dict_fields = record.dict() # and back to a dict with valid fields
try: with self._lock:
self._lock.acquire()
if key in self._config: if key in self._config:
existing_model = self.get_model(key) existing_model = self.get_model(key)
raise DuplicateModelException( raise DuplicateModelException(
@ -125,8 +118,6 @@ class ModelConfigStoreYAML(ModelConfigStore):
) )
self._config[key] = self._fix_enums(dict_fields) self._config[key] = self._fix_enums(dict_fields)
self._commit() self._commit()
finally:
self._lock.release()
return self.get_model(key) return self.get_model(key)
def _fix_enums(self, original: dict) -> dict: def _fix_enums(self, original: dict) -> dict:
@ -144,14 +135,11 @@ class ModelConfigStoreYAML(ModelConfigStore):
Can raise an UnknownModelException Can raise an UnknownModelException
""" """
try: with self._lock:
self._lock.acquire()
if key not in self._config: if key not in self._config:
raise UnknownModelException(f"Unknown key '{key}' for model config") raise UnknownModelException(f"Unknown key '{key}' for model config")
self._config.pop(key) self._config.pop(key)
self._commit() self._commit()
finally:
self._lock.release()
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase: def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
""" """
@ -163,14 +151,11 @@ class ModelConfigStoreYAML(ModelConfigStore):
""" """
record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect
dict_fields = record.dict() # and back to a dict with valid fields dict_fields = record.dict() # and back to a dict with valid fields
try: with self._lock:
self._lock.acquire()
if key not in self._config: if key not in self._config:
raise UnknownModelException(f"Unknown key '{key}' for model config") raise UnknownModelException(f"Unknown key '{key}' for model config")
self._config[key] = self._fix_enums(dict_fields) self._config[key] = self._fix_enums(dict_fields)
self._commit() self._commit()
finally:
self._lock.release()
return self.get_model(key) return self.get_model(key)
def get_model(self, key: str) -> AnyModelConfig: def get_model(self, key: str) -> AnyModelConfig:
@ -203,15 +188,12 @@ class ModelConfigStoreYAML(ModelConfigStore):
""" """
results = [] results = []
tags = set(tags) tags = set(tags)
try: with self._lock:
self._lock.acquire()
for config in self.all_models(): for config in self.all_models():
config_tags = set(config.tags or []) config_tags = set(config.tags or [])
if tags.difference(config_tags): # not all tags in the model if tags.difference(config_tags): # not all tags in the model
continue continue
results.append(config) results.append(config)
finally:
self._lock.release()
return results return results
def search_by_name( def search_by_name(
@ -231,8 +213,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
models in the database. models in the database.
""" """
results: List[ModelConfigBase] = list() results: List[ModelConfigBase] = list()
try: with self._lock:
self._lock.acquire()
for key, record in self._config.items(): for key, record in self._config.items():
if key == "__metadata__": if key == "__metadata__":
continue continue
@ -244,20 +225,15 @@ class ModelConfigStoreYAML(ModelConfigStore):
if model_type and model.model_type != model_type: if model_type and model.model_type != model_type:
continue continue
results.append(model) results.append(model)
finally:
self._lock.release()
return results return results
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]: def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:
"""Return the model with the indicated path, or None.""" """Return the model with the indicated path, or None."""
try: with self._lock:
self._lock.acquire()
for key, record in self._config.items(): for key, record in self._config.items():
if key == "__metadata__": if key == "__metadata__":
continue continue
model = ModelConfigFactory.make_config(record, str(key)) model = ModelConfigFactory.make_config(record, str(key))
if model.path == path: if model.path == path:
return model return model
finally:
self._lock.release()
return None return None

View File

@ -270,8 +270,6 @@ def test_bad_urls():
def test_pause_cancel_url(): # this one is tricky because of potential race conditions def test_pause_cancel_url(): # this one is tricky because of potential race conditions
def event_handler(job: DownloadJobBase): def event_handler(job: DownloadJobBase):
if job.id == 0:
print(job.status, job.bytes)
time.sleep(0.5) # slow down the thread so that we can recover the paused state time.sleep(0.5) # slow down the thread so that we can recover the paused state
queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler]) queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler])