further changes for ruff

This commit is contained in:
Lincoln Stein 2023-11-26 17:13:31 -05:00
parent 8f4f4d48d5
commit 8ef596eac7
15 changed files with 245 additions and 212 deletions

View File

@ -88,7 +88,9 @@ class ApiDependencies:
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
model_manager = ModelManagerService(config, logger) model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db) model_record_service = ModelRecordServiceSQL(db=db)
model_install_service = ModelInstallService(app_config=config, record_store=model_record_service, event_bus=events) model_install_service = ModelInstallService(
app_config=config, record_store=model_record_service, event_bus=events
)
names = SimpleNameService() names = SimpleNameService()
performance_statistics = InvocationStatsService() performance_statistics = InvocationStatsService()
processor = DefaultInvocationProcessor() processor = DefaultInvocationProcessor()

View File

@ -51,7 +51,9 @@ async def list_model_records(
found_models: list[AnyModelConfig] = [] found_models: list[AnyModelConfig] = []
if base_models: if base_models:
for base_model in base_models: for base_model in base_models:
found_models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name)) found_models.extend(
record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name)
)
else: else:
found_models.extend(record_store.search_by_attr(model_type=model_type, model_name=model_name)) found_models.extend(record_store.search_by_attr(model_type=model_type, model_name=model_name))
return ModelsList(models=found_models) return ModelsList(models=found_models)
@ -184,25 +186,25 @@ async def add_model_record(
status_code=201, status_code=201,
) )
async def import_model( async def import_model(
source: ModelSource = Body( source: ModelSource = Body(
description="A model path, repo_id or URL to import. NOTE: only model path is implemented currently!" description="A model path, repo_id or URL to import. NOTE: only model path is implemented currently!"
), ),
config: Optional[Dict[str, Any]] = Body( config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None, default=None,
), ),
variant: Optional[str] = Body( variant: Optional[str] = Body(
description="When fetching a repo_id, force variant type to fetch such as 'fp16'", description="When fetching a repo_id, force variant type to fetch such as 'fp16'",
default=None, default=None,
), ),
subfolder: Optional[str] = Body( subfolder: Optional[str] = Body(
description="When fetching a repo_id, specify subfolder to fetch model from", description="When fetching a repo_id, specify subfolder to fetch model from",
default=None, default=None,
), ),
access_token: Optional[str] = Body( access_token: Optional[str] = Body(
description="When fetching a repo_id or URL, access token for web access", description="When fetching a repo_id or URL, access token for web access",
default=None, default=None,
), ),
) -> ModelInstallJob: ) -> ModelInstallJob:
"""Add a model using its local path, repo_id, or remote URL. """Add a model using its local path, repo_id, or remote URL.
@ -250,14 +252,16 @@ async def import_model(
raise HTTPException(status_code=409, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
return result return result
@model_records_router.get( @model_records_router.get(
"/import", "/import",
operation_id="list_model_install_jobs", operation_id="list_model_install_jobs",
) )
async def list_model_install_jobs( async def list_model_install_jobs(
source: Optional[str] = Query(description="Filter list by install source, partial string match.", source: Optional[str] = Query(
default=None, description="Filter list by install source, partial string match.",
) default=None,
),
) -> List[ModelInstallJob]: ) -> List[ModelInstallJob]:
""" """
Return list of model install jobs. Return list of model install jobs.
@ -268,6 +272,7 @@ async def list_model_install_jobs(
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs(source) jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs(source)
return jobs return jobs
@model_records_router.patch( @model_records_router.patch(
"/import", "/import",
operation_id="prune_model_install_jobs", operation_id="prune_model_install_jobs",
@ -276,14 +281,14 @@ async def list_model_install_jobs(
400: {"description": "Bad request"}, 400: {"description": "Bad request"},
}, },
) )
async def prune_model_install_jobs( async def prune_model_install_jobs() -> Response:
) -> Response:
""" """
Prune all completed and errored jobs from the install job list. Prune all completed and errored jobs from the install job list.
""" """
ApiDependencies.invoker.services.model_install.prune_jobs() ApiDependencies.invoker.services.model_install.prune_jobs()
return Response(status_code=204) return Response(status_code=204)
@model_records_router.patch( @model_records_router.patch(
"/sync", "/sync",
operation_id="sync_models_to_config", operation_id="sync_models_to_config",
@ -292,8 +297,7 @@ async def prune_model_install_jobs(
400: {"description": "Bad request"}, 400: {"description": "Bad request"},
}, },
) )
async def sync_models_to_config( async def sync_models_to_config() -> Response:
) -> Response:
""" """
Traverse the models and autoimport directories. Model files without a corresponding Traverse the models and autoimport directories. Model files without a corresponding
record in the database are added. Orphan records without a models file are deleted. record in the database are added. Orphan records without a models file are deleted.

View File

@ -37,9 +37,5 @@ class SocketIO:
if "queue_id" in data: if "queue_id" in data:
await self.__sio.leave_room(sid, data["queue_id"]) await self.__sio.leave_room(sid, data["queue_id"])
async def _handle_model_event(self, event: Event) -> None: async def _handle_model_event(self, event: Event) -> None:
await self.__sio.emit( await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])
event=event[1]["event"],
data=event[1]["data"]
)

View File

@ -2,4 +2,4 @@
from .config_default import InvokeAIAppConfig, get_invokeai_config from .config_default import InvokeAIAppConfig, get_invokeai_config
__all__ = ['InvokeAIAppConfig', 'get_invokeai_config'] __all__ = ["InvokeAIAppConfig", "get_invokeai_config"]

View File

@ -331,9 +331,7 @@ class EventServiceBase:
""" """
self.__emit_model_event( self.__emit_model_event(
event_name="model_install_started", event_name="model_install_started",
payload={ payload={"source": source},
"source": source
},
) )
def emit_model_install_completed(self, source: str, key: str) -> None: def emit_model_install_completed(self, source: str, key: str) -> None:
@ -351,11 +349,12 @@ class EventServiceBase:
}, },
) )
def emit_model_install_progress(self, def emit_model_install_progress(
source: str, self,
current_bytes: int, source: str,
total_bytes: int, current_bytes: int,
) -> None: total_bytes: int,
) -> None:
""" """
Emitted while the install job is in progress. Emitted while the install job is in progress.
(Downloaded models only) (Downloaded models only)
@ -373,12 +372,12 @@ class EventServiceBase:
}, },
) )
def emit_model_install_error(
def emit_model_install_error(self, self,
source: str, source: str,
error_type: str, error_type: str,
error: str, error: str,
) -> None: ) -> None:
""" """
Emitted when an install job encounters an exception. Emitted when an install job encounters an exception.

View File

@ -9,10 +9,11 @@ from .model_install_base import (
) )
from .model_install_default import ModelInstallService from .model_install_default import ModelInstallService
__all__ = ['ModelInstallServiceBase', __all__ = [
'ModelInstallService', "ModelInstallServiceBase",
'InstallStatus', "ModelInstallService",
'ModelInstallJob', "InstallStatus",
'UnknownInstallJobException', "ModelInstallJob",
'ModelSource', "UnknownInstallJobException",
] "ModelSource",
]

View File

@ -17,10 +17,10 @@ from invokeai.backend.model_manager import AnyModelConfig
class InstallStatus(str, Enum): class InstallStatus(str, Enum):
"""State of an install job running in the background.""" """State of an install job running in the background."""
WAITING = "waiting" # waiting to be dequeued WAITING = "waiting" # waiting to be dequeued
RUNNING = "running" # being processed RUNNING = "running" # being processed
COMPLETED = "completed" # finished running COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message ERROR = "error" # terminated with an error message
class UnknownInstallJobException(Exception): class UnknownInstallJobException(Exception):
@ -32,10 +32,17 @@ ModelSource = Union[str, Path, AnyHttpUrl]
class ModelInstallJob(BaseModel): class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request.""" """Object that tracks the current status of an install request."""
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process") status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
config_in: Dict[str, Any] = Field(default_factory=dict, description="Configuration information (e.g. 'description') to apply to model.") config_in: Dict[str, Any] = Field(
config_out: Optional[AnyModelConfig] = Field(default=None, description="After successful installation, this will hold the configuration object.") default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
inplace: bool = Field(default=False, description="Leave model in its current location; otherwise install under models directory") )
config_out: Optional[AnyModelConfig] = Field(
default=None, description="After successful installation, this will hold the configuration object."
)
inplace: bool = Field(
default=False, description="Leave model in its current location; otherwise install under models directory"
)
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model") source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source") local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
error_type: Optional[str] = Field(default=None, description="Class name of the exception that led to status==ERROR") error_type: Optional[str] = Field(default=None, description="Class name of the exception that led to status==ERROR")
@ -53,10 +60,10 @@ class ModelInstallServiceBase(ABC):
@abstractmethod @abstractmethod
def __init__( def __init__(
self, self,
app_config: InvokeAIAppConfig, app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase, record_store: ModelRecordServiceBase,
event_bus: Optional["EventServiceBase"] = None, event_bus: Optional["EventServiceBase"] = None,
): ):
""" """
Create ModelInstallService object. Create ModelInstallService object.
@ -86,9 +93,9 @@ class ModelInstallServiceBase(ABC):
@abstractmethod @abstractmethod
def register_path( def register_path(
self, self,
model_path: Union[Path, str], model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
) -> str: ) -> str:
""" """
Probe and register the model at model_path. Probe and register the model at model_path.
@ -114,9 +121,9 @@ class ModelInstallServiceBase(ABC):
@abstractmethod @abstractmethod
def install_path( def install_path(
self, self,
model_path: Union[Path, str], model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
) -> str: ) -> str:
""" """
Probe, register and install the model in the models directory. Probe, register and install the model in the models directory.
@ -131,13 +138,13 @@ class ModelInstallServiceBase(ABC):
@abstractmethod @abstractmethod
def import_model( def import_model(
self, self,
source: Union[str, Path, AnyHttpUrl], source: Union[str, Path, AnyHttpUrl],
inplace: bool = False, inplace: bool = False,
variant: Optional[str] = None, variant: Optional[str] = None,
subfolder: Optional[str] = None, subfolder: Optional[str] = None,
config: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None, access_token: Optional[str] = None,
) -> ModelInstallJob: ) -> ModelInstallJob:
"""Install the indicated model. """Install the indicated model.
@ -189,7 +196,7 @@ class ModelInstallServiceBase(ABC):
"""Return the ModelInstallJob corresponding to the provided source.""" """Return the ModelInstallJob corresponding to the provided source."""
@abstractmethod @abstractmethod
def list_jobs(self, source: Optional[ModelSource]=None) -> List[ModelInstallJob]: # noqa D102 def list_jobs(self, source: Optional[ModelSource] = None) -> List[ModelInstallJob]: # noqa D102
""" """
List active and complete install jobs. List active and complete install jobs.

View File

@ -46,11 +46,12 @@ class ModelInstallService(ModelInstallServiceBase):
_cached_model_paths: Set[Path] _cached_model_paths: Set[Path]
_models_installed: Set[str] _models_installed: Set[str]
def __init__(self, def __init__(
app_config: InvokeAIAppConfig, self,
record_store: ModelRecordServiceBase, app_config: InvokeAIAppConfig,
event_bus: Optional[EventServiceBase] = None record_store: ModelRecordServiceBase,
): event_bus: Optional[EventServiceBase] = None,
):
""" """
Initialize the installer object. Initialize the installer object.
@ -73,11 +74,11 @@ class ModelInstallService(ModelInstallServiceBase):
return self._app_config return self._app_config
@property @property
def record_store(self) -> ModelRecordServiceBase: # noqa D102 def record_store(self) -> ModelRecordServiceBase: # noqa D102
return self._record_store return self._record_store
@property @property
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102 def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
return self._event_bus return self._event_bus
def _start_installer_thread(self) -> None: def _start_installer_thread(self) -> None:
@ -129,25 +130,25 @@ class ModelInstallService(ModelInstallServiceBase):
self._event_bus.emit_model_install_error(str(job.source), error_type, error) self._event_bus.emit_model_install_error(str(job.source), error_type, error)
def register_path( def register_path(
self, self,
model_path: Union[Path, str], model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
) -> str: # noqa D102 ) -> str: # noqa D102
model_path = Path(model_path) model_path = Path(model_path)
config = config or {} config = config or {}
if config.get('source') is None: if config.get("source") is None:
config['source'] = model_path.resolve().as_posix() config["source"] = model_path.resolve().as_posix()
return self._register(model_path, config) return self._register(model_path, config)
def install_path( def install_path(
self, self,
model_path: Union[Path, str], model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
) -> str: # noqa D102 ) -> str: # noqa D102
model_path = Path(model_path) model_path = Path(model_path)
config = config or {} config = config or {}
if config.get('source') is None: if config.get("source") is None:
config['source'] = model_path.resolve().as_posix() config["source"] = model_path.resolve().as_posix()
info: AnyModelConfig = self._probe_model(Path(model_path), config) info: AnyModelConfig = self._probe_model(Path(model_path), config)
@ -164,14 +165,14 @@ class ModelInstallService(ModelInstallServiceBase):
) )
def import_model( def import_model(
self, self,
source: ModelSource, source: ModelSource,
inplace: bool = False, inplace: bool = False,
variant: Optional[str] = None, variant: Optional[str] = None,
subfolder: Optional[str] = None, subfolder: Optional[str] = None,
config: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None, access_token: Optional[str] = None,
) -> ModelInstallJob: # noqa D102 ) -> ModelInstallJob: # noqa D102
# Clean up a common source of error. Doesn't work with Paths. # Clean up a common source of error. Doesn't work with Paths.
if isinstance(source, str): if isinstance(source, str):
source = source.strip() source = source.strip()
@ -181,11 +182,12 @@ class ModelInstallService(ModelInstallServiceBase):
# Installing a local path # Installing a local path
if isinstance(source, (str, Path)) and Path(source).exists(): # a path that is already on disk if isinstance(source, (str, Path)) and Path(source).exists(): # a path that is already on disk
job = ModelInstallJob(config_in=config, job = ModelInstallJob(
source=source, config_in=config,
inplace=inplace, source=source,
local_path=Path(source), inplace=inplace,
) local_path=Path(source),
)
self._install_jobs[source] = job self._install_jobs[source] = job
self._install_queue.put(job) self._install_queue.put(job)
return job return job
@ -193,7 +195,7 @@ class ModelInstallService(ModelInstallServiceBase):
else: # here is where we'd download a URL or repo_id. Implementation pending download queue. else: # here is where we'd download a URL or repo_id. Implementation pending download queue.
raise UnknownModelException("File or directory not found") raise UnknownModelException("File or directory not found")
def list_jobs(self, source: Optional[ModelSource]=None) -> List[ModelInstallJob]: # noqa D102 def list_jobs(self, source: Optional[ModelSource] = None) -> List[ModelInstallJob]: # noqa D102
jobs = self._install_jobs jobs = self._install_jobs
if not source: if not source:
return list(jobs.values()) return list(jobs.values())
@ -205,17 +207,19 @@ class ModelInstallService(ModelInstallServiceBase):
try: try:
return self._install_jobs[source] return self._install_jobs[source]
except KeyError: except KeyError:
raise UnknownInstallJobException(f'{source}: unknown install job') raise UnknownInstallJobException(f"{source}: unknown install job")
def wait_for_installs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102 def wait_for_installs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
self._install_queue.join() self._install_queue.join()
return self._install_jobs return self._install_jobs
def prune_jobs(self) -> None: def prune_jobs(self) -> None:
"""Prune all completed and errored jobs.""" """Prune all completed and errored jobs."""
finished_jobs = [source for source in self._install_jobs finished_jobs = [
if self._install_jobs[source].status in [InstallStatus.COMPLETED, InstallStatus.ERROR] source
] for source in self._install_jobs
if self._install_jobs[source].status in [InstallStatus.COMPLETED, InstallStatus.ERROR]
]
for source in finished_jobs: for source in finished_jobs:
del self._install_jobs[source] del self._install_jobs[source]
@ -228,7 +232,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"{len(installed)} new models registered") self._logger.info(f"{len(installed)} new models registered")
self._logger.info("Model installer (re)initialized") self._logger.info("Model installer (re)initialized")
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()} self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
callback = self._scan_install if install else self._scan_register callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback) search = ModelSearch(on_model_found=callback)
@ -295,7 +299,6 @@ class ModelInstallService(ModelInstallServiceBase):
self.record_store.update_model(key, model) self.record_store.update_model(key, model)
return model return model
def _scan_register(self, model: Path) -> bool: def _scan_register(self, model: Path) -> bool:
if model in self._cached_model_paths: if model in self._cached_model_paths:
return True return True
@ -308,7 +311,6 @@ class ModelInstallService(ModelInstallServiceBase):
pass pass
return True return True
def _scan_install(self, model: Path) -> bool: def _scan_install(self, model: Path) -> bool:
if model in self._cached_model_paths: if model in self._cached_model_paths:
return True return True
@ -320,7 +322,7 @@ class ModelInstallService(ModelInstallServiceBase):
pass pass
return True return True
def unregister(self, key: str) -> None: # noqa D102 def unregister(self, key: str) -> None: # noqa D102
self.record_store.del_model(key) self.record_store.del_model(key)
def delete(self, key: str) -> None: # noqa D102 def delete(self, key: str) -> None: # noqa D102
@ -333,7 +335,7 @@ class ModelInstallService(ModelInstallServiceBase):
else: else:
self.unregister(key) self.unregister(key)
def unconditionally_delete(self, key: str) -> None: # noqa D102 def unconditionally_delete(self, key: str) -> None: # noqa D102
model = self.record_store.get_model(key) model = self.record_store.get_model(key)
path = self.app_config.models_path / model.path path = self.app_config.models_path / model.path
if path.is_dir(): if path.is_dir():
@ -378,11 +380,9 @@ class ModelInstallService(ModelInstallServiceBase):
def _create_key(self) -> str: def _create_key(self) -> str:
return sha256(randbytes(100)).hexdigest()[0:32] return sha256(randbytes(100)).hexdigest()[0:32]
def _register(self, def _register(
model_path: Path, self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
config: Optional[Dict[str, Any]] = None, ) -> str:
info: Optional[AnyModelConfig] = None) -> str:
info = info or ModelProbe.probe(model_path, config) info = info or ModelProbe.probe(model_path, config)
key = self._create_key() key = self._create_key()
@ -393,7 +393,7 @@ class ModelInstallService(ModelInstallServiceBase):
info.path = model_path.as_posix() info.path = model_path.as_posix()
# add 'main' specific fields # add 'main' specific fields
if hasattr(info, 'config'): if hasattr(info, "config"):
# make config relative to our root # make config relative to our root
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve() legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix() info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()

View File

@ -8,9 +8,9 @@ from .model_records_base import ( # noqa F401
from .model_records_sql import ModelRecordServiceSQL # noqa F401 from .model_records_sql import ModelRecordServiceSQL # noqa F401
__all__ = [ __all__ = [
'ModelRecordServiceBase', "ModelRecordServiceBase",
'ModelRecordServiceSQL', "ModelRecordServiceSQL",
'DuplicateModelException', "DuplicateModelException",
'InvalidModelException', "InvalidModelException",
'UnknownModelException', "UnknownModelException",
] ]

View File

@ -123,8 +123,8 @@ class ModelProbe(object):
base_type=base_type, base_type=base_type,
variant_type=variant_type, variant_type=variant_type,
prediction_type=prediction_type, prediction_type=prediction_type,
name = name, name=name,
description = description, description=description,
upcast_attention=( upcast_attention=(
base_type == BaseModelType.StableDiffusion2 base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction and prediction_type == SchedulerPredictionType.VPrediction
@ -150,7 +150,7 @@ class ModelProbe(object):
@classmethod @classmethod
def get_model_name(cls, model_path: Path) -> str: def get_model_name(cls, model_path: Path) -> str:
if model_path.suffix in {'.safetensors', '.bin', '.pt', '.ckpt'}: if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
return model_path.stem return model_path.stem
else: else:
return model_path.name return model_path.name

View File

@ -14,15 +14,16 @@ from .config import (
from .probe import ModelProbe from .probe import ModelProbe
from .search import ModelSearch from .search import ModelSearch
__all__ = ['ModelProbe', 'ModelSearch', __all__ = [
'InvalidModelConfigException', "ModelProbe",
'ModelConfigFactory', "ModelSearch",
'BaseModelType', "InvalidModelConfigException",
'ModelType', "ModelConfigFactory",
'SubModelType', "BaseModelType",
'ModelVariantType', "ModelType",
'ModelFormat', "SubModelType",
'SchedulerPredictionType', "ModelVariantType",
'AnyModelConfig', "ModelFormat",
] "SchedulerPredictionType",
"AnyModelConfig",
]

View File

@ -49,6 +49,7 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[Sched
}, },
} }
class ProbeBase(object): class ProbeBase(object):
"""Base class for probes.""" """Base class for probes."""
@ -71,6 +72,7 @@ class ProbeBase(object):
"""Get model scheduler prediction type.""" """Get model scheduler prediction type."""
return None return None
class ModelProbe(object): class ModelProbe(object):
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = { PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
"diffusers": {}, "diffusers": {},
@ -100,9 +102,9 @@ class ModelProbe(object):
@classmethod @classmethod
def heuristic_probe( def heuristic_probe(
cls, cls,
model_path: Path, model_path: Path,
fields: Optional[Dict[str, Any]] = None, fields: Optional[Dict[str, Any]] = None,
) -> AnyModelConfig: ) -> AnyModelConfig:
return cls.probe(model_path, fields) return cls.probe(model_path, fields)
@ -138,29 +140,38 @@ class ModelProbe(object):
hash = FastModelHash.hash(model_path) hash = FastModelHash.hash(model_path)
probe = probe_class(model_path) probe = probe_class(model_path)
fields['path'] = model_path.as_posix() fields["path"] = model_path.as_posix()
fields['type'] = fields.get('type') or model_type fields["type"] = fields.get("type") or model_type
fields['base'] = fields.get('base') or probe.get_base_type() fields["base"] = fields.get("base") or probe.get_base_type()
fields['variant'] = fields.get('variant') or probe.get_variant_type() fields["variant"] = fields.get("variant") or probe.get_variant_type()
fields['prediction_type'] = fields.get('prediction_type') or probe.get_scheduler_prediction_type() fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type()
fields['name'] = fields.get('name') or cls.get_model_name(model_path) fields["name"] = fields.get("name") or cls.get_model_name(model_path)
fields['description'] = fields.get('description') or f"{fields['base'].value} {fields['type'].value} model {fields['name']}" fields["description"] = (
fields['format'] = fields.get('format') or probe.get_format() fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
fields['original_hash'] = fields.get('original_hash') or hash )
fields['current_hash'] = fields.get('current_hash') or hash fields["format"] = fields.get("format") or probe.get_format()
fields["original_hash"] = fields.get("original_hash") or hash
fields["current_hash"] = fields.get("current_hash") or hash
# additional fields needed for main and controlnet models # additional fields needed for main and controlnet models
if fields['type'] in [ModelType.Main, ModelType.ControlNet] and fields['format'] == ModelFormat.Checkpoint: if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
fields['config'] = cls._get_checkpoint_config_path(model_path, fields["config"] = cls._get_checkpoint_config_path(
model_type=fields['type'], model_path,
base_type=fields['base'], model_type=fields["type"],
variant_type=fields['variant'], base_type=fields["base"],
prediction_type=fields['prediction_type']).as_posix() variant_type=fields["variant"],
prediction_type=fields["prediction_type"],
).as_posix()
# additional fields needed for main non-checkpoint models # additional fields needed for main non-checkpoint models
elif fields['type'] == ModelType.Main and fields['format'] in [ModelFormat.Onnx, ModelFormat.Olive, ModelFormat.Diffusers]: elif fields["type"] == ModelType.Main and fields["format"] in [
fields['upcast_attention'] = fields.get('upcast_attention') or ( ModelFormat.Onnx,
fields['base'] == BaseModelType.StableDiffusion2 and fields['prediction_type'] == SchedulerPredictionType.VPrediction ModelFormat.Olive,
ModelFormat.Diffusers,
]:
fields["upcast_attention"] = fields.get("upcast_attention") or (
fields["base"] == BaseModelType.StableDiffusion2
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
) )
model_info = ModelConfigFactory.make_config(fields) model_info = ModelConfigFactory.make_config(fields)
@ -168,7 +179,7 @@ class ModelProbe(object):
@classmethod @classmethod
def get_model_name(cls, model_path: Path) -> str: def get_model_name(cls, model_path: Path) -> str:
if model_path.suffix in {'.safetensors', '.bin', '.pt', '.ckpt'}: if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
return model_path.stem return model_path.stem
else: else:
return model_path.name return model_path.name
@ -247,13 +258,14 @@ class ModelProbe(object):
) )
@classmethod @classmethod
def _get_checkpoint_config_path(cls, def _get_checkpoint_config_path(
model_path: Path, cls,
model_type: ModelType, model_path: Path,
base_type: BaseModelType, model_type: ModelType,
variant_type: ModelVariantType, base_type: BaseModelType,
prediction_type: SchedulerPredictionType) -> Path: variant_type: ModelVariantType,
prediction_type: SchedulerPredictionType,
) -> Path:
# look for a YAML file adjacent to the model file first # look for a YAML file adjacent to the model file first
possible_conf = model_path.with_suffix(".yaml") possible_conf = model_path.with_suffix(".yaml")
if possible_conf.exists(): if possible_conf.exists():
@ -264,9 +276,13 @@ class ModelProbe(object):
if isinstance(config_file, dict): # need another tier for sd-2.x models if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type] config_file = config_file[prediction_type]
elif model_type == ModelType.ControlNet: elif model_type == ModelType.ControlNet:
config_file = "../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml" config_file = (
"../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
)
else: else:
raise InvalidModelConfigException(f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}") raise InvalidModelConfigException(
f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}"
)
assert isinstance(config_file, str) assert isinstance(config_file, str)
return Path(config_file) return Path(config_file)
@ -297,6 +313,7 @@ class ModelProbe(object):
# Checkpoint probing # Checkpoint probing
# ##################################################3 # ##################################################3
class CheckpointProbeBase(ProbeBase): class CheckpointProbeBase(ProbeBase):
def __init__(self, model_path: Path): def __init__(self, model_path: Path):
super().__init__(model_path) super().__init__(model_path)
@ -446,7 +463,6 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase):
# classes for probing folders # classes for probing folders
####################################################### #######################################################
class FolderProbeBase(ProbeBase): class FolderProbeBase(ProbeBase):
def get_variant_type(self) -> ModelVariantType: def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal return ModelVariantType.Normal
@ -537,7 +553,9 @@ class TextualInversionFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
path = self.model_path / "learned_embeds.bin" path = self.model_path / "learned_embeds.bin"
if not path.exists(): if not path.exists():
raise InvalidModelConfigException(f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file") raise InvalidModelConfigException(
f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file"
)
return TextualInversionCheckpointProbe(path).get_base_type() return TextualInversionCheckpointProbe(path).get_base_type()
@ -608,7 +626,9 @@ class IPAdapterFolderProbe(FolderProbeBase):
elif cross_attention_dim == 2048: elif cross_attention_dim == 2048:
return BaseModelType.StableDiffusionXL return BaseModelType.StableDiffusionXL
else: else:
raise InvalidModelConfigException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.") raise InvalidModelConfigException(
f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}."
)
class CLIPVisionFolderProbe(FolderProbeBase): class CLIPVisionFolderProbe(FolderProbeBase):

View File

@ -165,14 +165,14 @@ class ModelSearch(ModelSearchBase):
self.scanned_dirs.add(path) self.scanned_dirs.add(path)
continue continue
if any( if any(
(path / x).exists() (path / x).exists()
for x in [ for x in [
"config.json", "config.json",
"model_index.json", "model_index.json",
"learned_embeds.bin", "learned_embeds.bin",
"pytorch_lora_weights.bin", "pytorch_lora_weights.bin",
"image_encoder.txt", "image_encoder.txt",
] ]
): ):
self.scanned_dirs.add(path) self.scanned_dirs.add(path)
try: try:

View File

@ -14,4 +14,4 @@ from .devices import ( # noqa: F401
from .logging import InvokeAILogger from .logging import InvokeAILogger
from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401 from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401
__all__ = ['Chdir', 'InvokeAILogger', 'choose_precision', 'choose_torch_device'] __all__ = ["Chdir", "InvokeAILogger", "choose_precision", "choose_torch_device"]

View File

@ -44,12 +44,12 @@ def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
@pytest.fixture @pytest.fixture
def installer(app_config: InvokeAIAppConfig, def installer(app_config: InvokeAIAppConfig, store: ModelRecordServiceBase) -> ModelInstallServiceBase:
store: ModelRecordServiceBase) -> ModelInstallServiceBase: return ModelInstallService(
return ModelInstallService(app_config=app_config, app_config=app_config,
record_store=store, record_store=store,
event_bus=DummyEventService(), event_bus=DummyEventService(),
) )
class DummyEvent(BaseModel): class DummyEvent(BaseModel):
@ -70,10 +70,8 @@ class DummyEventService(EventServiceBase):
def dispatch(self, event_name: str, payload: Any) -> None: def dispatch(self, event_name: str, payload: Any) -> None:
"""Dispatch an event by appending it to self.events.""" """Dispatch an event by appending it to self.events."""
self.events.append( self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
DummyEvent(event_name=payload['event'],
payload=payload['data'])
)
def test_registration(installer: ModelInstallServiceBase, test_file: Path) -> None: def test_registration(installer: ModelInstallServiceBase, test_file: Path) -> None:
store = installer.record_store store = installer.record_store
@ -83,6 +81,7 @@ def test_registration(installer: ModelInstallServiceBase, test_file: Path) -> No
assert key is not None assert key is not None
assert len(key) == 32 assert len(key) == 32
def test_registration_meta(installer: ModelInstallServiceBase, test_file: Path) -> None: def test_registration_meta(installer: ModelInstallServiceBase, test_file: Path) -> None:
store = installer.record_store store = installer.record_store
key = installer.register_path(test_file) key = installer.register_path(test_file)
@ -91,31 +90,30 @@ def test_registration_meta(installer: ModelInstallServiceBase, test_file: Path)
assert model_record.name == "test_embedding" assert model_record.name == "test_embedding"
assert model_record.type == ModelType.TextualInversion assert model_record.type == ModelType.TextualInversion
assert Path(model_record.path) == test_file assert Path(model_record.path) == test_file
assert model_record.base == BaseModelType('sd-1') assert model_record.base == BaseModelType("sd-1")
assert model_record.description is not None assert model_record.description is not None
assert model_record.source is not None assert model_record.source is not None
assert Path(model_record.source) == test_file assert Path(model_record.source) == test_file
def test_registration_meta_override_fail(installer: ModelInstallServiceBase, test_file: Path) -> None: def test_registration_meta_override_fail(installer: ModelInstallServiceBase, test_file: Path) -> None:
key = None key = None
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
key = installer.register_path(test_file, {"name": "banana_sushi", "type": ModelType("lora")}) key = installer.register_path(test_file, {"name": "banana_sushi", "type": ModelType("lora")})
assert key is None assert key is None
def test_registration_meta_override_succeed(installer: ModelInstallServiceBase, test_file: Path) -> None: def test_registration_meta_override_succeed(installer: ModelInstallServiceBase, test_file: Path) -> None:
store = installer.record_store store = installer.record_store
key = installer.register_path(test_file, key = installer.register_path(
{ test_file, {"name": "banana_sushi", "source": "fake/repo_id", "current_hash": "New Hash"}
"name": "banana_sushi", )
"source": "fake/repo_id",
"current_hash": "New Hash"
}
)
model_record = store.get_model(key) model_record = store.get_model(key)
assert model_record.name == "banana_sushi" assert model_record.name == "banana_sushi"
assert model_record.source == "fake/repo_id" assert model_record.source == "fake/repo_id"
assert model_record.current_hash == "New Hash" assert model_record.current_hash == "New Hash"
def test_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None: def test_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
store = installer.record_store store = installer.record_store
key = installer.install_path(test_file) key = installer.install_path(test_file)
@ -123,6 +121,7 @@ def test_install(installer: ModelInstallServiceBase, test_file: Path, app_config
assert model_record.path == "sd-1/embedding/test_embedding.safetensors" assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
assert model_record.source == test_file.as_posix() assert model_record.source == test_file.as_posix()
def test_background_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None: def test_background_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
"""Note: may want to break this down into several smaller unit tests.""" """Note: may want to break this down into several smaller unit tests."""
source = test_file source = test_file
@ -142,7 +141,7 @@ def test_background_install(installer: ModelInstallServiceBase, test_file: Path,
# test that the expected events were issued # test that the expected events were issued
bus = installer.event_bus bus = installer.event_bus
assert bus is not None # sigh - ruff is a stickler for type checking assert bus is not None # sigh - ruff is a stickler for type checking
assert isinstance(bus, DummyEventService) assert isinstance(bus, DummyEventService)
assert len(bus.events) == 2 assert len(bus.events) == 2
event_names = [x.event_name for x in bus.events] event_names = [x.event_name for x in bus.events]
@ -167,6 +166,7 @@ def test_background_install(installer: ModelInstallServiceBase, test_file: Path,
with pytest.raises(UnknownInstallJobException): with pytest.raises(UnknownInstallJobException):
assert installer.get_job(source) assert installer.get_job(source)
def test_delete_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig): def test_delete_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
store = installer.record_store store = installer.record_store
key = installer.install_path(test_file) key = installer.install_path(test_file)
@ -174,11 +174,14 @@ def test_delete_install(installer: ModelInstallServiceBase, test_file: Path, app
assert Path(app_config.models_dir / model_record.path).exists() assert Path(app_config.models_dir / model_record.path).exists()
assert test_file.exists() # original should still be there after installation assert test_file.exists() # original should still be there after installation
installer.delete(key) installer.delete(key)
assert not Path(app_config.models_dir / model_record.path).exists() # after deletion, installed copy should not exist assert not Path(
app_config.models_dir / model_record.path
).exists() # after deletion, installed copy should not exist
assert test_file.exists() # but original should still be there assert test_file.exists() # but original should still be there
with pytest.raises(UnknownModelException): with pytest.raises(UnknownModelException):
store.get_model(key) store.get_model(key)
def test_delete_register(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig): def test_delete_register(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
store = installer.record_store store = installer.record_store
key = installer.register_path(test_file) key = installer.register_path(test_file)