address all PR 4252 comments from ryan through October 5

This commit is contained in:
Lincoln Stein
2023-10-09 00:28:21 -04:00
parent ce2baa36a9
commit fe1038665c
13 changed files with 260 additions and 342 deletions

View File

@ -98,13 +98,18 @@ async def update_model(
) -> InvokeAIModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
try:
info_dict = info.dict()
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
record_store = ApiDependencies.invoker.services.model_record_store
model_install = ApiDependencies.invoker.services.model_installer
try:
new_config = record_store.update_model(key, config=info_dict)
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
try:
# In the event that the model's name, type or base has changed, and the model itself
# resides in the invokeai root models directory, then the next statement will move
# the model file into its new canonical location.
@ -115,9 +120,6 @@ async def update_model(
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e))
return model_response
@ -198,8 +200,8 @@ async def import_model(
responses={
201: {"description": "The model added successfully"},
404: {"description": "The model could not be found"},
424: {"description": "The model appeared to add successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
415: {"description": "Unrecognized file/folder format"},
},
status_code=201,
response_model=InvokeAIModelConfig,
@ -213,16 +215,22 @@ async def add_model(
"""
logger = ApiDependencies.invoker.services.logger
try:
path = info.path
installer = ApiDependencies.invoker.services.model_installer
record_store = ApiDependencies.invoker.services.model_record_store
try:
key = installer.install_path(path)
logger.info(f"Created model {key} for {path}")
except DuplicateModelException as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
# update with the provided info
try:
info_dict = info.dict()
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
new_config = record_store.update_model(key, new_config=info_dict)
return parse_obj_as(InvokeAIModelConfig, new_config.dict())
except UnknownModelException as e:
@ -405,9 +413,7 @@ async def merge_models(
)
async def list_install_jobs() -> List[ModelImportStatus]:
"""List active and pending model installation jobs."""
logger = ApiDependencies.invoker.services.logger
job_mgr = ApiDependencies.invoker.services.model_installer
try:
jobs = job_mgr.list_install_jobs()
return [
ModelImportStatus(
@ -420,9 +426,6 @@ async def list_install_jobs() -> List[ModelImportStatus]:
)
for x in jobs
]
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e))
@models_router.patch(
@ -459,11 +462,11 @@ async def control_install_jobs(
elif operation == JobControlOperation.CANCEL:
job_mgr.cancel_job(job_id)
elif operation == JobControlOperation.CHANGE_PRIORITY:
elif operation == JobControlOperation.CHANGE_PRIORITY and priority_delta is not None:
job_mgr.change_job_priority(job_id, priority_delta)
else:
raise ValueError(f"Unknown operation {JobControlOperation.value}")
raise ValueError(f"Unknown operation {operation.value}")
return ModelImportStatus(
job_id=job_id,
@ -478,9 +481,6 @@ async def control_install_jobs(
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e))
@models_router.patch(
@ -490,39 +490,27 @@ async def control_install_jobs(
204: {"description": "All jobs cancelled successfully"},
400: {"description": "Bad request"},
},
status_code=200,
response_model=ModelImportStatus,
)
async def cancel_install_jobs():
"""Cancel all pending install jobs."""
"""Cancel all model installation jobs."""
logger = ApiDependencies.invoker.services.logger
try:
job_mgr = ApiDependencies.invoker.services.model_installer
logger.info("Cancelling all running model installation jobs.")
logger.info("Cancelling all model installation jobs.")
job_mgr.cancel_all_jobs()
return Response(status_code=204)
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e))
@models_router.patch(
"/jobs/prune",
operation_id="prune_jobs",
responses={
204: {"description": "All jobs cancelled successfully"},
204: {"description": "All completed jobs have been pruned"},
400: {"description": "Bad request"},
},
status_code=200,
response_model=ModelImportStatus,
)
async def prune_jobs():
"""Prune all completed and errored jobs."""
logger = ApiDependencies.invoker.services.logger
try:
mgr = ApiDependencies.invoker.services.model_installer
mgr.prune_jobs()
return Response(status_code=204)
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e))

View File

@ -126,7 +126,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
e.g. `max_parallel_dl`.
"""
self._event_bus = event_bus
self._queue = DownloadQueue()
self._queue = DownloadQueue(**kwargs)
def create_download_job(
self,

View File

@ -92,7 +92,7 @@ class ModelInstallServiceBase(ModelInstallBase): # This is an ABC
@abstractmethod
def cancel_all_jobs(self):
"""Cancel all active jobs."""
"""Cancel all installation jobs."""
pass
@abstractmethod

View File

@ -62,7 +62,7 @@ class ModelLoadServiceBase(ABC):
class ModelLoadService(ModelLoadServiceBase):
"""Responsible for managing models on disk and in memory."""
_loader: ModelLoad = Field(description="InvokeAIAppConfig object for the current process")
_loader: ModelLoad
def __init__(
self,

View File

@ -12,6 +12,4 @@ from .model_manager import ( # noqa F401
SilenceWarnings,
SubModelType,
)
from .model_manager.install import ModelInstall # noqa F401
from .model_manager.loader import ModelLoad # noqa F401
from .util.devices import get_precision # noqa F401

View File

@ -18,7 +18,6 @@ from invokeai.backend.model_manager.install import ModelInstall, ModelInstallJob
# name of the starter models file
INITIAL_MODELS = "INITIAL_MODELS.yaml"
ACCESS_TOKEN = HfFolder.get_token()
class UnifiedModelInfo(BaseModel):
@ -173,7 +172,7 @@ class InstallHelper(object):
model.source,
subfolder=model.subfolder,
variant="fp16" if self._config.precision == "float16" else None,
access_token=ACCESS_TOKEN, # this is a global,
access_token=HfFolder.get_token(),
metadata=metadata,
)

View File

@ -185,14 +185,17 @@ class ProgressBar:
# ---------------------------------------------
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage())
filter = lambda x: "fp16 is not a valid" not in x.getMessage()
logger.addFilter(filter)
try:
model = model_class.from_pretrained(
model_name,
resume_download=True,
**kwargs,
)
model.save_pretrained(destination, safe_serialization=True)
finally:
logger.removeFilter(filter)
return destination

View File

@ -212,7 +212,7 @@ class DownloadQueueBase(ABC):
@abstractmethod
def cancel_all_jobs(self, preserve_partial: bool = False):
"""
Cancel all active and enquedjobs.
Cancel all jobs (those in enqueued, running and paused states).
:param preserve_partial: Keep partially downloaded files [False].
"""

View File

@ -12,8 +12,7 @@ from queue import PriorityQueue
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import requests
from huggingface_hub import HfApi, hf_hub_url
from pydantic import Field, parse_obj_as, validator
from pydantic import Field
from pydantic.networks import AnyHttpUrl
from requests import HTTPError
@ -59,7 +58,7 @@ class DownloadQueue(DownloadQueueBase):
_jobs: Dict[int, DownloadJobBase]
_worker_pool: Set[threading.Thread]
_queue: PriorityQueue
_lock: threading.RLock
_lock: threading.RLock # to allow methods called within the same thread to lock without blocking
_logger: Logger
_event_handlers: List[DownloadEventHandler] = Field(default_factory=list)
_next_job_id: int = 0
@ -136,13 +135,10 @@ class DownloadQueue(DownloadQueueBase):
# add the queue's handlers
for handler in self._event_handlers:
job.add_event_handler(handler)
try:
self._lock.acquire()
with self._lock:
job.id = self._next_job_id
self._jobs[job.id] = job
self._next_job_id += 1
finally:
self._lock.release()
if start:
self.start_job(job)
return job
@ -163,20 +159,18 @@ class DownloadQueue(DownloadQueueBase):
def change_priority(self, job: DownloadJobBase, delta: int):
"""Change the priority of a job. Smaller priorities run first."""
with self._lock:
try:
self._lock.acquire()
assert isinstance(self._jobs[job.id], DownloadJobBase)
job.priority += delta
except (AssertionError, KeyError) as excp:
raise UnknownJobIDException("Unrecognized job") from excp
finally:
self._lock.release()
def prune_jobs(self):
"""Prune completed and errored queue items from the job list."""
try:
with self._lock:
to_delete = set()
self._lock.acquire()
try:
for job_id, job in self._jobs.items():
if self._in_terminal_state(job):
to_delete.add(job_id)
@ -184,8 +178,6 @@ class DownloadQueue(DownloadQueueBase):
del self._jobs[job_id]
except KeyError as excp:
raise UnknownJobIDException("Unrecognized job") from excp
finally:
self._lock.release()
def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False):
"""
@ -194,16 +186,14 @@ class DownloadQueue(DownloadQueueBase):
If it is running it will be stopped.
job.status will be set to DownloadJobStatus.CANCELLED
"""
with self._lock:
try:
self._lock.acquire()
assert isinstance(self._jobs[job.id], DownloadJobBase)
job.preserve_partial_downloads = preserve_partial
self._update_job_status(job, DownloadJobStatus.CANCELLED)
job.cleanup()
except (AssertionError, KeyError) as excp:
raise UnknownJobIDException("Unrecognized job") from excp
finally:
self._lock.release()
def id_to_job(self, id: int) -> DownloadJobBase:
"""Translate a job ID into a DownloadJobBase object."""
@ -214,6 +204,7 @@ class DownloadQueue(DownloadQueueBase):
def start_job(self, job: DownloadJobBase):
"""Enqueue (start) the indicated job."""
with self._lock:
try:
assert isinstance(self._jobs[job.id], DownloadJobBase)
self._update_job_status(job, DownloadJobStatus.ENQUEUED)
@ -228,45 +219,34 @@ class DownloadQueue(DownloadQueueBase):
The job can be restarted with start_job() and the download will pick up
from where it left off.
"""
with self._lock:
try:
self._lock.acquire()
assert isinstance(self._jobs[job.id], DownloadJobBase)
self._update_job_status(job, DownloadJobStatus.PAUSED)
job.cleanup()
except (AssertionError, KeyError) as excp:
raise UnknownJobIDException("Unrecognized job") from excp
finally:
self._lock.release()
def start_all_jobs(self):
"""Start (enqueue) all jobs that are idle or paused."""
try:
self._lock.acquire()
with self._lock:
for job in self._jobs.values():
if job.status in [DownloadJobStatus.IDLE, DownloadJobStatus.PAUSED]:
self.start_job(job)
finally:
self._lock.release()
def pause_all_jobs(self):
"""Pause all running jobs."""
try:
self._lock.acquire()
with self._lock:
for job in self._jobs.values():
if job.status == DownloadJobStatus.RUNNING:
self.pause_job(job)
finally:
self._lock.release()
def cancel_all_jobs(self, preserve_partial: bool = False):
"""Cancel all running jobs."""
try:
self._lock.acquire()
"""Cancel all jobs (those not in enqueued, running or paused state)."""
with self._lock:
for job in self._jobs.values():
if not self._in_terminal_state(job):
self.cancel_job(job, preserve_partial)
finally:
self._lock.release()
def _in_terminal_state(self, job: DownloadJobBase):
return job.status in [
@ -288,17 +268,17 @@ class DownloadQueue(DownloadQueueBase):
while not done:
job = self._queue.get()
try: # this is for debugging priority
self._lock.acquire()
with self._lock:
job.job_sequence = self._sequence
self._sequence += 1
finally:
self._lock.release()
try:
if job == STOP_JOB: # marker that queue is done
done = True
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for non-enqueued jobs (shouldn't happen)
if (
job.status == DownloadJobStatus.ENQUEUED
): # Don't do anything for non-enqueued jobs (shouldn't happen)
if not self._quiet:
self._logger.info(f"{job.source}: Downloading to {job.destination}")
do_download = self.select_downloader(job)
@ -306,7 +286,7 @@ class DownloadQueue(DownloadQueueBase):
if job.status == DownloadJobStatus.CANCELLED:
self._cleanup_cancelled_job(job)
finally:
self._queue.task_done()
def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]:
@ -397,11 +377,7 @@ class DownloadQueue(DownloadQueueBase):
self._update_job_status(job, DownloadJobStatus.COMPLETED)
except KeyboardInterrupt as excp:
raise excp
except DuplicateModelException as excp:
self._logger.error(f"A model with the same hash as {dest} is already installed.")
job.error = excp
self._update_job_status(job, DownloadJobStatus.ERROR)
except Exception as excp:
except (HTTPError, OSError) as excp:
self._logger.error(f"An error occurred while downloading/installing {job.source}: {str(excp)}")
print(traceback.format_exc())
job.error = excp

View File

@ -15,6 +15,8 @@ from typing import Dict, Union
from imohash import hashfile
from .models import InvalidModelException
class FastModelHash(object):
"""FastModelHash obect provides one public class method, hash()."""
@ -32,9 +34,6 @@ class FastModelHash(object):
elif model_location.is_dir():
return cls._hash_dir(model_location)
else:
# avoid circular import
from .models import InvalidModelException
raise InvalidModelException(f"Not a valid file or directory: {model_location}")
@classmethod
@ -54,7 +53,8 @@ class FastModelHash(object):
for root, dirs, files in os.walk(model_location):
for file in files:
# only tally tensor files
# only tally tensor files because diffusers config files change slightly
# depending on how the model was downloaded/converted.
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
continue
path = (Path(root) / file).as_posix()

View File

@ -72,14 +72,11 @@ class ModelConfigStoreSQL(ModelConfigStore):
self._cursor = self._conn.cursor()
self._lock = lock
try:
self._lock.acquire()
with self._lock:
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
assert (
str(self.version) == CONFIG_FILE_VERSION
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
@ -189,8 +186,8 @@ class ModelConfigStoreSQL(ModelConfigStore):
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
json_serialized = json.dumps(record.dict()) # and turn it into a json string.
with self._lock:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT INTO model_config (
@ -225,16 +222,13 @@ class ModelConfigStoreSQL(ModelConfigStore):
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return self.get_model(key)
@property
def version(self) -> str:
"""Return the version of the database schema."""
try:
self._lock.acquire()
with self._lock:
self._cursor.execute(
"""--sql
SELECT metadata_value FROM model_manager_metadata
@ -246,8 +240,6 @@ class ModelConfigStoreSQL(ModelConfigStore):
if not rows:
raise KeyError("Models database does not have metadata key 'version'")
return rows[0]
finally:
self._lock.release()
def _update_tags(self, key: str, tags: List[str]) -> None:
"""Update tags for model with key."""
@ -301,8 +293,8 @@ class ModelConfigStoreSQL(ModelConfigStore):
Can raise an UnknownModelException
"""
with self._lock:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM model_config
@ -316,8 +308,6 @@ class ModelConfigStoreSQL(ModelConfigStore):
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
"""
@ -329,8 +319,8 @@ class ModelConfigStoreSQL(ModelConfigStore):
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
json_serialized = json.dumps(record.dict()) # and turn it into a json string.
with self._lock:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
UPDATE model_config
@ -351,8 +341,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return self.get_model(key)
def get_model(self, key: str) -> AnyModelConfig:
@ -363,8 +352,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
Exceptions: UnknownModelException
"""
try:
self._lock.acquire()
with self._lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
@ -376,8 +364,6 @@ class ModelConfigStoreSQL(ModelConfigStore):
if not rows:
raise UnknownModelException
model = ModelConfigFactory.make_config(json.loads(rows[0]))
finally:
self._lock.release()
return model
def exists(self, key: str) -> bool:
@ -387,8 +373,8 @@ class ModelConfigStoreSQL(ModelConfigStore):
:param key: Unique key for the model to be deleted
"""
count = 0
with self._lock:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
select count(*) FROM model_config
@ -399,8 +385,6 @@ class ModelConfigStoreSQL(ModelConfigStore):
count = self._cursor.fetchone()[0]
except sqlite3.Error as e:
raise e
finally:
self._lock.release()
return count > 0
def search_by_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
@ -408,8 +392,8 @@ class ModelConfigStoreSQL(ModelConfigStore):
# rather than create a hairy SQL cross-product, we intersect
# tag results in a stepwise fashion at the python level.
results = []
with self._lock:
try:
self._lock.acquire()
matches: Set[str] = set()
for tag in tags:
self._cursor.execute(
@ -434,8 +418,6 @@ class ModelConfigStoreSQL(ModelConfigStore):
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
except sqlite3.Error as e:
raise e
finally:
self._lock.release()
return results
def search_by_name(
@ -467,8 +449,8 @@ class ModelConfigStoreSQL(ModelConfigStore):
where_clause.append("model_type=?")
bindings.append(model_type)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
with self._lock:
try:
self._lock.acquire()
self._cursor.execute(
f"""--sql
select config FROM model_config
@ -479,8 +461,6 @@ class ModelConfigStoreSQL(ModelConfigStore):
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
except sqlite3.Error as e:
raise e
finally:
self._lock.release()
return results
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:

View File

@ -80,24 +80,18 @@ class ModelConfigStoreYAML(ModelConfigStore):
raise ConfigFileVersionMismatchException
def _initialize_yaml(self):
try:
self._lock.acquire()
with self._lock:
self._filename.parent.mkdir(parents=True, exist_ok=True)
with open(self._filename, "w") as yaml_file:
yaml_file.write(yaml.dump({"__metadata__": {"version": CONFIG_FILE_VERSION}}))
finally:
self._lock.release()
def _commit(self):
try:
self._lock.acquire()
with self._lock:
newfile = Path(str(self._filename) + ".new")
yaml_str = OmegaConf.to_yaml(self._config)
with open(newfile, "w", encoding="utf-8") as outfile:
outfile.write(yaml_str)
newfile.replace(self._filename)
finally:
self._lock.release()
@property
def version(self) -> str:
@ -116,8 +110,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
"""
record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect
dict_fields = record.dict() # and back to a dict with valid fields
try:
self._lock.acquire()
with self._lock:
if key in self._config:
existing_model = self.get_model(key)
raise DuplicateModelException(
@ -125,8 +118,6 @@ class ModelConfigStoreYAML(ModelConfigStore):
)
self._config[key] = self._fix_enums(dict_fields)
self._commit()
finally:
self._lock.release()
return self.get_model(key)
def _fix_enums(self, original: dict) -> dict:
@ -144,14 +135,11 @@ class ModelConfigStoreYAML(ModelConfigStore):
Can raise an UnknownModelException
"""
try:
self._lock.acquire()
with self._lock:
if key not in self._config:
raise UnknownModelException(f"Unknown key '{key}' for model config")
self._config.pop(key)
self._commit()
finally:
self._lock.release()
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
"""
@ -163,14 +151,11 @@ class ModelConfigStoreYAML(ModelConfigStore):
"""
record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect
dict_fields = record.dict() # and back to a dict with valid fields
try:
self._lock.acquire()
with self._lock:
if key not in self._config:
raise UnknownModelException(f"Unknown key '{key}' for model config")
self._config[key] = self._fix_enums(dict_fields)
self._commit()
finally:
self._lock.release()
return self.get_model(key)
def get_model(self, key: str) -> AnyModelConfig:
@ -203,15 +188,12 @@ class ModelConfigStoreYAML(ModelConfigStore):
"""
results = []
tags = set(tags)
try:
self._lock.acquire()
with self._lock:
for config in self.all_models():
config_tags = set(config.tags or [])
if tags.difference(config_tags): # not all tags in the model
continue
results.append(config)
finally:
self._lock.release()
return results
def search_by_name(
@ -231,8 +213,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
models in the database.
"""
results: List[ModelConfigBase] = list()
try:
self._lock.acquire()
with self._lock:
for key, record in self._config.items():
if key == "__metadata__":
continue
@ -244,20 +225,15 @@ class ModelConfigStoreYAML(ModelConfigStore):
if model_type and model.model_type != model_type:
continue
results.append(model)
finally:
self._lock.release()
return results
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:
"""Return the model with the indicated path, or None."""
try:
self._lock.acquire()
with self._lock:
for key, record in self._config.items():
if key == "__metadata__":
continue
model = ModelConfigFactory.make_config(record, str(key))
if model.path == path:
return model
finally:
self._lock.release()
return None

View File

@ -270,8 +270,6 @@ def test_bad_urls():
def test_pause_cancel_url(): # this one is tricky because of potential race conditions
def event_handler(job: DownloadJobBase):
if job.id == 0:
print(job.status, job.bytes)
time.sleep(0.5) # slow down the thread so that we can recover the paused state
queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler])