mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
address all PR 4252 comments from ryan through October 5
This commit is contained in:
@ -98,13 +98,18 @@ async def update_model(
|
|||||||
) -> InvokeAIModelConfig:
|
) -> InvokeAIModelConfig:
|
||||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
info_dict = info.dict()
|
||||||
|
record_store = ApiDependencies.invoker.services.model_record_store
|
||||||
|
model_install = ApiDependencies.invoker.services.model_installer
|
||||||
|
try:
|
||||||
|
new_config = record_store.update_model(key, config=info_dict)
|
||||||
|
except UnknownModelException as e:
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
info_dict = info.dict()
|
|
||||||
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
|
||||||
record_store = ApiDependencies.invoker.services.model_record_store
|
|
||||||
model_install = ApiDependencies.invoker.services.model_installer
|
|
||||||
new_config = record_store.update_model(key, config=info_dict)
|
|
||||||
# In the event that the model's name, type or base has changed, and the model itself
|
# In the event that the model's name, type or base has changed, and the model itself
|
||||||
# resides in the invokeai root models directory, then the next statement will move
|
# resides in the invokeai root models directory, then the next statement will move
|
||||||
# the model file into its new canonical location.
|
# the model file into its new canonical location.
|
||||||
@ -115,9 +120,6 @@ async def update_model(
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
@ -198,8 +200,8 @@ async def import_model(
|
|||||||
responses={
|
responses={
|
||||||
201: {"description": "The model added successfully"},
|
201: {"description": "The model added successfully"},
|
||||||
404: {"description": "The model could not be found"},
|
404: {"description": "The model could not be found"},
|
||||||
424: {"description": "The model appeared to add successfully, but could not be found in the model manager"},
|
|
||||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||||
|
415: {"description": "Unrecognized file/folder format"},
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
response_model=InvokeAIModelConfig,
|
response_model=InvokeAIModelConfig,
|
||||||
@ -213,16 +215,22 @@ async def add_model(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
path = info.path
|
||||||
|
installer = ApiDependencies.invoker.services.model_installer
|
||||||
|
record_store = ApiDependencies.invoker.services.model_record_store
|
||||||
try:
|
try:
|
||||||
path = info.path
|
|
||||||
installer = ApiDependencies.invoker.services.model_installer
|
|
||||||
record_store = ApiDependencies.invoker.services.model_record_store
|
|
||||||
key = installer.install_path(path)
|
key = installer.install_path(path)
|
||||||
logger.info(f"Created model {key} for {path}")
|
logger.info(f"Created model {key} for {path}")
|
||||||
|
except DuplicateModelException as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
except InvalidModelException as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=415)
|
||||||
|
|
||||||
# update with the provided info
|
# update with the provided info
|
||||||
|
try:
|
||||||
info_dict = info.dict()
|
info_dict = info.dict()
|
||||||
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
|
||||||
new_config = record_store.update_model(key, new_config=info_dict)
|
new_config = record_store.update_model(key, new_config=info_dict)
|
||||||
return parse_obj_as(InvokeAIModelConfig, new_config.dict())
|
return parse_obj_as(InvokeAIModelConfig, new_config.dict())
|
||||||
except UnknownModelException as e:
|
except UnknownModelException as e:
|
||||||
@ -405,24 +413,19 @@ async def merge_models(
|
|||||||
)
|
)
|
||||||
async def list_install_jobs() -> List[ModelImportStatus]:
|
async def list_install_jobs() -> List[ModelImportStatus]:
|
||||||
"""List active and pending model installation jobs."""
|
"""List active and pending model installation jobs."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
|
||||||
job_mgr = ApiDependencies.invoker.services.model_installer
|
job_mgr = ApiDependencies.invoker.services.model_installer
|
||||||
try:
|
jobs = job_mgr.list_install_jobs()
|
||||||
jobs = job_mgr.list_install_jobs()
|
return [
|
||||||
return [
|
ModelImportStatus(
|
||||||
ModelImportStatus(
|
job_id=x.id,
|
||||||
job_id=x.id,
|
source=x.source,
|
||||||
source=x.source,
|
priority=x.priority,
|
||||||
priority=x.priority,
|
bytes=x.bytes,
|
||||||
bytes=x.bytes,
|
total_bytes=x.total_bytes,
|
||||||
total_bytes=x.total_bytes,
|
status=x.status,
|
||||||
status=x.status,
|
)
|
||||||
)
|
for x in jobs
|
||||||
for x in jobs
|
]
|
||||||
]
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
@models_router.patch(
|
@models_router.patch(
|
||||||
@ -459,11 +462,11 @@ async def control_install_jobs(
|
|||||||
elif operation == JobControlOperation.CANCEL:
|
elif operation == JobControlOperation.CANCEL:
|
||||||
job_mgr.cancel_job(job_id)
|
job_mgr.cancel_job(job_id)
|
||||||
|
|
||||||
elif operation == JobControlOperation.CHANGE_PRIORITY:
|
elif operation == JobControlOperation.CHANGE_PRIORITY and priority_delta is not None:
|
||||||
job_mgr.change_job_priority(job_id, priority_delta)
|
job_mgr.change_job_priority(job_id, priority_delta)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown operation {JobControlOperation.value}")
|
raise ValueError(f"Unknown operation {operation.value}")
|
||||||
|
|
||||||
return ModelImportStatus(
|
return ModelImportStatus(
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
@ -478,9 +481,6 @@ async def control_install_jobs(
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
@models_router.patch(
|
@models_router.patch(
|
||||||
@ -490,39 +490,27 @@ async def control_install_jobs(
|
|||||||
204: {"description": "All jobs cancelled successfully"},
|
204: {"description": "All jobs cancelled successfully"},
|
||||||
400: {"description": "Bad request"},
|
400: {"description": "Bad request"},
|
||||||
},
|
},
|
||||||
status_code=200,
|
|
||||||
response_model=ModelImportStatus,
|
|
||||||
)
|
)
|
||||||
async def cancel_install_jobs():
|
async def cancel_install_jobs():
|
||||||
"""Cancel all pending install jobs."""
|
"""Cancel all model installation jobs."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
try:
|
job_mgr = ApiDependencies.invoker.services.model_installer
|
||||||
job_mgr = ApiDependencies.invoker.services.model_installer
|
logger.info("Cancelling all model installation jobs.")
|
||||||
logger.info("Cancelling all running model installation jobs.")
|
job_mgr.cancel_all_jobs()
|
||||||
job_mgr.cancel_all_jobs()
|
return Response(status_code=204)
|
||||||
return Response(status_code=204)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
@models_router.patch(
|
@models_router.patch(
|
||||||
"/jobs/prune",
|
"/jobs/prune",
|
||||||
operation_id="prune_jobs",
|
operation_id="prune_jobs",
|
||||||
responses={
|
responses={
|
||||||
204: {"description": "All jobs cancelled successfully"},
|
204: {"description": "All completed jobs have been pruned"},
|
||||||
400: {"description": "Bad request"},
|
400: {"description": "Bad request"},
|
||||||
},
|
},
|
||||||
status_code=200,
|
|
||||||
response_model=ModelImportStatus,
|
|
||||||
)
|
)
|
||||||
async def prune_jobs():
|
async def prune_jobs():
|
||||||
"""Prune all completed and errored jobs."""
|
"""Prune all completed and errored jobs."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
try:
|
mgr = ApiDependencies.invoker.services.model_installer
|
||||||
mgr = ApiDependencies.invoker.services.model_installer
|
mgr.prune_jobs()
|
||||||
mgr.prune_jobs()
|
return Response(status_code=204)
|
||||||
return Response(status_code=204)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
|
@ -126,7 +126,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
e.g. `max_parallel_dl`.
|
e.g. `max_parallel_dl`.
|
||||||
"""
|
"""
|
||||||
self._event_bus = event_bus
|
self._event_bus = event_bus
|
||||||
self._queue = DownloadQueue()
|
self._queue = DownloadQueue(**kwargs)
|
||||||
|
|
||||||
def create_download_job(
|
def create_download_job(
|
||||||
self,
|
self,
|
||||||
|
@ -92,7 +92,7 @@ class ModelInstallServiceBase(ModelInstallBase): # This is an ABC
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def cancel_all_jobs(self):
|
def cancel_all_jobs(self):
|
||||||
"""Cancel all active jobs."""
|
"""Cancel all installation jobs."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -62,7 +62,7 @@ class ModelLoadServiceBase(ABC):
|
|||||||
class ModelLoadService(ModelLoadServiceBase):
|
class ModelLoadService(ModelLoadServiceBase):
|
||||||
"""Responsible for managing models on disk and in memory."""
|
"""Responsible for managing models on disk and in memory."""
|
||||||
|
|
||||||
_loader: ModelLoad = Field(description="InvokeAIAppConfig object for the current process")
|
_loader: ModelLoad
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -12,6 +12,4 @@ from .model_manager import ( # noqa F401
|
|||||||
SilenceWarnings,
|
SilenceWarnings,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
)
|
)
|
||||||
from .model_manager.install import ModelInstall # noqa F401
|
|
||||||
from .model_manager.loader import ModelLoad # noqa F401
|
|
||||||
from .util.devices import get_precision # noqa F401
|
from .util.devices import get_precision # noqa F401
|
||||||
|
@ -18,7 +18,6 @@ from invokeai.backend.model_manager.install import ModelInstall, ModelInstallJob
|
|||||||
|
|
||||||
# name of the starter models file
|
# name of the starter models file
|
||||||
INITIAL_MODELS = "INITIAL_MODELS.yaml"
|
INITIAL_MODELS = "INITIAL_MODELS.yaml"
|
||||||
ACCESS_TOKEN = HfFolder.get_token()
|
|
||||||
|
|
||||||
|
|
||||||
class UnifiedModelInfo(BaseModel):
|
class UnifiedModelInfo(BaseModel):
|
||||||
@ -173,7 +172,7 @@ class InstallHelper(object):
|
|||||||
model.source,
|
model.source,
|
||||||
subfolder=model.subfolder,
|
subfolder=model.subfolder,
|
||||||
variant="fp16" if self._config.precision == "float16" else None,
|
variant="fp16" if self._config.precision == "float16" else None,
|
||||||
access_token=ACCESS_TOKEN, # this is a global,
|
access_token=HfFolder.get_token(),
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -185,14 +185,17 @@ class ProgressBar:
|
|||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
|
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
|
||||||
logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage())
|
filter = lambda x: "fp16 is not a valid" not in x.getMessage()
|
||||||
|
logger.addFilter(filter)
|
||||||
model = model_class.from_pretrained(
|
try:
|
||||||
model_name,
|
model = model_class.from_pretrained(
|
||||||
resume_download=True,
|
model_name,
|
||||||
**kwargs,
|
resume_download=True,
|
||||||
)
|
**kwargs,
|
||||||
model.save_pretrained(destination, safe_serialization=True)
|
)
|
||||||
|
model.save_pretrained(destination, safe_serialization=True)
|
||||||
|
finally:
|
||||||
|
logger.removeFilter(filter)
|
||||||
return destination
|
return destination
|
||||||
|
|
||||||
|
|
||||||
|
@ -212,7 +212,7 @@ class DownloadQueueBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def cancel_all_jobs(self, preserve_partial: bool = False):
|
def cancel_all_jobs(self, preserve_partial: bool = False):
|
||||||
"""
|
"""
|
||||||
Cancel all active and enquedjobs.
|
Cancel all jobs (those in enqueued, running and paused states).
|
||||||
|
|
||||||
:param preserve_partial: Keep partially downloaded files [False].
|
:param preserve_partial: Keep partially downloaded files [False].
|
||||||
"""
|
"""
|
||||||
|
@ -12,8 +12,7 @@ from queue import PriorityQueue
|
|||||||
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from huggingface_hub import HfApi, hf_hub_url
|
from pydantic import Field
|
||||||
from pydantic import Field, parse_obj_as, validator
|
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
|
|
||||||
@ -59,7 +58,7 @@ class DownloadQueue(DownloadQueueBase):
|
|||||||
_jobs: Dict[int, DownloadJobBase]
|
_jobs: Dict[int, DownloadJobBase]
|
||||||
_worker_pool: Set[threading.Thread]
|
_worker_pool: Set[threading.Thread]
|
||||||
_queue: PriorityQueue
|
_queue: PriorityQueue
|
||||||
_lock: threading.RLock
|
_lock: threading.RLock # to allow methods called within the same thread to lock without blocking
|
||||||
_logger: Logger
|
_logger: Logger
|
||||||
_event_handlers: List[DownloadEventHandler] = Field(default_factory=list)
|
_event_handlers: List[DownloadEventHandler] = Field(default_factory=list)
|
||||||
_next_job_id: int = 0
|
_next_job_id: int = 0
|
||||||
@ -136,13 +135,10 @@ class DownloadQueue(DownloadQueueBase):
|
|||||||
# add the queue's handlers
|
# add the queue's handlers
|
||||||
for handler in self._event_handlers:
|
for handler in self._event_handlers:
|
||||||
job.add_event_handler(handler)
|
job.add_event_handler(handler)
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
|
||||||
job.id = self._next_job_id
|
job.id = self._next_job_id
|
||||||
self._jobs[job.id] = job
|
self._jobs[job.id] = job
|
||||||
self._next_job_id += 1
|
self._next_job_id += 1
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
if start:
|
if start:
|
||||||
self.start_job(job)
|
self.start_job(job)
|
||||||
return job
|
return job
|
||||||
@ -163,29 +159,25 @@ class DownloadQueue(DownloadQueueBase):
|
|||||||
|
|
||||||
def change_priority(self, job: DownloadJobBase, delta: int):
|
def change_priority(self, job: DownloadJobBase, delta: int):
|
||||||
"""Change the priority of a job. Smaller priorities run first."""
|
"""Change the priority of a job. Smaller priorities run first."""
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
try:
|
||||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||||
job.priority += delta
|
job.priority += delta
|
||||||
except (AssertionError, KeyError) as excp:
|
except (AssertionError, KeyError) as excp:
|
||||||
raise UnknownJobIDException("Unrecognized job") from excp
|
raise UnknownJobIDException("Unrecognized job") from excp
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
def prune_jobs(self):
|
def prune_jobs(self):
|
||||||
"""Prune completed and errored queue items from the job list."""
|
"""Prune completed and errored queue items from the job list."""
|
||||||
try:
|
with self._lock:
|
||||||
to_delete = set()
|
to_delete = set()
|
||||||
self._lock.acquire()
|
try:
|
||||||
for job_id, job in self._jobs.items():
|
for job_id, job in self._jobs.items():
|
||||||
if self._in_terminal_state(job):
|
if self._in_terminal_state(job):
|
||||||
to_delete.add(job_id)
|
to_delete.add(job_id)
|
||||||
for job_id in to_delete:
|
for job_id in to_delete:
|
||||||
del self._jobs[job_id]
|
del self._jobs[job_id]
|
||||||
except KeyError as excp:
|
except KeyError as excp:
|
||||||
raise UnknownJobIDException("Unrecognized job") from excp
|
raise UnknownJobIDException("Unrecognized job") from excp
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False):
|
def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False):
|
||||||
"""
|
"""
|
||||||
@ -194,16 +186,14 @@ class DownloadQueue(DownloadQueueBase):
|
|||||||
If it is running it will be stopped.
|
If it is running it will be stopped.
|
||||||
job.status will be set to DownloadJobStatus.CANCELLED
|
job.status will be set to DownloadJobStatus.CANCELLED
|
||||||
"""
|
"""
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
try:
|
||||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||||
job.preserve_partial_downloads = preserve_partial
|
job.preserve_partial_downloads = preserve_partial
|
||||||
self._update_job_status(job, DownloadJobStatus.CANCELLED)
|
self._update_job_status(job, DownloadJobStatus.CANCELLED)
|
||||||
job.cleanup()
|
job.cleanup()
|
||||||
except (AssertionError, KeyError) as excp:
|
except (AssertionError, KeyError) as excp:
|
||||||
raise UnknownJobIDException("Unrecognized job") from excp
|
raise UnknownJobIDException("Unrecognized job") from excp
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
def id_to_job(self, id: int) -> DownloadJobBase:
|
def id_to_job(self, id: int) -> DownloadJobBase:
|
||||||
"""Translate a job ID into a DownloadJobBase object."""
|
"""Translate a job ID into a DownloadJobBase object."""
|
||||||
@ -214,12 +204,13 @@ class DownloadQueue(DownloadQueueBase):
|
|||||||
|
|
||||||
def start_job(self, job: DownloadJobBase):
|
def start_job(self, job: DownloadJobBase):
|
||||||
"""Enqueue (start) the indicated job."""
|
"""Enqueue (start) the indicated job."""
|
||||||
try:
|
with self._lock:
|
||||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
try:
|
||||||
self._update_job_status(job, DownloadJobStatus.ENQUEUED)
|
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||||
self._queue.put(job)
|
self._update_job_status(job, DownloadJobStatus.ENQUEUED)
|
||||||
except (AssertionError, KeyError) as excp:
|
self._queue.put(job)
|
||||||
raise UnknownJobIDException("Unrecognized job") from excp
|
except (AssertionError, KeyError) as excp:
|
||||||
|
raise UnknownJobIDException("Unrecognized job") from excp
|
||||||
|
|
||||||
def pause_job(self, job: DownloadJobBase):
|
def pause_job(self, job: DownloadJobBase):
|
||||||
"""
|
"""
|
||||||
@ -228,45 +219,34 @@ class DownloadQueue(DownloadQueueBase):
|
|||||||
The job can be restarted with start_job() and the download will pick up
|
The job can be restarted with start_job() and the download will pick up
|
||||||
from where it left off.
|
from where it left off.
|
||||||
"""
|
"""
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
try:
|
||||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||||
self._update_job_status(job, DownloadJobStatus.PAUSED)
|
self._update_job_status(job, DownloadJobStatus.PAUSED)
|
||||||
job.cleanup()
|
job.cleanup()
|
||||||
except (AssertionError, KeyError) as excp:
|
except (AssertionError, KeyError) as excp:
|
||||||
raise UnknownJobIDException("Unrecognized job") from excp
|
raise UnknownJobIDException("Unrecognized job") from excp
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
def start_all_jobs(self):
|
def start_all_jobs(self):
|
||||||
"""Start (enqueue) all jobs that are idle or paused."""
|
"""Start (enqueue) all jobs that are idle or paused."""
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
|
||||||
for job in self._jobs.values():
|
for job in self._jobs.values():
|
||||||
if job.status in [DownloadJobStatus.IDLE, DownloadJobStatus.PAUSED]:
|
if job.status in [DownloadJobStatus.IDLE, DownloadJobStatus.PAUSED]:
|
||||||
self.start_job(job)
|
self.start_job(job)
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
def pause_all_jobs(self):
|
def pause_all_jobs(self):
|
||||||
"""Pause all running jobs."""
|
"""Pause all running jobs."""
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
|
||||||
for job in self._jobs.values():
|
for job in self._jobs.values():
|
||||||
if job.status == DownloadJobStatus.RUNNING:
|
if job.status == DownloadJobStatus.RUNNING:
|
||||||
self.pause_job(job)
|
self.pause_job(job)
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
def cancel_all_jobs(self, preserve_partial: bool = False):
|
def cancel_all_jobs(self, preserve_partial: bool = False):
|
||||||
"""Cancel all running jobs."""
|
"""Cancel all jobs (those not in enqueued, running or paused state)."""
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
|
||||||
for job in self._jobs.values():
|
for job in self._jobs.values():
|
||||||
if not self._in_terminal_state(job):
|
if not self._in_terminal_state(job):
|
||||||
self.cancel_job(job, preserve_partial)
|
self.cancel_job(job, preserve_partial)
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
def _in_terminal_state(self, job: DownloadJobBase):
|
def _in_terminal_state(self, job: DownloadJobBase):
|
||||||
return job.status in [
|
return job.status in [
|
||||||
@ -288,26 +268,26 @@ class DownloadQueue(DownloadQueueBase):
|
|||||||
while not done:
|
while not done:
|
||||||
job = self._queue.get()
|
job = self._queue.get()
|
||||||
|
|
||||||
try: # this is for debugging priority
|
with self._lock:
|
||||||
self._lock.acquire()
|
|
||||||
job.job_sequence = self._sequence
|
job.job_sequence = self._sequence
|
||||||
self._sequence += 1
|
self._sequence += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
if job == STOP_JOB: # marker that queue is done
|
||||||
|
done = True
|
||||||
|
|
||||||
|
if (
|
||||||
|
job.status == DownloadJobStatus.ENQUEUED
|
||||||
|
): # Don't do anything for non-enqueued jobs (shouldn't happen)
|
||||||
|
if not self._quiet:
|
||||||
|
self._logger.info(f"{job.source}: Downloading to {job.destination}")
|
||||||
|
do_download = self.select_downloader(job)
|
||||||
|
do_download(job)
|
||||||
|
|
||||||
|
if job.status == DownloadJobStatus.CANCELLED:
|
||||||
|
self._cleanup_cancelled_job(job)
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._queue.task_done()
|
||||||
|
|
||||||
if job == STOP_JOB: # marker that queue is done
|
|
||||||
done = True
|
|
||||||
|
|
||||||
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for non-enqueued jobs (shouldn't happen)
|
|
||||||
if not self._quiet:
|
|
||||||
self._logger.info(f"{job.source}: Downloading to {job.destination}")
|
|
||||||
do_download = self.select_downloader(job)
|
|
||||||
do_download(job)
|
|
||||||
|
|
||||||
if job.status == DownloadJobStatus.CANCELLED:
|
|
||||||
self._cleanup_cancelled_job(job)
|
|
||||||
|
|
||||||
self._queue.task_done()
|
|
||||||
|
|
||||||
def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]:
|
def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]:
|
||||||
"""Based on the job type select the download method."""
|
"""Based on the job type select the download method."""
|
||||||
@ -397,11 +377,7 @@ class DownloadQueue(DownloadQueueBase):
|
|||||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||||
except KeyboardInterrupt as excp:
|
except KeyboardInterrupt as excp:
|
||||||
raise excp
|
raise excp
|
||||||
except DuplicateModelException as excp:
|
except (HTTPError, OSError) as excp:
|
||||||
self._logger.error(f"A model with the same hash as {dest} is already installed.")
|
|
||||||
job.error = excp
|
|
||||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
|
||||||
except Exception as excp:
|
|
||||||
self._logger.error(f"An error occurred while downloading/installing {job.source}: {str(excp)}")
|
self._logger.error(f"An error occurred while downloading/installing {job.source}: {str(excp)}")
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
job.error = excp
|
job.error = excp
|
||||||
|
@ -15,6 +15,8 @@ from typing import Dict, Union
|
|||||||
|
|
||||||
from imohash import hashfile
|
from imohash import hashfile
|
||||||
|
|
||||||
|
from .models import InvalidModelException
|
||||||
|
|
||||||
|
|
||||||
class FastModelHash(object):
|
class FastModelHash(object):
|
||||||
"""FastModelHash obect provides one public class method, hash()."""
|
"""FastModelHash obect provides one public class method, hash()."""
|
||||||
@ -32,9 +34,6 @@ class FastModelHash(object):
|
|||||||
elif model_location.is_dir():
|
elif model_location.is_dir():
|
||||||
return cls._hash_dir(model_location)
|
return cls._hash_dir(model_location)
|
||||||
else:
|
else:
|
||||||
# avoid circular import
|
|
||||||
from .models import InvalidModelException
|
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid file or directory: {model_location}")
|
raise InvalidModelException(f"Not a valid file or directory: {model_location}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -54,7 +53,8 @@ class FastModelHash(object):
|
|||||||
|
|
||||||
for root, dirs, files in os.walk(model_location):
|
for root, dirs, files in os.walk(model_location):
|
||||||
for file in files:
|
for file in files:
|
||||||
# only tally tensor files
|
# only tally tensor files because diffusers config files change slightly
|
||||||
|
# depending on how the model was downloaded/converted.
|
||||||
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
|
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
|
||||||
continue
|
continue
|
||||||
path = (Path(root) / file).as_posix()
|
path = (Path(root) / file).as_posix()
|
||||||
|
@ -72,14 +72,11 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
|||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
self._lock = lock
|
self._lock = lock
|
||||||
|
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
|
||||||
# Enable foreign keys
|
# Enable foreign keys
|
||||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||||
self._create_tables()
|
self._create_tables()
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
assert (
|
assert (
|
||||||
str(self.version) == CONFIG_FILE_VERSION
|
str(self.version) == CONFIG_FILE_VERSION
|
||||||
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
|
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
|
||||||
@ -189,52 +186,49 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
|||||||
"""
|
"""
|
||||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
||||||
json_serialized = json.dumps(record.dict()) # and turn it into a json string.
|
json_serialized = json.dumps(record.dict()) # and turn it into a json string.
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
try:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
INSERT INTO model_config (
|
INSERT INTO model_config (
|
||||||
id,
|
id,
|
||||||
base_model,
|
base_model,
|
||||||
model_type,
|
model_type,
|
||||||
model_name,
|
model_name,
|
||||||
model_path,
|
model_path,
|
||||||
config
|
config
|
||||||
)
|
)
|
||||||
VALUES (?,?,?,?,?,?);
|
VALUES (?,?,?,?,?,?);
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
key,
|
key,
|
||||||
record.base_model,
|
record.base_model,
|
||||||
record.model_type,
|
record.model_type,
|
||||||
record.name,
|
record.name,
|
||||||
record.path,
|
record.path,
|
||||||
json_serialized,
|
json_serialized,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if record.tags:
|
if record.tags:
|
||||||
self._update_tags(key, record.tags)
|
self._update_tags(key, record.tags)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
|
|
||||||
except sqlite3.IntegrityError as e:
|
except sqlite3.IntegrityError as e:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
if "UNIQUE constraint failed" in str(e):
|
if "UNIQUE constraint failed" in str(e):
|
||||||
raise DuplicateModelException(f"A model with key '{key}' is already installed") from e
|
raise DuplicateModelException(f"A model with key '{key}' is already installed") from e
|
||||||
else:
|
else:
|
||||||
|
raise e
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
raise e
|
raise e
|
||||||
except sqlite3.Error as e:
|
|
||||||
self._conn.rollback()
|
|
||||||
raise e
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
return self.get_model(key)
|
return self.get_model(key)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def version(self) -> str:
|
def version(self) -> str:
|
||||||
"""Return the version of the database schema."""
|
"""Return the version of the database schema."""
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
SELECT metadata_value FROM model_manager_metadata
|
SELECT metadata_value FROM model_manager_metadata
|
||||||
@ -246,8 +240,6 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
|||||||
if not rows:
|
if not rows:
|
||||||
raise KeyError("Models database does not have metadata key 'version'")
|
raise KeyError("Models database does not have metadata key 'version'")
|
||||||
return rows[0]
|
return rows[0]
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
def _update_tags(self, key: str, tags: List[str]) -> None:
|
def _update_tags(self, key: str, tags: List[str]) -> None:
|
||||||
"""Update tags for model with key."""
|
"""Update tags for model with key."""
|
||||||
@ -301,23 +293,21 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
|||||||
|
|
||||||
Can raise an UnknownModelException
|
Can raise an UnknownModelException
|
||||||
"""
|
"""
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
try:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
DELETE FROM model_config
|
DELETE FROM model_config
|
||||||
WHERE id=?;
|
WHERE id=?;
|
||||||
""",
|
""",
|
||||||
(key,),
|
(key,),
|
||||||
)
|
)
|
||||||
if self._cursor.rowcount == 0:
|
if self._cursor.rowcount == 0:
|
||||||
raise UnknownModelException
|
raise UnknownModelException
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
raise e
|
raise e
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
|
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
|
||||||
"""
|
"""
|
||||||
@ -329,30 +319,29 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
|||||||
"""
|
"""
|
||||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
||||||
json_serialized = json.dumps(record.dict()) # and turn it into a json string.
|
json_serialized = json.dumps(record.dict()) # and turn it into a json string.
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
try:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
UPDATE model_config
|
UPDATE model_config
|
||||||
SET base_model=?,
|
SET base_model=?,
|
||||||
model_type=?,
|
model_type=?,
|
||||||
model_name=?,
|
model_name=?,
|
||||||
model_path=?,
|
model_path=?,
|
||||||
config=?
|
config=?
|
||||||
WHERE id=?;
|
WHERE id=?;
|
||||||
""",
|
""",
|
||||||
(record.base_model, record.model_type, record.name, record.path, json_serialized, key),
|
(record.base_model, record.model_type, record.name, record.path, json_serialized, key),
|
||||||
)
|
)
|
||||||
if self._cursor.rowcount == 0:
|
if self._cursor.rowcount == 0:
|
||||||
raise UnknownModelException
|
raise UnknownModelException
|
||||||
if record.tags:
|
if record.tags:
|
||||||
self._update_tags(key, record.tags)
|
self._update_tags(key, record.tags)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
raise e
|
raise e
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
return self.get_model(key)
|
return self.get_model(key)
|
||||||
|
|
||||||
def get_model(self, key: str) -> AnyModelConfig:
|
def get_model(self, key: str) -> AnyModelConfig:
|
||||||
@ -363,8 +352,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
|||||||
|
|
||||||
Exceptions: UnknownModelException
|
Exceptions: UnknownModelException
|
||||||
"""
|
"""
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
SELECT config FROM model_config
|
SELECT config FROM model_config
|
||||||
@ -376,8 +364,6 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
|||||||
if not rows:
|
if not rows:
|
||||||
raise UnknownModelException
|
raise UnknownModelException
|
||||||
model = ModelConfigFactory.make_config(json.loads(rows[0]))
|
model = ModelConfigFactory.make_config(json.loads(rows[0]))
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def exists(self, key: str) -> bool:
|
def exists(self, key: str) -> bool:
|
||||||
@ -387,20 +373,18 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
|||||||
:param key: Unique key for the model to be deleted
|
:param key: Unique key for the model to be deleted
|
||||||
"""
|
"""
|
||||||
count = 0
|
count = 0
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
try:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
select count(*) FROM model_config
|
select count(*) FROM model_config
|
||||||
WHERE id=?;
|
WHERE id=?;
|
||||||
""",
|
""",
|
||||||
(key,),
|
(key,),
|
||||||
)
|
)
|
||||||
count = self._cursor.fetchone()[0]
|
count = self._cursor.fetchone()[0]
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
raise e
|
raise e
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
return count > 0
|
return count > 0
|
||||||
|
|
||||||
def search_by_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
def search_by_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||||
@ -408,34 +392,32 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
|||||||
# rather than create a hairy SQL cross-product, we intersect
|
# rather than create a hairy SQL cross-product, we intersect
|
||||||
# tag results in a stepwise fashion at the python level.
|
# tag results in a stepwise fashion at the python level.
|
||||||
results = []
|
results = []
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
try:
|
||||||
matches: Set[str] = set()
|
matches: Set[str] = set()
|
||||||
for tag in tags:
|
for tag in tags:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
SELECT a.id FROM model_tag AS a,
|
SELECT a.id FROM model_tag AS a,
|
||||||
tags AS b
|
tags AS b
|
||||||
WHERE a.tag_id=b.tag_id
|
WHERE a.tag_id=b.tag_id
|
||||||
AND b.tag_text=?;
|
AND b.tag_text=?;
|
||||||
""",
|
""",
|
||||||
(tag,),
|
(tag,),
|
||||||
)
|
)
|
||||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||||
matches = matches.intersection(model_keys) if len(matches) > 0 else model_keys
|
matches = matches.intersection(model_keys) if len(matches) > 0 else model_keys
|
||||||
if matches:
|
if matches:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""--sql
|
f"""--sql
|
||||||
SELECT config FROM model_config
|
SELECT config FROM model_config
|
||||||
WHERE id IN ({','.join('?' * len(matches))});
|
WHERE id IN ({','.join('?' * len(matches))});
|
||||||
""",
|
""",
|
||||||
tuple(matches),
|
tuple(matches),
|
||||||
)
|
)
|
||||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
raise e
|
raise e
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def search_by_name(
|
def search_by_name(
|
||||||
@ -467,20 +449,18 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
|||||||
where_clause.append("model_type=?")
|
where_clause.append("model_type=?")
|
||||||
bindings.append(model_type)
|
bindings.append(model_type)
|
||||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
try:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""--sql
|
f"""--sql
|
||||||
select config FROM model_config
|
select config FROM model_config
|
||||||
{where};
|
{where};
|
||||||
""",
|
""",
|
||||||
tuple(bindings),
|
tuple(bindings),
|
||||||
)
|
)
|
||||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
raise e
|
raise e
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:
|
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:
|
||||||
|
@ -80,24 +80,18 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
|||||||
raise ConfigFileVersionMismatchException
|
raise ConfigFileVersionMismatchException
|
||||||
|
|
||||||
def _initialize_yaml(self):
|
def _initialize_yaml(self):
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
|
||||||
self._filename.parent.mkdir(parents=True, exist_ok=True)
|
self._filename.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with open(self._filename, "w") as yaml_file:
|
with open(self._filename, "w") as yaml_file:
|
||||||
yaml_file.write(yaml.dump({"__metadata__": {"version": CONFIG_FILE_VERSION}}))
|
yaml_file.write(yaml.dump({"__metadata__": {"version": CONFIG_FILE_VERSION}}))
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
def _commit(self):
|
def _commit(self):
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
|
||||||
newfile = Path(str(self._filename) + ".new")
|
newfile = Path(str(self._filename) + ".new")
|
||||||
yaml_str = OmegaConf.to_yaml(self._config)
|
yaml_str = OmegaConf.to_yaml(self._config)
|
||||||
with open(newfile, "w", encoding="utf-8") as outfile:
|
with open(newfile, "w", encoding="utf-8") as outfile:
|
||||||
outfile.write(yaml_str)
|
outfile.write(yaml_str)
|
||||||
newfile.replace(self._filename)
|
newfile.replace(self._filename)
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def version(self) -> str:
|
def version(self) -> str:
|
||||||
@ -116,8 +110,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
|||||||
"""
|
"""
|
||||||
record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect
|
record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect
|
||||||
dict_fields = record.dict() # and back to a dict with valid fields
|
dict_fields = record.dict() # and back to a dict with valid fields
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
|
||||||
if key in self._config:
|
if key in self._config:
|
||||||
existing_model = self.get_model(key)
|
existing_model = self.get_model(key)
|
||||||
raise DuplicateModelException(
|
raise DuplicateModelException(
|
||||||
@ -125,8 +118,6 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
|||||||
)
|
)
|
||||||
self._config[key] = self._fix_enums(dict_fields)
|
self._config[key] = self._fix_enums(dict_fields)
|
||||||
self._commit()
|
self._commit()
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
return self.get_model(key)
|
return self.get_model(key)
|
||||||
|
|
||||||
def _fix_enums(self, original: dict) -> dict:
|
def _fix_enums(self, original: dict) -> dict:
|
||||||
@ -144,14 +135,11 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
|||||||
|
|
||||||
Can raise an UnknownModelException
|
Can raise an UnknownModelException
|
||||||
"""
|
"""
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
|
||||||
if key not in self._config:
|
if key not in self._config:
|
||||||
raise UnknownModelException(f"Unknown key '{key}' for model config")
|
raise UnknownModelException(f"Unknown key '{key}' for model config")
|
||||||
self._config.pop(key)
|
self._config.pop(key)
|
||||||
self._commit()
|
self._commit()
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
|
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
|
||||||
"""
|
"""
|
||||||
@ -163,14 +151,11 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
|||||||
"""
|
"""
|
||||||
record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect
|
record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect
|
||||||
dict_fields = record.dict() # and back to a dict with valid fields
|
dict_fields = record.dict() # and back to a dict with valid fields
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
|
||||||
if key not in self._config:
|
if key not in self._config:
|
||||||
raise UnknownModelException(f"Unknown key '{key}' for model config")
|
raise UnknownModelException(f"Unknown key '{key}' for model config")
|
||||||
self._config[key] = self._fix_enums(dict_fields)
|
self._config[key] = self._fix_enums(dict_fields)
|
||||||
self._commit()
|
self._commit()
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
return self.get_model(key)
|
return self.get_model(key)
|
||||||
|
|
||||||
def get_model(self, key: str) -> AnyModelConfig:
|
def get_model(self, key: str) -> AnyModelConfig:
|
||||||
@ -203,15 +188,12 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
|||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
tags = set(tags)
|
tags = set(tags)
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
|
||||||
for config in self.all_models():
|
for config in self.all_models():
|
||||||
config_tags = set(config.tags or [])
|
config_tags = set(config.tags or [])
|
||||||
if tags.difference(config_tags): # not all tags in the model
|
if tags.difference(config_tags): # not all tags in the model
|
||||||
continue
|
continue
|
||||||
results.append(config)
|
results.append(config)
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def search_by_name(
|
def search_by_name(
|
||||||
@ -231,8 +213,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
|||||||
models in the database.
|
models in the database.
|
||||||
"""
|
"""
|
||||||
results: List[ModelConfigBase] = list()
|
results: List[ModelConfigBase] = list()
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
|
||||||
for key, record in self._config.items():
|
for key, record in self._config.items():
|
||||||
if key == "__metadata__":
|
if key == "__metadata__":
|
||||||
continue
|
continue
|
||||||
@ -244,20 +225,15 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
|||||||
if model_type and model.model_type != model_type:
|
if model_type and model.model_type != model_type:
|
||||||
continue
|
continue
|
||||||
results.append(model)
|
results.append(model)
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:
|
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:
|
||||||
"""Return the model with the indicated path, or None."""
|
"""Return the model with the indicated path, or None."""
|
||||||
try:
|
with self._lock:
|
||||||
self._lock.acquire()
|
|
||||||
for key, record in self._config.items():
|
for key, record in self._config.items():
|
||||||
if key == "__metadata__":
|
if key == "__metadata__":
|
||||||
continue
|
continue
|
||||||
model = ModelConfigFactory.make_config(record, str(key))
|
model = ModelConfigFactory.make_config(record, str(key))
|
||||||
if model.path == path:
|
if model.path == path:
|
||||||
return model
|
return model
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
return None
|
return None
|
||||||
|
@ -270,8 +270,6 @@ def test_bad_urls():
|
|||||||
|
|
||||||
def test_pause_cancel_url(): # this one is tricky because of potential race conditions
|
def test_pause_cancel_url(): # this one is tricky because of potential race conditions
|
||||||
def event_handler(job: DownloadJobBase):
|
def event_handler(job: DownloadJobBase):
|
||||||
if job.id == 0:
|
|
||||||
print(job.status, job.bytes)
|
|
||||||
time.sleep(0.5) # slow down the thread so that we can recover the paused state
|
time.sleep(0.5) # slow down the thread so that we can recover the paused state
|
||||||
|
|
||||||
queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler])
|
queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler])
|
||||||
|
Reference in New Issue
Block a user