mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
address all PR 4252 comments from ryan through October 5
This commit is contained in:
@ -98,13 +98,18 @@ async def update_model(
|
||||
) -> InvokeAIModelConfig:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
info_dict = info.dict()
|
||||
record_store = ApiDependencies.invoker.services.model_record_store
|
||||
model_install = ApiDependencies.invoker.services.model_installer
|
||||
try:
|
||||
new_config = record_store.update_model(key, config=info_dict)
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
try:
|
||||
info_dict = info.dict()
|
||||
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
||||
record_store = ApiDependencies.invoker.services.model_record_store
|
||||
model_install = ApiDependencies.invoker.services.model_installer
|
||||
new_config = record_store.update_model(key, config=info_dict)
|
||||
# In the event that the model's name, type or base has changed, and the model itself
|
||||
# resides in the invokeai root models directory, then the next statement will move
|
||||
# the model file into its new canonical location.
|
||||
@ -115,9 +120,6 @@ async def update_model(
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return model_response
|
||||
|
||||
@ -198,8 +200,8 @@ async def import_model(
|
||||
responses={
|
||||
201: {"description": "The model added successfully"},
|
||||
404: {"description": "The model could not be found"},
|
||||
424: {"description": "The model appeared to add successfully, but could not be found in the model manager"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=InvokeAIModelConfig,
|
||||
@ -213,16 +215,22 @@ async def add_model(
|
||||
"""
|
||||
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
path = info.path
|
||||
installer = ApiDependencies.invoker.services.model_installer
|
||||
record_store = ApiDependencies.invoker.services.model_record_store
|
||||
try:
|
||||
path = info.path
|
||||
installer = ApiDependencies.invoker.services.model_installer
|
||||
record_store = ApiDependencies.invoker.services.model_record_store
|
||||
key = installer.install_path(path)
|
||||
logger.info(f"Created model {key} for {path}")
|
||||
except DuplicateModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
# update with the provided info
|
||||
# update with the provided info
|
||||
try:
|
||||
info_dict = info.dict()
|
||||
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
||||
new_config = record_store.update_model(key, new_config=info_dict)
|
||||
return parse_obj_as(InvokeAIModelConfig, new_config.dict())
|
||||
except UnknownModelException as e:
|
||||
@ -405,24 +413,19 @@ async def merge_models(
|
||||
)
|
||||
async def list_install_jobs() -> List[ModelImportStatus]:
|
||||
"""List active and pending model installation jobs."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
job_mgr = ApiDependencies.invoker.services.model_installer
|
||||
try:
|
||||
jobs = job_mgr.list_install_jobs()
|
||||
return [
|
||||
ModelImportStatus(
|
||||
job_id=x.id,
|
||||
source=x.source,
|
||||
priority=x.priority,
|
||||
bytes=x.bytes,
|
||||
total_bytes=x.total_bytes,
|
||||
status=x.status,
|
||||
)
|
||||
for x in jobs
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
jobs = job_mgr.list_install_jobs()
|
||||
return [
|
||||
ModelImportStatus(
|
||||
job_id=x.id,
|
||||
source=x.source,
|
||||
priority=x.priority,
|
||||
bytes=x.bytes,
|
||||
total_bytes=x.total_bytes,
|
||||
status=x.status,
|
||||
)
|
||||
for x in jobs
|
||||
]
|
||||
|
||||
|
||||
@models_router.patch(
|
||||
@ -459,11 +462,11 @@ async def control_install_jobs(
|
||||
elif operation == JobControlOperation.CANCEL:
|
||||
job_mgr.cancel_job(job_id)
|
||||
|
||||
elif operation == JobControlOperation.CHANGE_PRIORITY:
|
||||
elif operation == JobControlOperation.CHANGE_PRIORITY and priority_delta is not None:
|
||||
job_mgr.change_job_priority(job_id, priority_delta)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown operation {JobControlOperation.value}")
|
||||
raise ValueError(f"Unknown operation {operation.value}")
|
||||
|
||||
return ModelImportStatus(
|
||||
job_id=job_id,
|
||||
@ -478,9 +481,6 @@ async def control_install_jobs(
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@models_router.patch(
|
||||
@ -490,39 +490,27 @@ async def control_install_jobs(
|
||||
204: {"description": "All jobs cancelled successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=ModelImportStatus,
|
||||
)
|
||||
async def cancel_install_jobs():
|
||||
"""Cancel all pending install jobs."""
|
||||
"""Cancel all model installation jobs."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
job_mgr = ApiDependencies.invoker.services.model_installer
|
||||
logger.info("Cancelling all running model installation jobs.")
|
||||
job_mgr.cancel_all_jobs()
|
||||
return Response(status_code=204)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
job_mgr = ApiDependencies.invoker.services.model_installer
|
||||
logger.info("Cancelling all model installation jobs.")
|
||||
job_mgr.cancel_all_jobs()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@models_router.patch(
|
||||
"/jobs/prune",
|
||||
operation_id="prune_jobs",
|
||||
responses={
|
||||
204: {"description": "All jobs cancelled successfully"},
|
||||
204: {"description": "All completed jobs have been pruned"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=ModelImportStatus,
|
||||
)
|
||||
async def prune_jobs():
|
||||
"""Prune all completed and errored jobs."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
mgr = ApiDependencies.invoker.services.model_installer
|
||||
mgr.prune_jobs()
|
||||
return Response(status_code=204)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
mgr = ApiDependencies.invoker.services.model_installer
|
||||
mgr.prune_jobs()
|
||||
return Response(status_code=204)
|
||||
|
@ -126,7 +126,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
e.g. `max_parallel_dl`.
|
||||
"""
|
||||
self._event_bus = event_bus
|
||||
self._queue = DownloadQueue()
|
||||
self._queue = DownloadQueue(**kwargs)
|
||||
|
||||
def create_download_job(
|
||||
self,
|
||||
|
@ -92,7 +92,7 @@ class ModelInstallServiceBase(ModelInstallBase): # This is an ABC
|
||||
|
||||
@abstractmethod
|
||||
def cancel_all_jobs(self):
|
||||
"""Cancel all active jobs."""
|
||||
"""Cancel all installation jobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
@ -62,7 +62,7 @@ class ModelLoadServiceBase(ABC):
|
||||
class ModelLoadService(ModelLoadServiceBase):
|
||||
"""Responsible for managing models on disk and in memory."""
|
||||
|
||||
_loader: ModelLoad = Field(description="InvokeAIAppConfig object for the current process")
|
||||
_loader: ModelLoad
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -12,6 +12,4 @@ from .model_manager import ( # noqa F401
|
||||
SilenceWarnings,
|
||||
SubModelType,
|
||||
)
|
||||
from .model_manager.install import ModelInstall # noqa F401
|
||||
from .model_manager.loader import ModelLoad # noqa F401
|
||||
from .util.devices import get_precision # noqa F401
|
||||
|
@ -18,7 +18,6 @@ from invokeai.backend.model_manager.install import ModelInstall, ModelInstallJob
|
||||
|
||||
# name of the starter models file
|
||||
INITIAL_MODELS = "INITIAL_MODELS.yaml"
|
||||
ACCESS_TOKEN = HfFolder.get_token()
|
||||
|
||||
|
||||
class UnifiedModelInfo(BaseModel):
|
||||
@ -173,7 +172,7 @@ class InstallHelper(object):
|
||||
model.source,
|
||||
subfolder=model.subfolder,
|
||||
variant="fp16" if self._config.precision == "float16" else None,
|
||||
access_token=ACCESS_TOKEN, # this is a global,
|
||||
access_token=HfFolder.get_token(),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
@ -185,14 +185,17 @@ class ProgressBar:
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
|
||||
logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage())
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
resume_download=True,
|
||||
**kwargs,
|
||||
)
|
||||
model.save_pretrained(destination, safe_serialization=True)
|
||||
filter = lambda x: "fp16 is not a valid" not in x.getMessage()
|
||||
logger.addFilter(filter)
|
||||
try:
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
resume_download=True,
|
||||
**kwargs,
|
||||
)
|
||||
model.save_pretrained(destination, safe_serialization=True)
|
||||
finally:
|
||||
logger.removeFilter(filter)
|
||||
return destination
|
||||
|
||||
|
||||
|
@ -212,7 +212,7 @@ class DownloadQueueBase(ABC):
|
||||
@abstractmethod
|
||||
def cancel_all_jobs(self, preserve_partial: bool = False):
|
||||
"""
|
||||
Cancel all active and enquedjobs.
|
||||
Cancel all jobs (those in enqueued, running and paused states).
|
||||
|
||||
:param preserve_partial: Keep partially downloaded files [False].
|
||||
"""
|
||||
|
@ -12,8 +12,7 @@ from queue import PriorityQueue
|
||||
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import requests
|
||||
from huggingface_hub import HfApi, hf_hub_url
|
||||
from pydantic import Field, parse_obj_as, validator
|
||||
from pydantic import Field
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests import HTTPError
|
||||
|
||||
@ -59,7 +58,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
_jobs: Dict[int, DownloadJobBase]
|
||||
_worker_pool: Set[threading.Thread]
|
||||
_queue: PriorityQueue
|
||||
_lock: threading.RLock
|
||||
_lock: threading.RLock # to allow methods called within the same thread to lock without blocking
|
||||
_logger: Logger
|
||||
_event_handlers: List[DownloadEventHandler] = Field(default_factory=list)
|
||||
_next_job_id: int = 0
|
||||
@ -136,13 +135,10 @@ class DownloadQueue(DownloadQueueBase):
|
||||
# add the queue's handlers
|
||||
for handler in self._event_handlers:
|
||||
job.add_event_handler(handler)
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
job.id = self._next_job_id
|
||||
self._jobs[job.id] = job
|
||||
self._next_job_id += 1
|
||||
finally:
|
||||
self._lock.release()
|
||||
if start:
|
||||
self.start_job(job)
|
||||
return job
|
||||
@ -163,29 +159,25 @@ class DownloadQueue(DownloadQueueBase):
|
||||
|
||||
def change_priority(self, job: DownloadJobBase, delta: int):
|
||||
"""Change the priority of a job. Smaller priorities run first."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||
job.priority += delta
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
finally:
|
||||
self._lock.release()
|
||||
with self._lock:
|
||||
try:
|
||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||
job.priority += delta
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def prune_jobs(self):
|
||||
"""Prune completed and errored queue items from the job list."""
|
||||
try:
|
||||
with self._lock:
|
||||
to_delete = set()
|
||||
self._lock.acquire()
|
||||
for job_id, job in self._jobs.items():
|
||||
if self._in_terminal_state(job):
|
||||
to_delete.add(job_id)
|
||||
for job_id in to_delete:
|
||||
del self._jobs[job_id]
|
||||
except KeyError as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
finally:
|
||||
self._lock.release()
|
||||
try:
|
||||
for job_id, job in self._jobs.items():
|
||||
if self._in_terminal_state(job):
|
||||
to_delete.add(job_id)
|
||||
for job_id in to_delete:
|
||||
del self._jobs[job_id]
|
||||
except KeyError as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False):
|
||||
"""
|
||||
@ -194,16 +186,14 @@ class DownloadQueue(DownloadQueueBase):
|
||||
If it is running it will be stopped.
|
||||
job.status will be set to DownloadJobStatus.CANCELLED
|
||||
"""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||
job.preserve_partial_downloads = preserve_partial
|
||||
self._update_job_status(job, DownloadJobStatus.CANCELLED)
|
||||
job.cleanup()
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
finally:
|
||||
self._lock.release()
|
||||
with self._lock:
|
||||
try:
|
||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||
job.preserve_partial_downloads = preserve_partial
|
||||
self._update_job_status(job, DownloadJobStatus.CANCELLED)
|
||||
job.cleanup()
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def id_to_job(self, id: int) -> DownloadJobBase:
|
||||
"""Translate a job ID into a DownloadJobBase object."""
|
||||
@ -214,12 +204,13 @@ class DownloadQueue(DownloadQueueBase):
|
||||
|
||||
def start_job(self, job: DownloadJobBase):
|
||||
"""Enqueue (start) the indicated job."""
|
||||
try:
|
||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||
self._update_job_status(job, DownloadJobStatus.ENQUEUED)
|
||||
self._queue.put(job)
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
with self._lock:
|
||||
try:
|
||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||
self._update_job_status(job, DownloadJobStatus.ENQUEUED)
|
||||
self._queue.put(job)
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def pause_job(self, job: DownloadJobBase):
|
||||
"""
|
||||
@ -228,45 +219,34 @@ class DownloadQueue(DownloadQueueBase):
|
||||
The job can be restarted with start_job() and the download will pick up
|
||||
from where it left off.
|
||||
"""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||
self._update_job_status(job, DownloadJobStatus.PAUSED)
|
||||
job.cleanup()
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
finally:
|
||||
self._lock.release()
|
||||
with self._lock:
|
||||
try:
|
||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||
self._update_job_status(job, DownloadJobStatus.PAUSED)
|
||||
job.cleanup()
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def start_all_jobs(self):
|
||||
"""Start (enqueue) all jobs that are idle or paused."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
for job in self._jobs.values():
|
||||
if job.status in [DownloadJobStatus.IDLE, DownloadJobStatus.PAUSED]:
|
||||
self.start_job(job)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def pause_all_jobs(self):
|
||||
"""Pause all running jobs."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
for job in self._jobs.values():
|
||||
if job.status == DownloadJobStatus.RUNNING:
|
||||
self.pause_job(job)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def cancel_all_jobs(self, preserve_partial: bool = False):
|
||||
"""Cancel all running jobs."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
"""Cancel all jobs (those not in enqueued, running or paused state)."""
|
||||
with self._lock:
|
||||
for job in self._jobs.values():
|
||||
if not self._in_terminal_state(job):
|
||||
self.cancel_job(job, preserve_partial)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _in_terminal_state(self, job: DownloadJobBase):
|
||||
return job.status in [
|
||||
@ -288,26 +268,26 @@ class DownloadQueue(DownloadQueueBase):
|
||||
while not done:
|
||||
job = self._queue.get()
|
||||
|
||||
try: # this is for debugging priority
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
job.job_sequence = self._sequence
|
||||
self._sequence += 1
|
||||
|
||||
try:
|
||||
if job == STOP_JOB: # marker that queue is done
|
||||
done = True
|
||||
|
||||
if (
|
||||
job.status == DownloadJobStatus.ENQUEUED
|
||||
): # Don't do anything for non-enqueued jobs (shouldn't happen)
|
||||
if not self._quiet:
|
||||
self._logger.info(f"{job.source}: Downloading to {job.destination}")
|
||||
do_download = self.select_downloader(job)
|
||||
do_download(job)
|
||||
|
||||
if job.status == DownloadJobStatus.CANCELLED:
|
||||
self._cleanup_cancelled_job(job)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if job == STOP_JOB: # marker that queue is done
|
||||
done = True
|
||||
|
||||
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for non-enqueued jobs (shouldn't happen)
|
||||
if not self._quiet:
|
||||
self._logger.info(f"{job.source}: Downloading to {job.destination}")
|
||||
do_download = self.select_downloader(job)
|
||||
do_download(job)
|
||||
|
||||
if job.status == DownloadJobStatus.CANCELLED:
|
||||
self._cleanup_cancelled_job(job)
|
||||
|
||||
self._queue.task_done()
|
||||
self._queue.task_done()
|
||||
|
||||
def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]:
|
||||
"""Based on the job type select the download method."""
|
||||
@ -397,11 +377,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||
except KeyboardInterrupt as excp:
|
||||
raise excp
|
||||
except DuplicateModelException as excp:
|
||||
self._logger.error(f"A model with the same hash as {dest} is already installed.")
|
||||
job.error = excp
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
except Exception as excp:
|
||||
except (HTTPError, OSError) as excp:
|
||||
self._logger.error(f"An error occurred while downloading/installing {job.source}: {str(excp)}")
|
||||
print(traceback.format_exc())
|
||||
job.error = excp
|
||||
|
@ -15,6 +15,8 @@ from typing import Dict, Union
|
||||
|
||||
from imohash import hashfile
|
||||
|
||||
from .models import InvalidModelException
|
||||
|
||||
|
||||
class FastModelHash(object):
|
||||
"""FastModelHash obect provides one public class method, hash()."""
|
||||
@ -32,9 +34,6 @@ class FastModelHash(object):
|
||||
elif model_location.is_dir():
|
||||
return cls._hash_dir(model_location)
|
||||
else:
|
||||
# avoid circular import
|
||||
from .models import InvalidModelException
|
||||
|
||||
raise InvalidModelException(f"Not a valid file or directory: {model_location}")
|
||||
|
||||
@classmethod
|
||||
@ -54,7 +53,8 @@ class FastModelHash(object):
|
||||
|
||||
for root, dirs, files in os.walk(model_location):
|
||||
for file in files:
|
||||
# only tally tensor files
|
||||
# only tally tensor files because diffusers config files change slightly
|
||||
# depending on how the model was downloaded/converted.
|
||||
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
|
||||
continue
|
||||
path = (Path(root) / file).as_posix()
|
||||
|
@ -72,14 +72,11 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = lock
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
assert (
|
||||
str(self.version) == CONFIG_FILE_VERSION
|
||||
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
|
||||
@ -189,52 +186,49 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
||||
json_serialized = json.dumps(record.dict()) # and turn it into a json string.
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_config (
|
||||
id,
|
||||
base_model,
|
||||
model_type,
|
||||
model_name,
|
||||
model_path,
|
||||
config
|
||||
)
|
||||
VALUES (?,?,?,?,?,?);
|
||||
""",
|
||||
(
|
||||
key,
|
||||
record.base_model,
|
||||
record.model_type,
|
||||
record.name,
|
||||
record.path,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
if record.tags:
|
||||
self._update_tags(key, record.tags)
|
||||
self._conn.commit()
|
||||
with self._lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_config (
|
||||
id,
|
||||
base_model,
|
||||
model_type,
|
||||
model_name,
|
||||
model_path,
|
||||
config
|
||||
)
|
||||
VALUES (?,?,?,?,?,?);
|
||||
""",
|
||||
(
|
||||
key,
|
||||
record.base_model,
|
||||
record.model_type,
|
||||
record.name,
|
||||
record.path,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
if record.tags:
|
||||
self._update_tags(key, record.tags)
|
||||
self._conn.commit()
|
||||
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
raise DuplicateModelException(f"A model with key '{key}' is already installed") from e
|
||||
else:
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
raise DuplicateModelException(f"A model with key '{key}' is already installed") from e
|
||||
else:
|
||||
raise e
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
return self.get_model(key)
|
||||
|
||||
@property
|
||||
def version(self) -> str:
|
||||
"""Return the version of the database schema."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata_value FROM model_manager_metadata
|
||||
@ -246,8 +240,6 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
if not rows:
|
||||
raise KeyError("Models database does not have metadata key 'version'")
|
||||
return rows[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _update_tags(self, key: str, tags: List[str]) -> None:
|
||||
"""Update tags for model with key."""
|
||||
@ -301,23 +293,21 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
|
||||
Can raise an UnknownModelException
|
||||
"""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_config
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownModelException
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
with self._lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_config
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownModelException
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
|
||||
"""
|
||||
@ -329,30 +319,29 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
||||
json_serialized = json.dumps(record.dict()) # and turn it into a json string.
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_config
|
||||
SET base_model=?,
|
||||
model_type=?,
|
||||
model_name=?,
|
||||
model_path=?,
|
||||
config=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(record.base_model, record.model_type, record.name, record.path, json_serialized, key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownModelException
|
||||
if record.tags:
|
||||
self._update_tags(key, record.tags)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
with self._lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_config
|
||||
SET base_model=?,
|
||||
model_type=?,
|
||||
model_name=?,
|
||||
model_path=?,
|
||||
config=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(record.base_model, record.model_type, record.name, record.path, json_serialized, key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownModelException
|
||||
if record.tags:
|
||||
self._update_tags(key, record.tags)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(key)
|
||||
|
||||
def get_model(self, key: str) -> AnyModelConfig:
|
||||
@ -363,8 +352,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
|
||||
Exceptions: UnknownModelException
|
||||
"""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config FROM model_config
|
||||
@ -376,8 +364,6 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
if not rows:
|
||||
raise UnknownModelException
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]))
|
||||
finally:
|
||||
self._lock.release()
|
||||
return model
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
@ -387,20 +373,18 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
:param key: Unique key for the model to be deleted
|
||||
"""
|
||||
count = 0
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM model_config
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
with self._lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM model_config
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return count > 0
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||
@ -408,34 +392,32 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
# rather than create a hairy SQL cross-product, we intersect
|
||||
# tag results in a stepwise fashion at the python level.
|
||||
results = []
|
||||
try:
|
||||
self._lock.acquire()
|
||||
matches: Set[str] = set()
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.id FROM model_tag AS a,
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
matches = matches.intersection(model_keys) if len(matches) > 0 else model_keys
|
||||
if matches:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config FROM model_config
|
||||
WHERE id IN ({','.join('?' * len(matches))});
|
||||
""",
|
||||
tuple(matches),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
with self._lock:
|
||||
try:
|
||||
matches: Set[str] = set()
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.id FROM model_tag AS a,
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
matches = matches.intersection(model_keys) if len(matches) > 0 else model_keys
|
||||
if matches:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config FROM model_config
|
||||
WHERE id IN ({','.join('?' * len(matches))});
|
||||
""",
|
||||
tuple(matches),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return results
|
||||
|
||||
def search_by_name(
|
||||
@ -467,20 +449,18 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
where_clause.append("model_type=?")
|
||||
bindings.append(model_type)
|
||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
select config FROM model_config
|
||||
{where};
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
with self._lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
select config FROM model_config
|
||||
{where};
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return results
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:
|
||||
|
@ -80,24 +80,18 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
raise ConfigFileVersionMismatchException
|
||||
|
||||
def _initialize_yaml(self):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
self._filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(self._filename, "w") as yaml_file:
|
||||
yaml_file.write(yaml.dump({"__metadata__": {"version": CONFIG_FILE_VERSION}}))
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _commit(self):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
newfile = Path(str(self._filename) + ".new")
|
||||
yaml_str = OmegaConf.to_yaml(self._config)
|
||||
with open(newfile, "w", encoding="utf-8") as outfile:
|
||||
outfile.write(yaml_str)
|
||||
newfile.replace(self._filename)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
@property
|
||||
def version(self) -> str:
|
||||
@ -116,8 +110,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect
|
||||
dict_fields = record.dict() # and back to a dict with valid fields
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
if key in self._config:
|
||||
existing_model = self.get_model(key)
|
||||
raise DuplicateModelException(
|
||||
@ -125,8 +118,6 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
)
|
||||
self._config[key] = self._fix_enums(dict_fields)
|
||||
self._commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get_model(key)
|
||||
|
||||
def _fix_enums(self, original: dict) -> dict:
|
||||
@ -144,14 +135,11 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
|
||||
Can raise an UnknownModelException
|
||||
"""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
if key not in self._config:
|
||||
raise UnknownModelException(f"Unknown key '{key}' for model config")
|
||||
self._config.pop(key)
|
||||
self._commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
|
||||
"""
|
||||
@ -163,14 +151,11 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect
|
||||
dict_fields = record.dict() # and back to a dict with valid fields
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
if key not in self._config:
|
||||
raise UnknownModelException(f"Unknown key '{key}' for model config")
|
||||
self._config[key] = self._fix_enums(dict_fields)
|
||||
self._commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get_model(key)
|
||||
|
||||
def get_model(self, key: str) -> AnyModelConfig:
|
||||
@ -203,15 +188,12 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
"""
|
||||
results = []
|
||||
tags = set(tags)
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
for config in self.all_models():
|
||||
config_tags = set(config.tags or [])
|
||||
if tags.difference(config_tags): # not all tags in the model
|
||||
continue
|
||||
results.append(config)
|
||||
finally:
|
||||
self._lock.release()
|
||||
return results
|
||||
|
||||
def search_by_name(
|
||||
@ -231,8 +213,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
models in the database.
|
||||
"""
|
||||
results: List[ModelConfigBase] = list()
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
for key, record in self._config.items():
|
||||
if key == "__metadata__":
|
||||
continue
|
||||
@ -244,20 +225,15 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
if model_type and model.model_type != model_type:
|
||||
continue
|
||||
results.append(model)
|
||||
finally:
|
||||
self._lock.release()
|
||||
return results
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:
|
||||
"""Return the model with the indicated path, or None."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
for key, record in self._config.items():
|
||||
if key == "__metadata__":
|
||||
continue
|
||||
model = ModelConfigFactory.make_config(record, str(key))
|
||||
if model.path == path:
|
||||
return model
|
||||
finally:
|
||||
self._lock.release()
|
||||
return None
|
||||
|
@ -270,8 +270,6 @@ def test_bad_urls():
|
||||
|
||||
def test_pause_cancel_url(): # this one is tricky because of potential race conditions
|
||||
def event_handler(job: DownloadJobBase):
|
||||
if job.id == 0:
|
||||
print(job.status, job.bytes)
|
||||
time.sleep(0.5) # slow down the thread so that we can recover the paused state
|
||||
|
||||
queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler])
|
||||
|
Reference in New Issue
Block a user