all backend features in place; config scanning is failing on controlnet

This commit is contained in:
Lincoln Stein 2023-11-24 19:37:46 -05:00
parent 80bc9be3ab
commit 19baea1883
7 changed files with 342 additions and 65 deletions

View File

@ -1,6 +1,5 @@
""" """Init file for InvokeAI configure package."""
Init file for InvokeAI configure package
"""
from .config_base import PagingArgumentParser # noqa F401 from .config_default import InvokeAIAppConfig, get_invokeai_config
from .config_default import InvokeAIAppConfig, get_invokeai_config # noqa F401
__all__ = ['InvokeAIAppConfig', 'get_invokeai_config']

View File

@ -323,21 +323,23 @@ class EventServiceBase:
""" """
self.__emit_queue_event( self.__emit_queue_event(
event_name="model_install_started", event_name="model_install_started",
payload={"source": source}, payload={
"source": source
},
) )
def emit_model_install_completed(self, source: str, dest: str) -> None: def emit_model_install_completed(self, source: str, key: str) -> None:
""" """
Emitted when an install job is completed successfully. Emitted when an install job is completed successfully.
:param source: Source of the model; local path, repo_id or url :param source: Source of the model; local path, repo_id or url
:param dest: Destination of the model files; always a local path :param key: Model config record key
""" """
self.__emit_queue_event( self.__emit_queue_event(
event_name="model_install_completed", event_name="model_install_completed",
payload={ payload={
"source": source, "source": source,
"dest": dest, "key": key,
}, },
) )

View File

@ -1,2 +1,6 @@
from .model_install_base import ModelInstallServiceBase # noqa F401 """Initialization file for model install service package."""
from .model_install_default import ModelInstallService # noqa F401
from .model_install_base import InstallStatus, ModelInstallServiceBase, ModelInstallJob, UnknownInstallJobException
from .model_install_default import ModelInstallService
__all__ = ['ModelInstallServiceBase', 'ModelInstallService', 'InstallStatus', 'ModelInstallJob', 'UnknownInstallJobException']

View File

@ -21,12 +21,21 @@ class InstallStatus(str, Enum):
ERROR = "error" # terminated with an error message ERROR = "error" # terminated with an error message
class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested."""
ModelSource = Union[str, Path, AnyHttpUrl]
class ModelInstallJob(BaseModel): class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request.""" """Object that tracks the current status of an install request."""
source: Union[str, Path, AnyHttpUrl] = Field(description="Source (URL, repo_id, or local path) of model")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process") status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
local_path: Optional[Path] = Field(default=None, description="Path to locally-downloaded model") metadata: Dict[str, Any] = Field(default_factory=dict, description="Configuration metadata to apply to model before installing it")
inplace: bool = Field(default=False, description="Leave model in its current location; otherwise install under models directory")
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
key: Optional[str] = Field(default=None, description="After model is installed, this is its config record key")
error_type: str = Field(default="", description="Class name of the exception that led to status==ERROR") error_type: str = Field(default="", description="Class name of the exception that led to status==ERROR")
error: str = Field(default="", description="Error traceback") # noqa #501 error: str = Field(default="", description="Error traceback") # noqa #501
@ -65,6 +74,11 @@ class ModelInstallServiceBase(ABC):
def record_store(self) -> ModelRecordServiceBase: def record_store(self) -> ModelRecordServiceBase:
"""Return the ModelRecoreService object associated with the installer.""" """Return the ModelRecoreService object associated with the installer."""
@property
@abstractmethod
def event_bus(self) -> Optional[EventServiceBase]:
"""Return the event service base object associated with the installer."""
@abstractmethod @abstractmethod
def register_path( def register_path(
self, self,
@ -86,16 +100,12 @@ class ModelInstallServiceBase(ABC):
"""Remove model with indicated key from the database.""" """Remove model with indicated key from the database."""
@abstractmethod @abstractmethod
def delete(self, key: str) -> None: def delete(self, key: str) -> None: # noqa D102
"""Remove model with indicated key from the database and delete weight files from disk.""" """Remove model with indicated key from the database. Delete its files only if they are within our models directory."""
@abstractmethod @abstractmethod
def conditionally_delete(self, key: str) -> None: def unconditionally_delete(self, key: str) -> None:
""" """Remove model with indicated key from the database and unconditionally delete weight files from disk."""
Remove model with indicated key from the database
and conditeionally delete weight files from disk
if they reside within InvokeAI's models directory.
"""
@abstractmethod @abstractmethod
def install_path( def install_path(
@ -121,7 +131,7 @@ class ModelInstallServiceBase(ABC):
inplace: bool = True, inplace: bool = True,
variant: Optional[str] = None, variant: Optional[str] = None,
subfolder: Optional[str] = None, subfolder: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None, access_token: Optional[str] = None,
) -> ModelInstallJob: ) -> ModelInstallJob:
"""Install the indicated model. """Install the indicated model.
@ -168,6 +178,18 @@ class ModelInstallServiceBase(ABC):
""" """
@abstractmethod
def get_job(self, source: ModelSource) -> ModelInstallJob:
"""Return the ModelInstallJob corresponding to the provided source."""
@abstractmethod
def get_jobs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
"""Return a dict in which keys are model sources and values are corresponding model install jobs."""
@abstractmethod
def prune_jobs(self) -> None:
"""Prune all completed and errored jobs."""
@abstractmethod @abstractmethod
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], ModelInstallJob]: def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], ModelInstallJob]:
""" """
@ -194,4 +216,4 @@ class ModelInstallServiceBase(ABC):
@abstractmethod @abstractmethod
def sync_to_config(self) -> None: def sync_to_config(self) -> None:
"""Synchronize models on disk to those in memory.""" """Synchronize models on disk to those in the model record database."""

View File

@ -6,7 +6,7 @@ from pathlib import Path
from queue import Queue from queue import Queue
from random import randbytes from random import randbytes
from shutil import move, rmtree from shutil import move, rmtree
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Set, Optional, Union
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
@ -18,14 +18,16 @@ from invokeai.backend.model_manager.config import (
DuplicateModelException, DuplicateModelException,
InvalidModelConfigException, InvalidModelConfigException,
) )
from invokeai.backend.model_manager.config import ModelType, BaseModelType
from invokeai.backend.model_manager.hash import FastModelHash from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.model_manager.probe import ModelProbe from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import Chdir, InvokeAILogger
from .model_install_base import InstallStatus, ModelInstallJob, ModelInstallServiceBase from .model_install_base import ModelSource, InstallStatus, ModelInstallJob, ModelInstallServiceBase, UnknownInstallJobException
# marker that the queue is done and that thread should exit # marker that the queue is done and that thread should exit
STOP_JOB = ModelInstallJob(source="stop") STOP_JOB = ModelInstallJob(source="stop", local_path=Path("/dev/null"))
class ModelInstallService(ModelInstallServiceBase): class ModelInstallService(ModelInstallServiceBase):
@ -35,8 +37,10 @@ class ModelInstallService(ModelInstallServiceBase):
_record_store: ModelRecordServiceBase _record_store: ModelRecordServiceBase
_event_bus: Optional[EventServiceBase] = None _event_bus: Optional[EventServiceBase] = None
_install_queue: Queue[ModelInstallJob] _install_queue: Queue[ModelInstallJob]
_install_jobs: Dict[Union[str, Path, AnyHttpUrl], ModelInstallJob] _install_jobs: Dict[ModelSource, ModelInstallJob]
_logger: InvokeAILogger _logger: InvokeAILogger
_cached_model_paths: Set[Path]
_models_installed: Set[str]
def __init__(self, def __init__(self,
app_config: InvokeAIAppConfig, app_config: InvokeAIAppConfig,
@ -52,9 +56,12 @@ class ModelInstallService(ModelInstallServiceBase):
""" """
self._app_config = app_config self._app_config = app_config
self._record_store = record_store self._record_store = record_store
self._install_queue = Queue()
self._event_bus = event_bus self._event_bus = event_bus
self._logger = InvokeAILogger.get_logger(name=self.__class__.__name__) self._logger = InvokeAILogger.get_logger(name=self.__class__.__name__)
self._install_jobs = {}
self._install_queue = Queue()
self._cached_model_paths = set()
self._models_installed = set()
self._start_installer_thread() self._start_installer_thread()
@property @property
@ -65,6 +72,13 @@ class ModelInstallService(ModelInstallServiceBase):
def record_store(self) -> ModelRecordServiceBase: # noqa D102 def record_store(self) -> ModelRecordServiceBase: # noqa D102
return self._record_store return self._record_store
@property
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
return self._event_bus
def get_jobs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
return self._install_jobs
def _start_installer_thread(self) -> None: def _start_installer_thread(self) -> None:
threading.Thread(target=self._install_next_item, daemon=True).start() threading.Thread(target=self._install_next_item, daemon=True).start()
@ -74,14 +88,20 @@ class ModelInstallService(ModelInstallServiceBase):
job = self._install_queue.get() job = self._install_queue.get()
if job == STOP_JOB: if job == STOP_JOB:
done = True done = True
elif job.status == InstallStatus.WAITING: continue
assert job.local_path is not None assert job.local_path is not None
try: try:
self._signal_job_running(job) self._signal_job_running(job)
self.register_path(job.local_path) if job.inplace:
job.key = self.register_path(job.local_path, job.metadata)
else:
job.key = self.install_path(job.local_path, job.metadata)
self._signal_job_completed(job) self._signal_job_completed(job)
except (OSError, DuplicateModelException, InvalidModelConfigException) as excp: except (OSError, DuplicateModelException, InvalidModelConfigException) as excp:
self._signal_job_errored(job, excp) self._signal_job_errored(job, excp)
finally:
self._install_queue.task_done()
def _signal_job_running(self, job: ModelInstallJob) -> None: def _signal_job_running(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.RUNNING job.status = InstallStatus.RUNNING
@ -92,7 +112,7 @@ class ModelInstallService(ModelInstallServiceBase):
job.status = InstallStatus.COMPLETED job.status = InstallStatus.COMPLETED
if self._event_bus: if self._event_bus:
assert job.local_path is not None assert job.local_path is not None
self._event_bus.emit_model_install_completed(str(job.source), job.local_path.as_posix()) self._event_bus.emit_model_install_completed(str(job.source), job.key)
def _signal_job_errored(self, job: ModelInstallJob, excp: Exception) -> None: def _signal_job_errored(self, job: ModelInstallJob, excp: Exception) -> None:
job.set_error(excp) job.set_error(excp)
@ -136,29 +156,165 @@ class ModelInstallService(ModelInstallServiceBase):
def import_model( def import_model(
self, self,
source: Union[str, Path, AnyHttpUrl], source: ModelSource,
inplace: bool = True, inplace: bool = True,
variant: Optional[str] = None, variant: Optional[str] = None,
subfolder: Optional[str] = None, subfolder: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None, access_token: Optional[str] = None,
) -> ModelInstallJob: # noqa D102 ) -> ModelInstallJob: # noqa D102
# Clean up a common source of error. Doesn't work with Paths.
if isinstance(source, str):
source = source.strip()
if not metadata:
metadata = {}
# Installing a local path
if isinstance(source, (str, Path)) and Path(source).exists(): # a path that is already on disk
job = ModelInstallJob(metadata=metadata,
source=source,
inplace=inplace,
local_path=Path(source),
)
self._install_jobs[source] = job
self._install_queue.put(job)
return job
else: # waiting for download queue implementation
raise NotImplementedError raise NotImplementedError
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], ModelInstallJob]: # noqa D102 def get_job(self, source: ModelSource) -> ModelInstallJob: # noqa D102
try:
return self._install_jobs[source]
except KeyError:
raise UnknownInstallJobException(f'{source}: unknown install job')
def wait_for_installs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
self._install_queue.join() self._install_queue.join()
return self._install_jobs return self._install_jobs
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 def prune_jobs(self) -> None:
raise NotImplementedError """Prune all completed and errored jobs."""
finished_jobs = [source for source in self._install_jobs
if self._install_jobs[source].status in [InstallStatus.COMPLETED, InstallStatus.ERROR]
]
for source in finished_jobs:
del self._install_jobs[source]
def sync_to_config(self) -> None: # noqa D102 def sync_to_config(self) -> None:
raise NotImplementedError """Synchronize models on disk to those in the config record store database."""
self._scan_models_directory()
if autoimport := self._app_config.autoimport_dir:
self._logger.info("Scanning autoimport directory for new models")
self.scan_directory(self._app_config.root_path / autoimport)
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback)
self._models_installed: Set[str] = set()
search.search(scan_dir)
return list(self._models_installed)
def _scan_models_directory(self) -> None:
"""
Scan the models directory for new and missing models.
New models will be added to the storage backend. Missing models
will be deleted.
"""
defunct_models = set()
installed = set()
with Chdir(self._app_config.models_path):
self._logger.info("Checking for models that have been moved or deleted from disk")
for model_config in self.record_store.all_models():
path = Path(model_config.path)
if not path.exists():
self._logger.info(f"{model_config.name}: path {path.as_posix()} no longer exists. Unregistering")
defunct_models.add(model_config.key)
for key in defunct_models:
self.unregister(key)
self._logger.info(f"Scanning {self._app_config.models_path} for new models")
for cur_base_model in BaseModelType:
for cur_model_type in ModelType:
models_dir = Path(cur_base_model.value, cur_model_type.value)
installed.update(self.scan_directory(models_dir))
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
def _sync_model_path(self, key: str, ignore_hash_change: bool = False) -> AnyModelConfig:
"""
Move model into the location indicated by its basetype, type and name.
Call this after updating a model's attributes in order to move
the model's path into the location indicated by its basetype, type and
name. Applies only to models whose paths are within the root `models_dir`
directory.
May raise an UnknownModelException.
"""
model = self.record_store.get_model(key)
old_path = Path(model.path)
models_dir = self.app_config.models_path
if not old_path.is_relative_to(models_dir):
return model
new_path = models_dir / model.base.value / model.type.value / model.name
self._logger.info(f"Moving {model.name} to {new_path}.")
new_path = self._move_model(old_path, new_path)
new_hash = FastModelHash.hash(new_path)
model.path = new_path.relative_to(models_dir).as_posix()
if model.current_hash != new_hash:
assert (
ignore_hash_change
), f"{model.name}: Model hash changed during installation, model is possibly corrupted"
model.current_hash = new_hash
self._logger.info(f"Model has new hash {model.current_hash}, but will continue to be identified by {key}")
self.record_store.update_model(key, model)
return model
def _scan_register(self, model: Path) -> bool:
if model in self._cached_model_paths:
return True
try:
id = self.register_path(model)
self._sync_model_path(id) # possibly move it to right place in `models`
self._logger.info(f"Registered {model.name} with id {id}")
self._models_installed.add(id)
except DuplicateModelException:
pass
return True
def _scan_install(self, model: Path) -> bool:
if model in self._cached_model_paths:
return True
try:
id = self.install_path(model)
self._logger.info(f"Installed {model} with id {id}")
self._models_installed.add(id)
except DuplicateModelException:
pass
return True
def unregister(self, key: str) -> None: # noqa D102 def unregister(self, key: str) -> None: # noqa D102
self.record_store.del_model(key) self.record_store.del_model(key)
def delete(self, key: str) -> None: # noqa D102 def delete(self, key: str) -> None: # noqa D102
"""Unregister the model. Delete its files only if they are within our models directory."""
model = self.record_store.get_model(key)
models_dir = self.app_config.models_path
model_path = models_dir / model.path
if model_path.is_relative_to(models_dir):
self.unconditionally_delete(key)
else:
self.unregister(key)
def unconditionally_delete(self, key: str) -> None: # noqa D102
model = self.record_store.get_model(key) model = self.record_store.get_model(key)
path = self.app_config.models_path / model.path path = self.app_config.models_path / model.path
if path.is_dir(): if path.is_dir():
@ -167,16 +323,6 @@ class ModelInstallService(ModelInstallServiceBase):
path.unlink() path.unlink()
self.unregister(key) self.unregister(key)
def conditionally_delete(self, key: str) -> None: # noqa D102
"""Unregister the model. Delete its files only if they are within our models directory."""
model = self.record_store.get_model(key)
models_dir = self.app_config.models_path
model_path = models_dir / model.path
if model_path.is_relative_to(models_dir):
self.delete(key)
else:
self.unregister(key)
def _move_model(self, old_path: Path, new_path: Path) -> Path: def _move_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path: if old_path == new_path:
return old_path return old_path

View File

@ -12,3 +12,6 @@ from .devices import ( # noqa: F401
torch_dtype, torch_dtype,
) )
from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401 from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401
from .logging import InvokeAILogger
__all__ = ['Chdir', 'InvokeAILogger', 'choose_precision', 'choose_torch_device']

View File

@ -2,15 +2,26 @@
Test the model installer Test the model installer
""" """
import pytest
from pathlib import Path from pathlib import Path
from pydantic import ValidationError from typing import List, Any, Dict
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.model_manager.config import ModelType, BaseModelType import pytest
from pydantic import ValidationError, BaseModel
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceSQL, ModelRecordServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_install import (
ModelInstallService,
ModelInstallServiceBase,
InstallStatus,
ModelInstallJob,
UnknownInstallJobException,
)
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
from invokeai.app.services.shared.sqlite import SqliteDatabase from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase from invokeai.backend.model_manager.config import BaseModelType, ModelType
from invokeai.backend.util.logging import InvokeAILogger
@pytest.fixture @pytest.fixture
def test_file(datadir: Path) -> Path: def test_file(datadir: Path) -> Path:
@ -36,10 +47,34 @@ def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
def installer(app_config: InvokeAIAppConfig, def installer(app_config: InvokeAIAppConfig,
store: ModelRecordServiceBase) -> ModelInstallServiceBase: store: ModelRecordServiceBase) -> ModelInstallServiceBase:
return ModelInstallService(app_config=app_config, return ModelInstallService(app_config=app_config,
record_store=store record_store=store,
event_bus=DummyEventService(),
) )
class DummyEvent(BaseModel):
"""Dummy Event to use with Dummy Event service."""
event_name: str
payload: Dict[str, Any]
class DummyEventService(EventServiceBase):
"""Dummy event service for testing."""
events: List[DummyEvent]
def __init__(self) -> None:
super().__init__()
self.events = []
def dispatch(self, event_name: str, payload: Any) -> None:
"""Dispatch an event by appending it to self.events."""
self.events.append(
DummyEvent(event_name=payload['event'],
payload=payload['data'])
)
def test_registration(installer: ModelInstallServiceBase, test_file: Path) -> None: def test_registration(installer: ModelInstallServiceBase, test_file: Path) -> None:
store = installer.record_store store = installer.record_store
matches = store.search_by_attr(model_name="test_embedding") matches = store.search_by_attr(model_name="test_embedding")
@ -87,3 +122,69 @@ def test_install(installer: ModelInstallServiceBase, test_file: Path, app_config
model_record = store.get_model(key) model_record = store.get_model(key)
assert model_record.path == "sd-1/embedding/test_embedding.safetensors" assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
assert model_record.source == test_file.as_posix() assert model_record.source == test_file.as_posix()
def test_background_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
"""Note: may want to break this down into several smaller unit tests."""
source = test_file
description = "Test of metadata assignment"
job = installer.import_model(source, inplace=False, metadata={"description": description})
assert job is not None
assert isinstance(job, ModelInstallJob)
# See if job is registered properly
assert installer.get_job(source) == job
# test that the job object tracked installation correctly
jobs = installer.wait_for_installs()
assert jobs[source] is not None
assert jobs[source] == job
assert jobs[source].status == InstallStatus.COMPLETED
# test that the expected events were issued
bus = installer.event_bus
assert bus is not None # sigh - ruff is a stickler for type checking
assert isinstance(bus, DummyEventService)
assert len(bus.events) == 2
event_names = [x.event_name for x in bus.events]
assert "model_install_started" in event_names
assert "model_install_completed" in event_names
assert bus.events[0].payload["source"] == source.as_posix()
assert bus.events[1].payload["source"] == source.as_posix()
key = bus.events[1].payload["key"]
assert key is not None
# see if the thing actually got installed at the expected location
model_record = installer.record_store.get_model(key)
assert model_record is not None
assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
assert Path(app_config.models_dir / model_record.path).exists()
# see if metadata was properly passed through
assert model_record.description == description
# see if prune works properly
installer.prune_jobs()
with pytest.raises(UnknownInstallJobException):
assert installer.get_job(source)
def test_delete_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
store = installer.record_store
key = installer.install_path(test_file)
model_record = store.get_model(key)
assert Path(app_config.models_dir / model_record.path).exists()
assert not test_file.exists() # original should not still be there after installation
installer.delete(key)
assert not Path(app_config.models_dir / model_record.path).exists() # but installed copy should not!
with pytest.raises(UnknownModelException):
store.get_model(key)
def test_delete_register(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
store = installer.record_store
key = installer.register_path(test_file)
model_record = store.get_model(key)
assert Path(app_config.models_dir / model_record.path).exists()
assert test_file.exists() # original should still be there after installation
installer.delete(key)
assert Path(app_config.models_dir / model_record.path).exists()
with pytest.raises(UnknownModelException):
store.get_model(key)