mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(mm): remove unused metadata logic, fix tests
- Metadata is merged with the config. We can simplify the MM substantially and remove the handling for metadata. - Per discussion, we don't have an ETA for frontend implementation of tags, and with the realization that the tags from CivitAI are largely useless, there's no reason to keep tags in the MM right now. When we are ready to implement tags on the frontend, we can refer back to the implementation here and use it if it supports the design. - Fix all tests.
This commit is contained in:
parent
0b9a212363
commit
44c40d7d1a
@ -26,7 +26,6 @@ from ..services.invocation_services import InvocationServices
|
|||||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||||
from ..services.invoker import Invoker
|
from ..services.invoker import Invoker
|
||||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||||
from ..services.model_metadata import ModelMetadataStoreSQL
|
|
||||||
from ..services.model_records import ModelRecordServiceSQL
|
from ..services.model_records import ModelRecordServiceSQL
|
||||||
from ..services.names.names_default import SimpleNameService
|
from ..services.names.names_default import SimpleNameService
|
||||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||||
@ -93,10 +92,9 @@ class ApiDependencies:
|
|||||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||||
)
|
)
|
||||||
download_queue_service = DownloadQueueService(event_bus=events)
|
download_queue_service = DownloadQueueService(event_bus=events)
|
||||||
model_metadata_service = ModelMetadataStoreSQL(db=db)
|
|
||||||
model_manager = ModelManagerService.build_model_manager(
|
model_manager = ModelManagerService.build_model_manager(
|
||||||
app_config=configuration,
|
app_config=configuration,
|
||||||
model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service),
|
model_record_service=ModelRecordServiceSQL(db=db),
|
||||||
download_queue=download_queue_service,
|
download_queue=download_queue_service,
|
||||||
events=events,
|
events=events,
|
||||||
)
|
)
|
||||||
|
@ -12,7 +12,6 @@ from starlette.exceptions import HTTPException
|
|||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from invokeai.app.services.model_install import ModelInstallJob
|
from invokeai.app.services.model_install import ModelInstallJob
|
||||||
from invokeai.app.services.model_metadata.metadata_store_base import ModelMetadataChanges
|
|
||||||
from invokeai.app.services.model_records import (
|
from invokeai.app.services.model_records import (
|
||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
InvalidModelException,
|
InvalidModelException,
|
||||||
@ -30,8 +29,6 @@ from invokeai.backend.model_manager.config import (
|
|||||||
SubModelType,
|
SubModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
|
||||||
from invokeai.backend.model_manager.metadata.metadata_base import BaseMetadata
|
|
||||||
from invokeai.backend.model_manager.search import ModelSearch
|
from invokeai.backend.model_manager.search import ModelSearch
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
@ -90,44 +87,6 @@ example_model_input = {
|
|||||||
"variant": "normal",
|
"variant": "normal",
|
||||||
}
|
}
|
||||||
|
|
||||||
example_model_metadata = {
|
|
||||||
"name": "ip_adapter_sd_image_encoder",
|
|
||||||
"author": "InvokeAI",
|
|
||||||
"tags": [
|
|
||||||
"transformers",
|
|
||||||
"safetensors",
|
|
||||||
"clip_vision_model",
|
|
||||||
"endpoints_compatible",
|
|
||||||
"region:us",
|
|
||||||
"has_space",
|
|
||||||
"license:apache-2.0",
|
|
||||||
],
|
|
||||||
"files": [
|
|
||||||
{
|
|
||||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md",
|
|
||||||
"path": "ip_adapter_sd_image_encoder/README.md",
|
|
||||||
"size": 628,
|
|
||||||
"sha256": None,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json",
|
|
||||||
"path": "ip_adapter_sd_image_encoder/config.json",
|
|
||||||
"size": 560,
|
|
||||||
"sha256": None,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors",
|
|
||||||
"path": "ip_adapter_sd_image_encoder/model.safetensors",
|
|
||||||
"size": 2528373448,
|
|
||||||
"sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"type": "huggingface",
|
|
||||||
"id": "InvokeAI/ip_adapter_sd_image_encoder",
|
|
||||||
"tag_dict": {"license": "apache-2.0"},
|
|
||||||
"last_modified": "2023-09-23T17:33:25Z",
|
|
||||||
}
|
|
||||||
|
|
||||||
##############################################################################
|
##############################################################################
|
||||||
# ROUTES
|
# ROUTES
|
||||||
##############################################################################
|
##############################################################################
|
||||||
@ -219,79 +178,6 @@ async def list_model_summary(
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.get(
|
|
||||||
"/i/{key}/metadata",
|
|
||||||
operation_id="get_model_metadata",
|
|
||||||
responses={
|
|
||||||
200: {
|
|
||||||
"description": "The model metadata was retrieved successfully",
|
|
||||||
"content": {"application/json": {"example": example_model_metadata}},
|
|
||||||
},
|
|
||||||
400: {"description": "Bad request"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def get_model_metadata(
|
|
||||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
|
||||||
) -> Optional[AnyModelRepoMetadata]:
|
|
||||||
"""Get a model metadata object."""
|
|
||||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
|
||||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.patch(
|
|
||||||
"/i/{key}/metadata",
|
|
||||||
operation_id="update_model_metadata",
|
|
||||||
responses={
|
|
||||||
201: {
|
|
||||||
"description": "The model metadata was updated successfully",
|
|
||||||
"content": {"application/json": {"example": example_model_metadata}},
|
|
||||||
},
|
|
||||||
400: {"description": "Bad request"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def update_model_metadata(
|
|
||||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
|
||||||
changes: ModelMetadataChanges = Body(description="The changes"),
|
|
||||||
) -> Optional[AnyModelRepoMetadata]:
|
|
||||||
"""Updates or creates a model metadata object."""
|
|
||||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
|
||||||
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store
|
|
||||||
|
|
||||||
try:
|
|
||||||
original_metadata = record_store.get_metadata(key)
|
|
||||||
if original_metadata:
|
|
||||||
if changes.default_settings:
|
|
||||||
original_metadata.default_settings = changes.default_settings
|
|
||||||
|
|
||||||
metadata_store.update_metadata(key, original_metadata)
|
|
||||||
else:
|
|
||||||
metadata_store.add_metadata(
|
|
||||||
key, BaseMetadata(name="", author="", default_settings=changes.default_settings)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail=f"An error occurred while updating the model metadata: {e}",
|
|
||||||
)
|
|
||||||
|
|
||||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.get(
|
|
||||||
"/tags",
|
|
||||||
operation_id="list_tags",
|
|
||||||
)
|
|
||||||
async def list_tags() -> Set[str]:
|
|
||||||
"""Get a unique set of all the model tags."""
|
|
||||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
|
||||||
result: Set[str] = record_store.list_tags()
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class FoundModel(BaseModel):
|
class FoundModel(BaseModel):
|
||||||
path: str = Field(description="Path to the model")
|
path: str = Field(description="Path to the model")
|
||||||
is_installed: bool = Field(description="Whether or not the model is already installed")
|
is_installed: bool = Field(description="Whether or not the model is already installed")
|
||||||
@ -361,19 +247,6 @@ async def scan_for_models(
|
|||||||
return scan_results
|
return scan_results
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.get(
|
|
||||||
"/tags/search",
|
|
||||||
operation_id="search_by_metadata_tags",
|
|
||||||
)
|
|
||||||
async def search_by_metadata_tags(
|
|
||||||
tags: Set[str] = Query(default=None, description="Tags to search for"),
|
|
||||||
) -> ModelsList:
|
|
||||||
"""Get a list of models."""
|
|
||||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
|
||||||
results = record_store.search_by_metadata_tag(tags)
|
|
||||||
return ModelsList(models=results)
|
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.patch(
|
@model_manager_router.patch(
|
||||||
"/i/{key}",
|
"/i/{key}",
|
||||||
operation_id="update_model_record",
|
operation_id="update_model_record",
|
||||||
@ -562,9 +435,8 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
|
|||||||
* "cancelled" -- Job was cancelled before completion.
|
* "cancelled" -- Job was cancelled before completion.
|
||||||
|
|
||||||
Once completed, information about the model such as its size, base
|
Once completed, information about the model such as its size, base
|
||||||
model, type, and metadata can be retrieved from the `config_out`
|
model and type can be retrieved from the `config_out` field. For multi-file models such as diffusers,
|
||||||
field. For multi-file models such as diffusers, information on individual files
|
information on individual files can be retrieved from `download_parts`.
|
||||||
can be retrieved from `download_parts`.
|
|
||||||
|
|
||||||
See the example and schema below for more information.
|
See the example and schema below for more information.
|
||||||
"""
|
"""
|
||||||
@ -708,10 +580,6 @@ async def convert_model(
|
|||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
|
||||||
# get the original metadata
|
|
||||||
if orig_metadata := store.get_metadata(key):
|
|
||||||
store.metadata_store.add_metadata(new_key, orig_metadata)
|
|
||||||
|
|
||||||
# delete the original safetensors file
|
# delete the original safetensors file
|
||||||
installer.delete(key)
|
installer.delete(key)
|
||||||
|
|
||||||
|
@ -21,8 +21,6 @@ from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
|||||||
from invokeai.backend.model_manager.config import ModelSourceType
|
from invokeai.backend.model_manager.config import ModelSourceType
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||||
|
|
||||||
from ..model_metadata import ModelMetadataStoreBase
|
|
||||||
|
|
||||||
|
|
||||||
class InstallStatus(str, Enum):
|
class InstallStatus(str, Enum):
|
||||||
"""State of an install job running in the background."""
|
"""State of an install job running in the background."""
|
||||||
@ -268,7 +266,6 @@ class ModelInstallServiceBase(ABC):
|
|||||||
app_config: InvokeAIAppConfig,
|
app_config: InvokeAIAppConfig,
|
||||||
record_store: ModelRecordServiceBase,
|
record_store: ModelRecordServiceBase,
|
||||||
download_queue: DownloadQueueServiceBase,
|
download_queue: DownloadQueueServiceBase,
|
||||||
metadata_store: ModelMetadataStoreBase,
|
|
||||||
event_bus: Optional["EventServiceBase"] = None,
|
event_bus: Optional["EventServiceBase"] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -94,7 +94,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._running = False
|
self._running = False
|
||||||
self._session = session
|
self._session = session
|
||||||
self._next_job_id = 0
|
self._next_job_id = 0
|
||||||
self._metadata_store = record_store.metadata_store # for convenience
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
||||||
|
@ -1,9 +0,0 @@
|
|||||||
"""Init file for ModelMetadataStoreService module."""
|
|
||||||
|
|
||||||
from .metadata_store_base import ModelMetadataStoreBase
|
|
||||||
from .metadata_store_sql import ModelMetadataStoreSQL
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ModelMetadataStoreBase",
|
|
||||||
"ModelMetadataStoreSQL",
|
|
||||||
]
|
|
@ -1,81 +0,0 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
|
||||||
"""
|
|
||||||
Storage for Model Metadata
|
|
||||||
"""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import List, Optional, Set, Tuple
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
|
||||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelDefaultSettings
|
|
||||||
|
|
||||||
|
|
||||||
class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"):
|
|
||||||
"""A set of changes to apply to model metadata.
|
|
||||||
Only limited changes are valid:
|
|
||||||
- `default_settings`: the user-configured default settings for this model
|
|
||||||
"""
|
|
||||||
|
|
||||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
|
||||||
default=None, description="The user-configured default settings for this model"
|
|
||||||
)
|
|
||||||
"""The user-configured default settings for this model"""
|
|
||||||
|
|
||||||
|
|
||||||
class ModelMetadataStoreBase(ABC):
|
|
||||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
|
||||||
"""
|
|
||||||
Add a block of repo metadata to a model record.
|
|
||||||
|
|
||||||
The model record config must already exist in the database with the
|
|
||||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
|
||||||
|
|
||||||
:param model_key: Existing model key in the `model_config` table
|
|
||||||
:param metadata: ModelRepoMetadata object to store
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
|
||||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
|
||||||
"""Dump out all the metadata."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
|
||||||
"""
|
|
||||||
Update metadata corresponding to the model with the indicated key.
|
|
||||||
|
|
||||||
:param model_key: Existing model key in the `model_config` table
|
|
||||||
:param metadata: ModelRepoMetadata object to update
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_tags(self) -> Set[str]:
|
|
||||||
"""Return all tags in the tags table."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
|
||||||
"""Return the keys of models containing all of the listed tags."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def search_by_author(self, author: str) -> Set[str]:
|
|
||||||
"""Return the keys of models authored by the indicated author."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def search_by_name(self, name: str) -> Set[str]:
|
|
||||||
"""
|
|
||||||
Return the keys of models with the indicated name.
|
|
||||||
|
|
||||||
Note that this is the name of the model given to it by
|
|
||||||
the remote source. The user may have changed the local
|
|
||||||
name. The local name will be located in the model config
|
|
||||||
record object.
|
|
||||||
"""
|
|
@ -1,223 +0,0 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
|
||||||
"""
|
|
||||||
SQL Storage for Model Metadata
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sqlite3
|
|
||||||
from typing import List, Optional, Set, Tuple
|
|
||||||
|
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
|
||||||
from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase
|
|
||||||
|
|
||||||
from .metadata_store_base import ModelMetadataStoreBase
|
|
||||||
|
|
||||||
|
|
||||||
class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
|
||||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
|
||||||
|
|
||||||
def __init__(self, db: SqliteDatabase):
|
|
||||||
"""
|
|
||||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
|
||||||
|
|
||||||
:param conn: sqlite3 connection object
|
|
||||||
:param lock: threading Lock object
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self._db = db
|
|
||||||
self._cursor = self._db.conn.cursor()
|
|
||||||
|
|
||||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
|
||||||
"""
|
|
||||||
Add a block of repo metadata to a model record.
|
|
||||||
|
|
||||||
The model record config must already exist in the database with the
|
|
||||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
|
||||||
|
|
||||||
:param model_key: Existing model key in the `model_config` table
|
|
||||||
:param metadata: ModelRepoMetadata object to store
|
|
||||||
"""
|
|
||||||
json_serialized = metadata.model_dump_json()
|
|
||||||
with self._db.lock:
|
|
||||||
try:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
INSERT INTO model_metadata(
|
|
||||||
id,
|
|
||||||
metadata
|
|
||||||
)
|
|
||||||
VALUES (?,?);
|
|
||||||
""",
|
|
||||||
(
|
|
||||||
model_key,
|
|
||||||
json_serialized,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self._update_tags(model_key, metadata.tags)
|
|
||||||
self._db.conn.commit()
|
|
||||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
|
||||||
self._db.conn.rollback()
|
|
||||||
raise UnknownMetadataException from excp
|
|
||||||
except sqlite3.Error as excp:
|
|
||||||
self._db.conn.rollback()
|
|
||||||
raise excp
|
|
||||||
|
|
||||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
|
||||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
|
||||||
with self._db.lock:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT metadata FROM model_metadata
|
|
||||||
WHERE id=?;
|
|
||||||
""",
|
|
||||||
(model_key,),
|
|
||||||
)
|
|
||||||
rows = self._cursor.fetchone()
|
|
||||||
if not rows:
|
|
||||||
raise UnknownMetadataException("model metadata not found")
|
|
||||||
return ModelMetadataFetchBase.from_json(rows[0])
|
|
||||||
|
|
||||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
|
||||||
"""Dump out all the metadata."""
|
|
||||||
with self._db.lock:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT id,metadata FROM model_metadata;
|
|
||||||
""",
|
|
||||||
(),
|
|
||||||
)
|
|
||||||
rows = self._cursor.fetchall()
|
|
||||||
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
|
||||||
|
|
||||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
|
||||||
"""
|
|
||||||
Update metadata corresponding to the model with the indicated key.
|
|
||||||
|
|
||||||
:param model_key: Existing model key in the `model_config` table
|
|
||||||
:param metadata: ModelRepoMetadata object to update
|
|
||||||
"""
|
|
||||||
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
|
||||||
with self._db.lock:
|
|
||||||
try:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
UPDATE model_metadata
|
|
||||||
SET
|
|
||||||
metadata=?
|
|
||||||
WHERE id=?;
|
|
||||||
""",
|
|
||||||
(json_serialized, model_key),
|
|
||||||
)
|
|
||||||
if self._cursor.rowcount == 0:
|
|
||||||
raise UnknownMetadataException("model metadata not found")
|
|
||||||
self._update_tags(model_key, metadata.tags)
|
|
||||||
self._db.conn.commit()
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
self._db.conn.rollback()
|
|
||||||
raise e
|
|
||||||
|
|
||||||
return self.get_metadata(model_key)
|
|
||||||
|
|
||||||
def list_tags(self) -> Set[str]:
|
|
||||||
"""Return all tags in the tags table."""
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
select tag_text from tags;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
return {x[0] for x in self._cursor.fetchall()}
|
|
||||||
|
|
||||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
|
||||||
"""Return the keys of models containing all of the listed tags."""
|
|
||||||
with self._db.lock:
|
|
||||||
try:
|
|
||||||
matches: Optional[Set[str]] = None
|
|
||||||
for tag in tags:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT a.model_id FROM model_tags AS a,
|
|
||||||
tags AS b
|
|
||||||
WHERE a.tag_id=b.tag_id
|
|
||||||
AND b.tag_text=?;
|
|
||||||
""",
|
|
||||||
(tag,),
|
|
||||||
)
|
|
||||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
|
||||||
if matches is None:
|
|
||||||
matches = model_keys
|
|
||||||
matches = matches.intersection(model_keys)
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
raise e
|
|
||||||
return matches if matches else set()
|
|
||||||
|
|
||||||
def search_by_author(self, author: str) -> Set[str]:
|
|
||||||
"""Return the keys of models authored by the indicated author."""
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT id FROM model_metadata
|
|
||||||
WHERE author=?;
|
|
||||||
""",
|
|
||||||
(author,),
|
|
||||||
)
|
|
||||||
return {x[0] for x in self._cursor.fetchall()}
|
|
||||||
|
|
||||||
def search_by_name(self, name: str) -> Set[str]:
|
|
||||||
"""
|
|
||||||
Return the keys of models with the indicated name.
|
|
||||||
|
|
||||||
Note that this is the name of the model given to it by
|
|
||||||
the remote source. The user may have changed the local
|
|
||||||
name. The local name will be located in the model config
|
|
||||||
record object.
|
|
||||||
"""
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT id FROM model_metadata
|
|
||||||
WHERE name=?;
|
|
||||||
""",
|
|
||||||
(name,),
|
|
||||||
)
|
|
||||||
return {x[0] for x in self._cursor.fetchall()}
|
|
||||||
|
|
||||||
def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None:
|
|
||||||
"""Update tags for the model referenced by model_key."""
|
|
||||||
if tags:
|
|
||||||
# remove previous tags from this model
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
DELETE FROM model_tags
|
|
||||||
WHERE model_id=?;
|
|
||||||
""",
|
|
||||||
(model_key,),
|
|
||||||
)
|
|
||||||
|
|
||||||
for tag in tags:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
INSERT OR IGNORE INTO tags (
|
|
||||||
tag_text
|
|
||||||
)
|
|
||||||
VALUES (?);
|
|
||||||
""",
|
|
||||||
(tag,),
|
|
||||||
)
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT tag_id
|
|
||||||
FROM tags
|
|
||||||
WHERE tag_text = ?
|
|
||||||
LIMIT 1;
|
|
||||||
""",
|
|
||||||
(tag,),
|
|
||||||
)
|
|
||||||
tag_id = self._cursor.fetchone()[0]
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
INSERT OR IGNORE INTO model_tags (
|
|
||||||
model_id,
|
|
||||||
tag_id
|
|
||||||
)
|
|
||||||
VALUES (?,?);
|
|
||||||
""",
|
|
||||||
(model_key, tag_id),
|
|
||||||
)
|
|
@ -6,7 +6,7 @@ Abstract base class for storing and retrieving model configuration records.
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -17,9 +17,6 @@ from invokeai.backend.model_manager import (
|
|||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
|
||||||
|
|
||||||
from ..model_metadata import ModelMetadataStoreBase
|
|
||||||
|
|
||||||
|
|
||||||
class DuplicateModelException(Exception):
|
class DuplicateModelException(Exception):
|
||||||
@ -109,40 +106,6 @@ class ModelRecordServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def metadata_store(self) -> ModelMetadataStoreBase:
|
|
||||||
"""Return a ModelMetadataStore initialized on the same database."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
|
||||||
"""
|
|
||||||
Retrieve metadata (if any) from when model was downloaded from a repo.
|
|
||||||
|
|
||||||
:param key: Model key
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
|
||||||
"""List metadata for all models that have it."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
|
||||||
"""
|
|
||||||
Search model metadata for ones with all listed tags and return their corresponding configs.
|
|
||||||
|
|
||||||
:param tags: Set of tags to search for. All tags must be present.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_tags(self) -> Set[str]:
|
|
||||||
"""Return a unique set of all the model tags in the metadata database."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list_models(
|
def list_models(
|
||||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||||
|
@ -43,7 +43,7 @@ import json
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
@ -53,9 +53,7 @@ from invokeai.backend.model_manager.config import (
|
|||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
|
||||||
|
|
||||||
from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
|
||||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
from .model_records_base import (
|
from .model_records_base import (
|
||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
@ -69,7 +67,7 @@ from .model_records_base import (
|
|||||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||||
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
||||||
|
|
||||||
def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase):
|
def __init__(self, db: SqliteDatabase):
|
||||||
"""
|
"""
|
||||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||||
|
|
||||||
@ -78,7 +76,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self._db = db
|
self._db = db
|
||||||
self._cursor = db.conn.cursor()
|
self._cursor = db.conn.cursor()
|
||||||
self._metadata_store = metadata_store
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def db(self) -> SqliteDatabase:
|
def db(self) -> SqliteDatabase:
|
||||||
@ -242,9 +239,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
If none of the optional filters are passed, will return all
|
If none of the optional filters are passed, will return all
|
||||||
models in the database.
|
models in the database.
|
||||||
"""
|
"""
|
||||||
results = []
|
where_clause: list[str] = []
|
||||||
where_clause = []
|
bindings: list[str] = []
|
||||||
bindings = []
|
|
||||||
if model_name:
|
if model_name:
|
||||||
where_clause.append("name=?")
|
where_clause.append("name=?")
|
||||||
bindings.append(model_name)
|
bindings.append(model_name)
|
||||||
@ -302,55 +298,17 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
]
|
]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@property
|
|
||||||
def metadata_store(self) -> ModelMetadataStoreBase:
|
|
||||||
"""Return a ModelMetadataStore initialized on the same database."""
|
|
||||||
return self._metadata_store
|
|
||||||
|
|
||||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
|
||||||
"""
|
|
||||||
Retrieve metadata (if any) from when model was downloaded from a repo.
|
|
||||||
|
|
||||||
:param key: Model key
|
|
||||||
"""
|
|
||||||
store = self.metadata_store
|
|
||||||
try:
|
|
||||||
metadata = store.get_metadata(key)
|
|
||||||
return metadata
|
|
||||||
except UnknownMetadataException:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
|
||||||
"""
|
|
||||||
Search model metadata for ones with all listed tags and return their corresponding configs.
|
|
||||||
|
|
||||||
:param tags: Set of tags to search for. All tags must be present.
|
|
||||||
"""
|
|
||||||
store = ModelMetadataStoreSQL(self._db)
|
|
||||||
keys = store.search_by_tag(tags)
|
|
||||||
return [self.get_model(x) for x in keys]
|
|
||||||
|
|
||||||
def list_tags(self) -> Set[str]:
|
|
||||||
"""Return a unique set of all the model tags in the metadata database."""
|
|
||||||
store = ModelMetadataStoreSQL(self._db)
|
|
||||||
return store.list_tags()
|
|
||||||
|
|
||||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
|
||||||
"""List metadata for all models that have it."""
|
|
||||||
store = ModelMetadataStoreSQL(self._db)
|
|
||||||
return store.list_all_metadata()
|
|
||||||
|
|
||||||
def list_models(
|
def list_models(
|
||||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||||
) -> PaginatedResults[ModelSummary]:
|
) -> PaginatedResults[ModelSummary]:
|
||||||
"""Return a paginated summary listing of each model in the database."""
|
"""Return a paginated summary listing of each model in the database."""
|
||||||
assert isinstance(order_by, ModelRecordOrderBy)
|
assert isinstance(order_by, ModelRecordOrderBy)
|
||||||
ordering = {
|
ordering = {
|
||||||
ModelRecordOrderBy.Default: "a.type, a.base, a.format, a.name",
|
ModelRecordOrderBy.Default: "type, base, format, name",
|
||||||
ModelRecordOrderBy.Type: "a.type",
|
ModelRecordOrderBy.Type: "type",
|
||||||
ModelRecordOrderBy.Base: "a.base",
|
ModelRecordOrderBy.Base: "base",
|
||||||
ModelRecordOrderBy.Name: "a.name",
|
ModelRecordOrderBy.Name: "name",
|
||||||
ModelRecordOrderBy.Format: "a.format",
|
ModelRecordOrderBy.Format: "format",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Lock so that the database isn't updated while we're doing the two queries.
|
# Lock so that the database isn't updated while we're doing the two queries.
|
||||||
@ -364,7 +322,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
)
|
)
|
||||||
total = int(self._cursor.fetchone()[0])
|
total = int(self._cursor.fetchone()[0])
|
||||||
|
|
||||||
# query2: fetch key fields from the join of models and model_metadata
|
# query2: fetch key fields
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""--sql
|
f"""--sql
|
||||||
SELECT config
|
SELECT config
|
||||||
|
@ -6,6 +6,15 @@ from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import
|
|||||||
class Migration7Callback:
|
class Migration7Callback:
|
||||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||||
self._create_models_table(cursor)
|
self._create_models_table(cursor)
|
||||||
|
self._drop_old_models_tables(cursor)
|
||||||
|
|
||||||
|
def _drop_old_models_tables(self, cursor: sqlite3.Cursor) -> None:
|
||||||
|
"""Drops the old model_records, model_metadata, model_tags and tags tables."""
|
||||||
|
|
||||||
|
tables = ["model_records", "model_metadata", "model_tags", "tags"]
|
||||||
|
|
||||||
|
for table in tables:
|
||||||
|
cursor.execute(f"DROP TABLE IF EXISTS {table};")
|
||||||
|
|
||||||
def _create_models_table(self, cursor: sqlite3.Cursor) -> None:
|
def _create_models_table(self, cursor: sqlite3.Cursor) -> None:
|
||||||
"""Creates the v4.0.0 models table."""
|
"""Creates the v4.0.0 models table."""
|
||||||
@ -67,7 +76,7 @@ def build_migration_7() -> Migration:
|
|||||||
|
|
||||||
This migration does the following:
|
This migration does the following:
|
||||||
- Adds the new models table
|
- Adds the new models table
|
||||||
- TODO(MM2): Drops the old model_records, model_metadata, model_tags and tags tables.
|
- Drops the old model_records, model_metadata, model_tags and tags tables.
|
||||||
- TODO(MM2): Migrates model names and descriptions from `models.yaml` to the new table (?).
|
- TODO(MM2): Migrates model names and descriptions from `models.yaml` to the new table (?).
|
||||||
"""
|
"""
|
||||||
migration_7 = Migration(
|
migration_7 = Migration(
|
||||||
|
@ -1,55 +0,0 @@
|
|||||||
import json
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import ValidationError
|
|
||||||
|
|
||||||
from invokeai.app.services.shared.graph import Edge
|
|
||||||
|
|
||||||
|
|
||||||
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
|
|
||||||
"""
|
|
||||||
Parses raw session string, returning a dict of the graph.
|
|
||||||
|
|
||||||
Only the general graph shape is validated; none of the fields are validated.
|
|
||||||
|
|
||||||
Any `metadata_accumulator` nodes and edges are removed.
|
|
||||||
|
|
||||||
Any validation failure will return None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
graph = json.loads(session_raw).get("graph", None)
|
|
||||||
|
|
||||||
# sanity check make sure the graph is at least reasonably shaped
|
|
||||||
if (
|
|
||||||
not isinstance(graph, dict)
|
|
||||||
or "nodes" not in graph
|
|
||||||
or not isinstance(graph["nodes"], dict)
|
|
||||||
or "edges" not in graph
|
|
||||||
or not isinstance(graph["edges"], list)
|
|
||||||
):
|
|
||||||
# something has gone terribly awry, return an empty dict
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# delete the `metadata_accumulator` node
|
|
||||||
del graph["nodes"]["metadata_accumulator"]
|
|
||||||
except KeyError:
|
|
||||||
# no accumulator node, all good
|
|
||||||
pass
|
|
||||||
|
|
||||||
# delete any edges to or from it
|
|
||||||
for i, edge in enumerate(graph["edges"]):
|
|
||||||
try:
|
|
||||||
# try to parse the edge
|
|
||||||
Edge(**edge)
|
|
||||||
except ValidationError:
|
|
||||||
# something has gone terribly awry, return an empty dict
|
|
||||||
return None
|
|
||||||
|
|
||||||
if (
|
|
||||||
edge["source"]["node_id"] == "metadata_accumulator"
|
|
||||||
or edge["destination"]["node_id"] == "metadata_accumulator"
|
|
||||||
):
|
|
||||||
del graph["edges"][i]
|
|
||||||
|
|
||||||
return graph
|
|
@ -29,6 +29,8 @@ from diffusers.models.modeling_utils import ModelMixin
|
|||||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||||
from typing_extensions import Annotated, Any, Dict
|
from typing_extensions import Annotated, Any, Dict
|
||||||
|
|
||||||
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
from ..raw_model import RawModel
|
from ..raw_model import RawModel
|
||||||
|
|
||||||
# ModelMixin is the base class for all diffusers and transformers models
|
# ModelMixin is the base class for all diffusers and transformers models
|
||||||
@ -132,7 +134,7 @@ class ModelSourceType(str, Enum):
|
|||||||
class ModelConfigBase(BaseModel):
|
class ModelConfigBase(BaseModel):
|
||||||
"""Base class for model configuration information."""
|
"""Base class for model configuration information."""
|
||||||
|
|
||||||
key: str = Field(description="A unique key for this model.")
|
key: str = Field(description="A unique key for this model.", default_factory=uuid_string)
|
||||||
hash: str = Field(description="The hash of the model file(s).")
|
hash: str = Field(description="The hash of the model file(s).")
|
||||||
path: str = Field(
|
path: str = Field(
|
||||||
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory."
|
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory."
|
||||||
@ -142,7 +144,9 @@ class ModelConfigBase(BaseModel):
|
|||||||
description: Optional[str] = Field(description="Model description", default=None)
|
description: Optional[str] = Field(description="Model description", default=None)
|
||||||
source: str = Field(description="The original source of the model (path, URL or repo_id).")
|
source: str = Field(description="The original source of the model (path, URL or repo_id).")
|
||||||
source_type: ModelSourceType = Field(description="The type of source")
|
source_type: ModelSourceType = Field(description="The type of source")
|
||||||
source_api_response: Optional[str] = Field(description="The original API response from the source, as stringified JSON.", default=None)
|
source_api_response: Optional[str] = Field(
|
||||||
|
description="The original API response from the source, as stringified JSON.", default=None
|
||||||
|
)
|
||||||
trigger_words: Optional[set[str]] = Field(description="Set of trigger words for this model", default=None)
|
trigger_words: Optional[set[str]] = Field(description="Set of trigger words for this model", default=None)
|
||||||
|
|
||||||
model_config = ConfigDict(use_enum_values=False, validate_assignment=True)
|
model_config = ConfigDict(use_enum_values=False, validate_assignment=True)
|
||||||
|
@ -91,7 +91,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return HuggingFaceMetadata(
|
return HuggingFaceMetadata(
|
||||||
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__)
|
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str)
|
||||||
)
|
)
|
||||||
|
|
||||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||||
|
@ -1,221 +0,0 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
|
||||||
"""
|
|
||||||
SQL Storage for Model Metadata
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sqlite3
|
|
||||||
from typing import List, Optional, Set, Tuple
|
|
||||||
|
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
|
||||||
|
|
||||||
from .fetch import ModelMetadataFetchBase
|
|
||||||
from .metadata_base import AnyModelRepoMetadata, UnknownMetadataException
|
|
||||||
|
|
||||||
|
|
||||||
class ModelMetadataStore:
|
|
||||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
|
||||||
|
|
||||||
def __init__(self, db: SqliteDatabase):
|
|
||||||
"""
|
|
||||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
|
||||||
|
|
||||||
:param conn: sqlite3 connection object
|
|
||||||
:param lock: threading Lock object
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self._db = db
|
|
||||||
self._cursor = self._db.conn.cursor()
|
|
||||||
|
|
||||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
|
||||||
"""
|
|
||||||
Add a block of repo metadata to a model record.
|
|
||||||
|
|
||||||
The model record config must already exist in the database with the
|
|
||||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
|
||||||
|
|
||||||
:param model_key: Existing model key in the `model_config` table
|
|
||||||
:param metadata: ModelRepoMetadata object to store
|
|
||||||
"""
|
|
||||||
json_serialized = metadata.model_dump_json()
|
|
||||||
with self._db.lock:
|
|
||||||
try:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
INSERT INTO model_metadata(
|
|
||||||
id,
|
|
||||||
metadata
|
|
||||||
)
|
|
||||||
VALUES (?,?);
|
|
||||||
""",
|
|
||||||
(
|
|
||||||
model_key,
|
|
||||||
json_serialized,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self._update_tags(model_key, metadata.tags)
|
|
||||||
self._db.conn.commit()
|
|
||||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
|
||||||
self._db.conn.rollback()
|
|
||||||
raise UnknownMetadataException from excp
|
|
||||||
except sqlite3.Error as excp:
|
|
||||||
self._db.conn.rollback()
|
|
||||||
raise excp
|
|
||||||
|
|
||||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
|
||||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
|
||||||
with self._db.lock:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT metadata FROM model_metadata
|
|
||||||
WHERE id=?;
|
|
||||||
""",
|
|
||||||
(model_key,),
|
|
||||||
)
|
|
||||||
rows = self._cursor.fetchone()
|
|
||||||
if not rows:
|
|
||||||
raise UnknownMetadataException("model metadata not found")
|
|
||||||
return ModelMetadataFetchBase.from_json(rows[0])
|
|
||||||
|
|
||||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
|
||||||
"""Dump out all the metadata."""
|
|
||||||
with self._db.lock:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT id,metadata FROM model_metadata;
|
|
||||||
""",
|
|
||||||
(),
|
|
||||||
)
|
|
||||||
rows = self._cursor.fetchall()
|
|
||||||
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
|
||||||
|
|
||||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
|
||||||
"""
|
|
||||||
Update metadata corresponding to the model with the indicated key.
|
|
||||||
|
|
||||||
:param model_key: Existing model key in the `model_config` table
|
|
||||||
:param metadata: ModelRepoMetadata object to update
|
|
||||||
"""
|
|
||||||
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
|
||||||
with self._db.lock:
|
|
||||||
try:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
UPDATE model_metadata
|
|
||||||
SET
|
|
||||||
metadata=?
|
|
||||||
WHERE id=?;
|
|
||||||
""",
|
|
||||||
(json_serialized, model_key),
|
|
||||||
)
|
|
||||||
if self._cursor.rowcount == 0:
|
|
||||||
raise UnknownMetadataException("model metadata not found")
|
|
||||||
self._update_tags(model_key, metadata.tags)
|
|
||||||
self._db.conn.commit()
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
self._db.conn.rollback()
|
|
||||||
raise e
|
|
||||||
|
|
||||||
return self.get_metadata(model_key)
|
|
||||||
|
|
||||||
def list_tags(self) -> Set[str]:
|
|
||||||
"""Return all tags in the tags table."""
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
select tag_text from tags;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
return {x[0] for x in self._cursor.fetchall()}
|
|
||||||
|
|
||||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
|
||||||
"""Return the keys of models containing all of the listed tags."""
|
|
||||||
with self._db.lock:
|
|
||||||
try:
|
|
||||||
matches: Optional[Set[str]] = None
|
|
||||||
for tag in tags:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT a.model_id FROM model_tags AS a,
|
|
||||||
tags AS b
|
|
||||||
WHERE a.tag_id=b.tag_id
|
|
||||||
AND b.tag_text=?;
|
|
||||||
""",
|
|
||||||
(tag,),
|
|
||||||
)
|
|
||||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
|
||||||
if matches is None:
|
|
||||||
matches = model_keys
|
|
||||||
matches = matches.intersection(model_keys)
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
raise e
|
|
||||||
return matches if matches else set()
|
|
||||||
|
|
||||||
def search_by_author(self, author: str) -> Set[str]:
|
|
||||||
"""Return the keys of models authored by the indicated author."""
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT id FROM model_metadata
|
|
||||||
WHERE author=?;
|
|
||||||
""",
|
|
||||||
(author,),
|
|
||||||
)
|
|
||||||
return {x[0] for x in self._cursor.fetchall()}
|
|
||||||
|
|
||||||
def search_by_name(self, name: str) -> Set[str]:
|
|
||||||
"""
|
|
||||||
Return the keys of models with the indicated name.
|
|
||||||
|
|
||||||
Note that this is the name of the model given to it by
|
|
||||||
the remote source. The user may have changed the local
|
|
||||||
name. The local name will be located in the model config
|
|
||||||
record object.
|
|
||||||
"""
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT id FROM model_metadata
|
|
||||||
WHERE name=?;
|
|
||||||
""",
|
|
||||||
(name,),
|
|
||||||
)
|
|
||||||
return {x[0] for x in self._cursor.fetchall()}
|
|
||||||
|
|
||||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
|
||||||
"""Update tags for the model referenced by model_key."""
|
|
||||||
# remove previous tags from this model
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
DELETE FROM model_tags
|
|
||||||
WHERE model_id=?;
|
|
||||||
""",
|
|
||||||
(model_key,),
|
|
||||||
)
|
|
||||||
|
|
||||||
for tag in tags:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
INSERT OR IGNORE INTO tags (
|
|
||||||
tag_text
|
|
||||||
)
|
|
||||||
VALUES (?);
|
|
||||||
""",
|
|
||||||
(tag,),
|
|
||||||
)
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT tag_id
|
|
||||||
FROM tags
|
|
||||||
WHERE tag_text = ?
|
|
||||||
LIMIT 1;
|
|
||||||
""",
|
|
||||||
(tag,),
|
|
||||||
)
|
|
||||||
tag_id = self._cursor.fetchone()[0]
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
INSERT OR IGNORE INTO model_tags (
|
|
||||||
model_id,
|
|
||||||
tag_id
|
|
||||||
)
|
|
||||||
VALUES (?,?);
|
|
||||||
""",
|
|
||||||
(model_key, tag_id),
|
|
||||||
)
|
|
@ -18,6 +18,7 @@ from .config import (
|
|||||||
ModelConfigFactory,
|
ModelConfigFactory,
|
||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelRepoVariant,
|
ModelRepoVariant,
|
||||||
|
ModelSourceType,
|
||||||
ModelType,
|
ModelType,
|
||||||
ModelVariantType,
|
ModelVariantType,
|
||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
@ -150,7 +151,7 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
probe = probe_class(model_path)
|
probe = probe_class(model_path)
|
||||||
|
|
||||||
fields["source_type"] = fields.get("source_type")
|
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
|
||||||
fields["source"] = fields.get("source") or model_path.as_posix()
|
fields["source"] = fields.get("source") or model_path.as_posix()
|
||||||
fields["key"] = fields.get("key", uuid_string())
|
fields["key"] = fields.get("key", uuid_string())
|
||||||
fields["path"] = model_path.as_posix()
|
fields["path"] = model_path.as_posix()
|
||||||
|
@ -3,29 +3,26 @@ Test the refactored model config classes.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
|
||||||
from invokeai.app.services.model_records import (
|
from invokeai.app.services.model_records import (
|
||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
ModelRecordOrderBy,
|
|
||||||
ModelRecordServiceBase,
|
ModelRecordServiceBase,
|
||||||
ModelRecordServiceSQL,
|
ModelRecordServiceSQL,
|
||||||
UnknownModelException,
|
UnknownModelException,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
MainCheckpointConfig,
|
|
||||||
MainDiffusersConfig,
|
MainDiffusersConfig,
|
||||||
ModelFormat,
|
ModelFormat,
|
||||||
|
ModelSourceType,
|
||||||
ModelType,
|
ModelType,
|
||||||
TextualInversionFileConfig,
|
TextualInversionFileConfig,
|
||||||
VaeDiffusersConfig,
|
VaeDiffusersConfig,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.metadata import BaseMetadata
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||||
@ -38,11 +35,13 @@ def store(
|
|||||||
config = InvokeAIAppConfig(root=datadir)
|
config = InvokeAIAppConfig(root=datadir)
|
||||||
logger = InvokeAILogger.get_logger(config=config)
|
logger = InvokeAILogger.get_logger(config=config)
|
||||||
db = create_mock_sqlite_database(config, logger)
|
db = create_mock_sqlite_database(config, logger)
|
||||||
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
return ModelRecordServiceSQL(db)
|
||||||
|
|
||||||
|
|
||||||
def example_config() -> TextualInversionFileConfig:
|
def example_config(key: Optional[str] = None) -> TextualInversionFileConfig:
|
||||||
return TextualInversionFileConfig(
|
config = TextualInversionFileConfig(
|
||||||
|
source="test/source/",
|
||||||
|
source_type=ModelSourceType.Path,
|
||||||
path="/tmp/pokemon.bin",
|
path="/tmp/pokemon.bin",
|
||||||
name="old name",
|
name="old name",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
@ -50,59 +49,45 @@ def example_config() -> TextualInversionFileConfig:
|
|||||||
format=ModelFormat.EmbeddingFile,
|
format=ModelFormat.EmbeddingFile,
|
||||||
hash="ABC123",
|
hash="ABC123",
|
||||||
)
|
)
|
||||||
|
if key is not None:
|
||||||
|
config.key = key
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
def test_type(store: ModelRecordServiceBase):
|
def test_type(store: ModelRecordServiceBase):
|
||||||
config = example_config()
|
config = example_config("key1")
|
||||||
store.add_model("key1", config)
|
store.add_model(config)
|
||||||
config1 = store.get_model("key1")
|
config1 = store.get_model("key1")
|
||||||
assert type(config1) == TextualInversionFileConfig
|
assert isinstance(config1, TextualInversionFileConfig)
|
||||||
|
|
||||||
|
|
||||||
def test_add(store: ModelRecordServiceBase):
|
def test_raises_on_violating_uniqueness(store: ModelRecordServiceBase):
|
||||||
raw = {
|
# Models have a uniqueness constraint by their name, base and type
|
||||||
"path": "/tmp/foo.ckpt",
|
config1 = example_config("key1")
|
||||||
"name": "model1",
|
config2 = config1.model_copy(deep=True)
|
||||||
"base": BaseModelType.StableDiffusion1,
|
config2.key = "key2"
|
||||||
"type": "main",
|
store.add_model(config1)
|
||||||
"config_path": "/tmp/foo.yaml",
|
|
||||||
"variant": "normal",
|
|
||||||
"format": "checkpoint",
|
|
||||||
"original_hash": "111222333444",
|
|
||||||
}
|
|
||||||
store.add_model("key1", raw)
|
|
||||||
config1 = store.get_model("key1")
|
|
||||||
assert config1 is not None
|
|
||||||
assert type(config1) == MainCheckpointConfig
|
|
||||||
assert config1.base == BaseModelType.StableDiffusion1
|
|
||||||
assert config1.name == "model1"
|
|
||||||
assert config1.hash == "111222333444"
|
|
||||||
|
|
||||||
|
|
||||||
def test_dup(store: ModelRecordServiceBase):
|
|
||||||
config = example_config()
|
|
||||||
store.add_model("key1", example_config())
|
|
||||||
with pytest.raises(DuplicateModelException):
|
with pytest.raises(DuplicateModelException):
|
||||||
store.add_model("key1", config)
|
store.add_model(config1)
|
||||||
with pytest.raises(DuplicateModelException):
|
with pytest.raises(DuplicateModelException):
|
||||||
store.add_model("key2", config)
|
store.add_model(config2)
|
||||||
|
|
||||||
|
|
||||||
def test_update(store: ModelRecordServiceBase):
|
def test_update(store: ModelRecordServiceBase):
|
||||||
config = example_config()
|
config = example_config("key1")
|
||||||
store.add_model("key1", config)
|
store.add_model(config)
|
||||||
config = store.get_model("key1")
|
config = store.get_model("key1")
|
||||||
assert config.name == "old name"
|
assert config.name == "old name"
|
||||||
|
|
||||||
config.name = "new name"
|
config.name = "new name"
|
||||||
store.update_model("key1", config)
|
store.update_model(config.key, config)
|
||||||
new_config = store.get_model("key1")
|
new_config = store.get_model("key1")
|
||||||
assert new_config.name == "new name"
|
assert new_config.name == "new name"
|
||||||
|
|
||||||
|
|
||||||
def test_rename(store: ModelRecordServiceBase):
|
def test_rename(store: ModelRecordServiceBase):
|
||||||
config = example_config()
|
config = example_config("key1")
|
||||||
store.add_model("key1", config)
|
store.add_model(config)
|
||||||
config = store.get_model("key1")
|
config = store.get_model("key1")
|
||||||
assert config.name == "old name"
|
assert config.name == "old name"
|
||||||
|
|
||||||
@ -112,15 +97,15 @@ def test_rename(store: ModelRecordServiceBase):
|
|||||||
|
|
||||||
|
|
||||||
def test_unknown_key(store: ModelRecordServiceBase):
|
def test_unknown_key(store: ModelRecordServiceBase):
|
||||||
config = example_config()
|
config = example_config("key1")
|
||||||
store.add_model("key1", config)
|
store.add_model(config)
|
||||||
with pytest.raises(UnknownModelException):
|
with pytest.raises(UnknownModelException):
|
||||||
store.update_model("unknown_key", config)
|
store.update_model("unknown_key", config)
|
||||||
|
|
||||||
|
|
||||||
def test_delete(store: ModelRecordServiceBase):
|
def test_delete(store: ModelRecordServiceBase):
|
||||||
config = example_config()
|
config = example_config("key1")
|
||||||
store.add_model("key1", config)
|
store.add_model(config)
|
||||||
config = store.get_model("key1")
|
config = store.get_model("key1")
|
||||||
store.del_model("key1")
|
store.del_model("key1")
|
||||||
with pytest.raises(UnknownModelException):
|
with pytest.raises(UnknownModelException):
|
||||||
@ -128,36 +113,45 @@ def test_delete(store: ModelRecordServiceBase):
|
|||||||
|
|
||||||
|
|
||||||
def test_exists(store: ModelRecordServiceBase):
|
def test_exists(store: ModelRecordServiceBase):
|
||||||
config = example_config()
|
config = example_config("key1")
|
||||||
store.add_model("key1", config)
|
store.add_model(config)
|
||||||
assert store.exists("key1")
|
assert store.exists("key1")
|
||||||
assert not store.exists("key2")
|
assert not store.exists("key2")
|
||||||
|
|
||||||
|
|
||||||
def test_filter(store: ModelRecordServiceBase):
|
def test_filter(store: ModelRecordServiceBase):
|
||||||
config1 = MainDiffusersConfig(
|
config1 = MainDiffusersConfig(
|
||||||
|
key="config1",
|
||||||
path="/tmp/config1",
|
path="/tmp/config1",
|
||||||
name="config1",
|
name="config1",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType.Main,
|
type=ModelType.Main,
|
||||||
hash="CONFIG1HASH",
|
hash="CONFIG1HASH",
|
||||||
|
source="test/source",
|
||||||
|
source_type=ModelSourceType.Path,
|
||||||
)
|
)
|
||||||
config2 = MainDiffusersConfig(
|
config2 = MainDiffusersConfig(
|
||||||
|
key="config2",
|
||||||
path="/tmp/config2",
|
path="/tmp/config2",
|
||||||
name="config2",
|
name="config2",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType.Main,
|
type=ModelType.Main,
|
||||||
hash="CONFIG2HASH",
|
hash="CONFIG2HASH",
|
||||||
|
source="test/source",
|
||||||
|
source_type=ModelSourceType.Path,
|
||||||
)
|
)
|
||||||
config3 = VaeDiffusersConfig(
|
config3 = VaeDiffusersConfig(
|
||||||
|
key="config3",
|
||||||
path="/tmp/config3",
|
path="/tmp/config3",
|
||||||
name="config3",
|
name="config3",
|
||||||
base=BaseModelType("sd-2"),
|
base=BaseModelType("sd-2"),
|
||||||
type=ModelType.Vae,
|
type=ModelType.Vae,
|
||||||
hash="CONFIG3HASH",
|
hash="CONFIG3HASH",
|
||||||
|
source="test/source",
|
||||||
|
source_type=ModelSourceType.Path,
|
||||||
)
|
)
|
||||||
for c in config1, config2, config3:
|
for c in config1, config2, config3:
|
||||||
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
|
store.add_model(c)
|
||||||
matches = store.search_by_attr(model_type=ModelType.Main)
|
matches = store.search_by_attr(model_type=ModelType.Main)
|
||||||
assert len(matches) == 2
|
assert len(matches) == 2
|
||||||
assert matches[0].name in {"config1", "config2"}
|
assert matches[0].name in {"config1", "config2"}
|
||||||
@ -165,7 +159,7 @@ def test_filter(store: ModelRecordServiceBase):
|
|||||||
matches = store.search_by_attr(model_type=ModelType.Vae)
|
matches = store.search_by_attr(model_type=ModelType.Vae)
|
||||||
assert len(matches) == 1
|
assert len(matches) == 1
|
||||||
assert matches[0].name == "config3"
|
assert matches[0].name == "config3"
|
||||||
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
|
assert matches[0].key == "config3"
|
||||||
assert isinstance(matches[0].type, ModelType) # This tests that we get proper enums back
|
assert isinstance(matches[0].type, ModelType) # This tests that we get proper enums back
|
||||||
|
|
||||||
matches = store.search_by_hash("CONFIG1HASH")
|
matches = store.search_by_hash("CONFIG1HASH")
|
||||||
@ -183,6 +177,8 @@ def test_unique(store: ModelRecordServiceBase):
|
|||||||
type=ModelType.Main,
|
type=ModelType.Main,
|
||||||
name="nonuniquename",
|
name="nonuniquename",
|
||||||
hash="CONFIG1HASH",
|
hash="CONFIG1HASH",
|
||||||
|
source="test/source/",
|
||||||
|
source_type=ModelSourceType.Path,
|
||||||
)
|
)
|
||||||
config2 = MainDiffusersConfig(
|
config2 = MainDiffusersConfig(
|
||||||
path="/tmp/config2",
|
path="/tmp/config2",
|
||||||
@ -190,6 +186,8 @@ def test_unique(store: ModelRecordServiceBase):
|
|||||||
type=ModelType.Main,
|
type=ModelType.Main,
|
||||||
name="nonuniquename",
|
name="nonuniquename",
|
||||||
hash="CONFIG1HASH",
|
hash="CONFIG1HASH",
|
||||||
|
source="test/source/",
|
||||||
|
source_type=ModelSourceType.Path,
|
||||||
)
|
)
|
||||||
config3 = VaeDiffusersConfig(
|
config3 = VaeDiffusersConfig(
|
||||||
path="/tmp/config3",
|
path="/tmp/config3",
|
||||||
@ -197,6 +195,8 @@ def test_unique(store: ModelRecordServiceBase):
|
|||||||
type=ModelType.Vae,
|
type=ModelType.Vae,
|
||||||
name="nonuniquename",
|
name="nonuniquename",
|
||||||
hash="CONFIG1HASH",
|
hash="CONFIG1HASH",
|
||||||
|
source="test/source/",
|
||||||
|
source_type=ModelSourceType.Path,
|
||||||
)
|
)
|
||||||
config4 = MainDiffusersConfig(
|
config4 = MainDiffusersConfig(
|
||||||
path="/tmp/config4",
|
path="/tmp/config4",
|
||||||
@ -204,15 +204,19 @@ def test_unique(store: ModelRecordServiceBase):
|
|||||||
type=ModelType.Main,
|
type=ModelType.Main,
|
||||||
name="nonuniquename",
|
name="nonuniquename",
|
||||||
hash="CONFIG1HASH",
|
hash="CONFIG1HASH",
|
||||||
|
source="test/source/",
|
||||||
|
source_type=ModelSourceType.Path,
|
||||||
)
|
)
|
||||||
# config1, config2 and config3 are compatible because they have unique combos
|
# config1, config2 and config3 are compatible because they have unique combos
|
||||||
# of name, type and base
|
# of name, type and base
|
||||||
for c in config1, config2, config3:
|
for c in config1, config2, config3:
|
||||||
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
|
c.key = sha256(c.path.encode("utf-8")).hexdigest()
|
||||||
|
store.add_model(c)
|
||||||
|
|
||||||
# config4 clashes with config1 and should raise an integrity error
|
# config4 clashes with config1 and should raise an integrity error
|
||||||
with pytest.raises(DuplicateModelException):
|
with pytest.raises(DuplicateModelException):
|
||||||
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), config4)
|
config4.key = sha256(config4.path.encode("utf-8")).hexdigest()
|
||||||
|
store.add_model(config4)
|
||||||
|
|
||||||
|
|
||||||
def test_filter_2(store: ModelRecordServiceBase):
|
def test_filter_2(store: ModelRecordServiceBase):
|
||||||
@ -222,6 +226,8 @@ def test_filter_2(store: ModelRecordServiceBase):
|
|||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType.Main,
|
type=ModelType.Main,
|
||||||
hash="CONFIG1HASH",
|
hash="CONFIG1HASH",
|
||||||
|
source="test/source/",
|
||||||
|
source_type=ModelSourceType.Path,
|
||||||
)
|
)
|
||||||
config2 = MainDiffusersConfig(
|
config2 = MainDiffusersConfig(
|
||||||
path="/tmp/config2",
|
path="/tmp/config2",
|
||||||
@ -229,6 +235,8 @@ def test_filter_2(store: ModelRecordServiceBase):
|
|||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType.Main,
|
type=ModelType.Main,
|
||||||
hash="CONFIG2HASH",
|
hash="CONFIG2HASH",
|
||||||
|
source="test/source/",
|
||||||
|
source_type=ModelSourceType.Path,
|
||||||
)
|
)
|
||||||
config3 = MainDiffusersConfig(
|
config3 = MainDiffusersConfig(
|
||||||
path="/tmp/config3",
|
path="/tmp/config3",
|
||||||
@ -236,6 +244,8 @@ def test_filter_2(store: ModelRecordServiceBase):
|
|||||||
base=BaseModelType("sd-2"),
|
base=BaseModelType("sd-2"),
|
||||||
type=ModelType.Main,
|
type=ModelType.Main,
|
||||||
hash="CONFIG3HASH",
|
hash="CONFIG3HASH",
|
||||||
|
source="test/source/",
|
||||||
|
source_type=ModelSourceType.Path,
|
||||||
)
|
)
|
||||||
config4 = MainDiffusersConfig(
|
config4 = MainDiffusersConfig(
|
||||||
path="/tmp/config4",
|
path="/tmp/config4",
|
||||||
@ -243,6 +253,8 @@ def test_filter_2(store: ModelRecordServiceBase):
|
|||||||
base=BaseModelType("sdxl"),
|
base=BaseModelType("sdxl"),
|
||||||
type=ModelType.Main,
|
type=ModelType.Main,
|
||||||
hash="CONFIG3HASH",
|
hash="CONFIG3HASH",
|
||||||
|
source="test/source/",
|
||||||
|
source_type=ModelSourceType.Path,
|
||||||
)
|
)
|
||||||
config5 = VaeDiffusersConfig(
|
config5 = VaeDiffusersConfig(
|
||||||
path="/tmp/config5",
|
path="/tmp/config5",
|
||||||
@ -250,9 +262,11 @@ def test_filter_2(store: ModelRecordServiceBase):
|
|||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType.Vae,
|
type=ModelType.Vae,
|
||||||
hash="CONFIG3HASH",
|
hash="CONFIG3HASH",
|
||||||
|
source="test/source/",
|
||||||
|
source_type=ModelSourceType.Path,
|
||||||
)
|
)
|
||||||
for c in config1, config2, config3, config4, config5:
|
for c in config1, config2, config3, config4, config5:
|
||||||
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
|
store.add_model(c)
|
||||||
|
|
||||||
matches = store.search_by_attr(
|
matches = store.search_by_attr(
|
||||||
model_type=ModelType.Main,
|
model_type=ModelType.Main,
|
||||||
@ -272,50 +286,3 @@ def test_filter_2(store: ModelRecordServiceBase):
|
|||||||
model_name="dup_name1",
|
model_name="dup_name1",
|
||||||
)
|
)
|
||||||
assert len(matches) == 1
|
assert len(matches) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_summary(mm2_record_store: ModelRecordServiceSQL) -> None:
|
|
||||||
# The fixture provides us with five configs.
|
|
||||||
for x in range(1, 5):
|
|
||||||
key = f"test_config_{x}"
|
|
||||||
name = f"name_{x}"
|
|
||||||
author = f"author_{x}"
|
|
||||||
tags = {f"tag{y}" for y in range(1, x)}
|
|
||||||
mm2_record_store.metadata_store.add_metadata(
|
|
||||||
model_key=key, metadata=BaseMetadata(name=name, author=author, tags=tags)
|
|
||||||
)
|
|
||||||
# sanity check that the tags sent in all right
|
|
||||||
assert mm2_record_store.get_metadata("test_config_3").tags == {"tag1", "tag2"}
|
|
||||||
assert mm2_record_store.get_metadata("test_config_4").tags == {"tag1", "tag2", "tag3"}
|
|
||||||
|
|
||||||
# get summary
|
|
||||||
summary1 = mm2_record_store.list_models(page=0, per_page=100)
|
|
||||||
assert summary1.page == 0
|
|
||||||
assert summary1.pages == 1
|
|
||||||
assert summary1.per_page == 100
|
|
||||||
assert summary1.total == 5
|
|
||||||
assert len(summary1.items) == 5
|
|
||||||
assert summary1.items[0].name == "test5" # lora / sd-1 / diffusers / test5
|
|
||||||
|
|
||||||
# find test_config_3
|
|
||||||
config3 = [x for x in summary1.items if x.key == "test_config_3"][0]
|
|
||||||
assert config3.description == "This is test 3"
|
|
||||||
assert config3.tags == {"tag1", "tag2"}
|
|
||||||
|
|
||||||
# find test_config_5
|
|
||||||
config5 = [x for x in summary1.items if x.key == "test_config_5"][0]
|
|
||||||
assert config5.tags == set()
|
|
||||||
assert config5.description == ""
|
|
||||||
|
|
||||||
# test paging
|
|
||||||
summary2 = mm2_record_store.list_models(page=1, per_page=2)
|
|
||||||
assert summary2.page == 1
|
|
||||||
assert summary2.per_page == 2
|
|
||||||
assert summary2.pages == 3
|
|
||||||
assert summary1.items[2].name == summary2.items[0].name
|
|
||||||
|
|
||||||
# test sorting
|
|
||||||
summary = mm2_record_store.list_models(page=0, per_page=100, order_by=ModelRecordOrderBy.Name)
|
|
||||||
print(summary.items)
|
|
||||||
assert summary.items[0].name == "model1"
|
|
||||||
assert summary.items[-1].name == "test5"
|
|
||||||
|
@ -18,12 +18,17 @@ from invokeai.app.services.events.events_base import EventServiceBase
|
|||||||
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
||||||
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase
|
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase
|
||||||
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase
|
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase
|
||||||
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
|
LoRADiffusersConfig,
|
||||||
|
MainCheckpointConfig,
|
||||||
|
MainDiffusersConfig,
|
||||||
ModelFormat,
|
ModelFormat,
|
||||||
|
ModelSourceType,
|
||||||
ModelType,
|
ModelType,
|
||||||
|
ModelVariantType,
|
||||||
|
VaeDiffusersConfig,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
@ -107,11 +112,6 @@ def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> Downloa
|
|||||||
return download_queue
|
return download_queue
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
|
|
||||||
return mm2_record_store.metadata_store
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
|
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
|
||||||
ram_cache = ModelCache(
|
ram_cache = ModelCache(
|
||||||
@ -137,7 +137,7 @@ def mm2_installer(
|
|||||||
logger = InvokeAILogger.get_logger()
|
logger = InvokeAILogger.get_logger()
|
||||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||||
events = DummyEventService()
|
events = DummyEventService()
|
||||||
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
store = ModelRecordServiceSQL(db)
|
||||||
|
|
||||||
installer = ModelInstallService(
|
installer = ModelInstallService(
|
||||||
app_config=mm2_app_config,
|
app_config=mm2_app_config,
|
||||||
@ -160,61 +160,71 @@ def mm2_installer(
|
|||||||
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
||||||
logger = InvokeAILogger.get_logger(config=mm2_app_config)
|
logger = InvokeAILogger.get_logger(config=mm2_app_config)
|
||||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||||
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
store = ModelRecordServiceSQL(db)
|
||||||
# add five simple config records to the database
|
# add five simple config records to the database
|
||||||
raw1 = {
|
config1 = VaeDiffusersConfig(
|
||||||
"path": "/tmp/foo1",
|
key="test_config_1",
|
||||||
"format": ModelFormat("diffusers"),
|
path="/tmp/foo1",
|
||||||
"name": "test2",
|
format=ModelFormat.Diffusers,
|
||||||
"base": BaseModelType("sd-2"),
|
name="test2",
|
||||||
"type": ModelType("vae"),
|
base=BaseModelType.StableDiffusion2,
|
||||||
"original_hash": "111222333444",
|
type=ModelType.Vae,
|
||||||
"source": "stabilityai/sdxl-vae",
|
hash="111222333444",
|
||||||
}
|
source="stabilityai/sdxl-vae",
|
||||||
raw2 = {
|
source_type=ModelSourceType.HFRepoID,
|
||||||
"path": "/tmp/foo2.ckpt",
|
)
|
||||||
"name": "model1",
|
config2 = MainCheckpointConfig(
|
||||||
"format": ModelFormat("checkpoint"),
|
key="test_config_2",
|
||||||
"base": BaseModelType("sd-1"),
|
path="/tmp/foo2.ckpt",
|
||||||
"type": "main",
|
name="model1",
|
||||||
"config_path": "/tmp/foo.yaml",
|
format=ModelFormat.Checkpoint,
|
||||||
"variant": "normal",
|
base=BaseModelType.StableDiffusion1,
|
||||||
"original_hash": "111222333444",
|
type=ModelType.Main,
|
||||||
"source": "https://civitai.com/models/206883/split",
|
config_path="/tmp/foo.yaml",
|
||||||
}
|
variant=ModelVariantType.Normal,
|
||||||
raw3 = {
|
hash="111222333444",
|
||||||
"path": "/tmp/foo3",
|
source="https://civitai.com/models/206883/split",
|
||||||
"format": ModelFormat("diffusers"),
|
source_type=ModelSourceType.CivitAI,
|
||||||
"name": "test3",
|
)
|
||||||
"base": BaseModelType("sdxl"),
|
config3 = MainDiffusersConfig(
|
||||||
"type": ModelType("main"),
|
key="test_config_3",
|
||||||
"original_hash": "111222333444",
|
path="/tmp/foo3",
|
||||||
"source": "author3/model3",
|
format=ModelFormat.Diffusers,
|
||||||
"description": "This is test 3",
|
name="test3",
|
||||||
}
|
base=BaseModelType.StableDiffusionXL,
|
||||||
raw4 = {
|
type=ModelType.Main,
|
||||||
"path": "/tmp/foo4",
|
hash="111222333444",
|
||||||
"format": ModelFormat("diffusers"),
|
source="author3/model3",
|
||||||
"name": "test4",
|
description="This is test 3",
|
||||||
"base": BaseModelType("sdxl"),
|
source_type=ModelSourceType.HFRepoID,
|
||||||
"type": ModelType("lora"),
|
)
|
||||||
"original_hash": "111222333444",
|
config4 = LoRADiffusersConfig(
|
||||||
"source": "author4/model4",
|
key="test_config_4",
|
||||||
}
|
path="/tmp/foo4",
|
||||||
raw5 = {
|
format=ModelFormat.Diffusers,
|
||||||
"path": "/tmp/foo5",
|
name="test4",
|
||||||
"format": ModelFormat("diffusers"),
|
base=BaseModelType.StableDiffusionXL,
|
||||||
"name": "test5",
|
type=ModelType.Lora,
|
||||||
"base": BaseModelType("sd-1"),
|
hash="111222333444",
|
||||||
"type": ModelType("lora"),
|
source="author4/model4",
|
||||||
"original_hash": "111222333444",
|
source_type=ModelSourceType.HFRepoID,
|
||||||
"source": "author4/model5",
|
)
|
||||||
}
|
config5 = LoRADiffusersConfig(
|
||||||
store.add_model("test_config_1", raw1)
|
key="test_config_5",
|
||||||
store.add_model("test_config_2", raw2)
|
path="/tmp/foo5",
|
||||||
store.add_model("test_config_3", raw3)
|
format=ModelFormat.Diffusers,
|
||||||
store.add_model("test_config_4", raw4)
|
name="test5",
|
||||||
store.add_model("test_config_5", raw5)
|
base=BaseModelType.StableDiffusion1,
|
||||||
|
type=ModelType.Lora,
|
||||||
|
hash="111222333444",
|
||||||
|
source="author4/model5",
|
||||||
|
source_type=ModelSourceType.HFRepoID,
|
||||||
|
)
|
||||||
|
store.add_model(config1)
|
||||||
|
store.add_model(config2)
|
||||||
|
store.add_model(config3)
|
||||||
|
store.add_model(config4)
|
||||||
|
store.add_model(config5)
|
||||||
return store
|
return store
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,202 +0,0 @@
|
|||||||
"""
|
|
||||||
Test model metadata fetching and storage.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic.networks import HttpUrl
|
|
||||||
from requests.sessions import Session
|
|
||||||
|
|
||||||
from invokeai.app.services.model_metadata import ModelMetadataStoreBase
|
|
||||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
|
||||||
from invokeai.backend.model_manager.metadata import (
|
|
||||||
CivitaiMetadata,
|
|
||||||
CivitaiMetadataFetch,
|
|
||||||
CommercialUsage,
|
|
||||||
HuggingFaceMetadata,
|
|
||||||
HuggingFaceMetadataFetch,
|
|
||||||
UnknownMetadataException,
|
|
||||||
)
|
|
||||||
from invokeai.backend.model_manager.util import select_hf_files
|
|
||||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
|
||||||
|
|
||||||
|
|
||||||
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
|
||||||
tags = {"text-to-image", "diffusers"}
|
|
||||||
input_metadata = HuggingFaceMetadata(
|
|
||||||
name="sdxl-vae",
|
|
||||||
author="stabilityai",
|
|
||||||
tags=tags,
|
|
||||||
id="stabilityai/sdxl-vae",
|
|
||||||
tag_dict={"license": "other"},
|
|
||||||
last_modified=datetime.datetime.now(),
|
|
||||||
)
|
|
||||||
mm2_metadata_store.add_metadata("test_config_1", input_metadata)
|
|
||||||
output_metadata = mm2_metadata_store.get_metadata("test_config_1")
|
|
||||||
assert input_metadata == output_metadata
|
|
||||||
with pytest.raises(UnknownMetadataException):
|
|
||||||
mm2_metadata_store.add_metadata("unknown_key", input_metadata)
|
|
||||||
assert mm2_metadata_store.list_tags() == tags
|
|
||||||
|
|
||||||
|
|
||||||
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
|
||||||
input_metadata = HuggingFaceMetadata(
|
|
||||||
name="sdxl-vae",
|
|
||||||
author="stabilityai",
|
|
||||||
tags={"text-to-image", "diffusers"},
|
|
||||||
id="stabilityai/sdxl-vae",
|
|
||||||
tag_dict={"license": "other"},
|
|
||||||
last_modified=datetime.datetime.now(),
|
|
||||||
)
|
|
||||||
mm2_metadata_store.add_metadata("test_config_1", input_metadata)
|
|
||||||
input_metadata.name = "new-name"
|
|
||||||
mm2_metadata_store.update_metadata("test_config_1", input_metadata)
|
|
||||||
output_metadata = mm2_metadata_store.get_metadata("test_config_1")
|
|
||||||
assert output_metadata.name == "new-name"
|
|
||||||
assert input_metadata == output_metadata
|
|
||||||
|
|
||||||
|
|
||||||
def test_metadata_search(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
|
||||||
metadata1 = HuggingFaceMetadata(
|
|
||||||
name="sdxl-vae",
|
|
||||||
author="stabilityai",
|
|
||||||
tags={"text-to-image", "diffusers"},
|
|
||||||
id="stabilityai/sdxl-vae",
|
|
||||||
tag_dict={"license": "other"},
|
|
||||||
last_modified=datetime.datetime.now(),
|
|
||||||
)
|
|
||||||
metadata2 = HuggingFaceMetadata(
|
|
||||||
name="model2",
|
|
||||||
author="stabilityai",
|
|
||||||
tags={"text-to-image", "diffusers", "community-contributed"},
|
|
||||||
id="author2/model2",
|
|
||||||
tag_dict={"license": "other"},
|
|
||||||
last_modified=datetime.datetime.now(),
|
|
||||||
)
|
|
||||||
metadata3 = HuggingFaceMetadata(
|
|
||||||
name="model3",
|
|
||||||
author="author3",
|
|
||||||
tags={"text-to-image", "checkpoint", "community-contributed"},
|
|
||||||
id="author3/model3",
|
|
||||||
tag_dict={"license": "other"},
|
|
||||||
last_modified=datetime.datetime.now(),
|
|
||||||
)
|
|
||||||
mm2_metadata_store.add_metadata("test_config_1", metadata1)
|
|
||||||
mm2_metadata_store.add_metadata("test_config_2", metadata2)
|
|
||||||
mm2_metadata_store.add_metadata("test_config_3", metadata3)
|
|
||||||
|
|
||||||
matches = mm2_metadata_store.search_by_author("stabilityai")
|
|
||||||
assert len(matches) == 2
|
|
||||||
assert "test_config_1" in matches
|
|
||||||
assert "test_config_2" in matches
|
|
||||||
matches = mm2_metadata_store.search_by_author("Sherlock Holmes")
|
|
||||||
assert not matches
|
|
||||||
|
|
||||||
matches = mm2_metadata_store.search_by_name("model3")
|
|
||||||
assert len(matches) == 1
|
|
||||||
assert "test_config_3" in matches
|
|
||||||
|
|
||||||
matches = mm2_metadata_store.search_by_tag({"text-to-image"})
|
|
||||||
assert len(matches) == 3
|
|
||||||
|
|
||||||
matches = mm2_metadata_store.search_by_tag({"text-to-image", "diffusers"})
|
|
||||||
assert len(matches) == 2
|
|
||||||
assert "test_config_1" in matches
|
|
||||||
assert "test_config_2" in matches
|
|
||||||
|
|
||||||
matches = mm2_metadata_store.search_by_tag({"checkpoint", "community-contributed"})
|
|
||||||
assert len(matches) == 1
|
|
||||||
assert "test_config_3" in matches
|
|
||||||
|
|
||||||
# does the tag table update correctly?
|
|
||||||
matches = mm2_metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
|
|
||||||
assert not matches
|
|
||||||
assert mm2_metadata_store.list_tags() == {"text-to-image", "diffusers", "community-contributed", "checkpoint"}
|
|
||||||
metadata3.tags.add("licensed-for-commercial-use")
|
|
||||||
mm2_metadata_store.update_metadata("test_config_3", metadata3)
|
|
||||||
assert mm2_metadata_store.list_tags() == {
|
|
||||||
"text-to-image",
|
|
||||||
"diffusers",
|
|
||||||
"community-contributed",
|
|
||||||
"checkpoint",
|
|
||||||
"licensed-for-commercial-use",
|
|
||||||
}
|
|
||||||
matches = mm2_metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
|
|
||||||
assert len(matches) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_metadata_civitai_fetch(mm2_session: Session) -> None:
|
|
||||||
fetcher = CivitaiMetadataFetch(mm2_session)
|
|
||||||
metadata = fetcher.from_url(HttpUrl("https://civitai.com/models/215485/SDXL-turbo"))
|
|
||||||
assert isinstance(metadata, CivitaiMetadata)
|
|
||||||
assert metadata.id == 215485
|
|
||||||
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
|
|
||||||
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
|
|
||||||
assert CommercialUsage("RentCivit") in metadata.restrictions.AllowCommercialUse
|
|
||||||
assert metadata.version_id == 242807
|
|
||||||
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_metadata_hf_fetch(mm2_session: Session) -> None:
|
|
||||||
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
|
||||||
metadata = fetcher.from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
|
|
||||||
assert isinstance(metadata, HuggingFaceMetadata)
|
|
||||||
assert metadata.author == "test_author" # this is not the same as the original
|
|
||||||
assert metadata.files
|
|
||||||
assert metadata.tags == {
|
|
||||||
"diffusers",
|
|
||||||
"onnx",
|
|
||||||
"safetensors",
|
|
||||||
"text-to-image",
|
|
||||||
"license:other",
|
|
||||||
"has_space",
|
|
||||||
"diffusers:StableDiffusionXLPipeline",
|
|
||||||
"region:us",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_metadata_hf_filter(mm2_session: Session) -> None:
|
|
||||||
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
|
|
||||||
assert isinstance(metadata, HuggingFaceMetadata)
|
|
||||||
files = [x.path for x in metadata.files]
|
|
||||||
fp16_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp16"))
|
|
||||||
assert Path("sdxl-turbo/text_encoder/model.fp16.safetensors") in fp16_files
|
|
||||||
assert Path("sdxl-turbo/text_encoder/model.safetensors") not in fp16_files
|
|
||||||
|
|
||||||
fp32_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp32"))
|
|
||||||
assert Path("sdxl-turbo/text_encoder/model.safetensors") in fp32_files
|
|
||||||
assert Path("sdxl-turbo/text_encoder/model.16.safetensors") not in fp32_files
|
|
||||||
|
|
||||||
onnx_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("onnx"))
|
|
||||||
assert Path("sdxl-turbo/text_encoder/model.onnx") in onnx_files
|
|
||||||
assert Path("sdxl-turbo/text_encoder/model.safetensors") not in onnx_files
|
|
||||||
|
|
||||||
default_files = select_hf_files.filter_files(files)
|
|
||||||
assert Path("sdxl-turbo/text_encoder/model.safetensors") in default_files
|
|
||||||
assert Path("sdxl-turbo/text_encoder/model.16.safetensors") not in default_files
|
|
||||||
|
|
||||||
openvino_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("openvino"))
|
|
||||||
print(openvino_files)
|
|
||||||
assert len(openvino_files) == 0
|
|
||||||
|
|
||||||
flax_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("flax"))
|
|
||||||
print(flax_files)
|
|
||||||
assert not flax_files
|
|
||||||
|
|
||||||
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(
|
|
||||||
HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo-nofp16")
|
|
||||||
)
|
|
||||||
assert isinstance(metadata, HuggingFaceMetadata)
|
|
||||||
files = [x.path for x in metadata.files]
|
|
||||||
filtered_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp16"))
|
|
||||||
assert (
|
|
||||||
Path("sdxl-turbo-nofp16/text_encoder/model.safetensors") in filtered_files
|
|
||||||
) # confirm that default is returned
|
|
||||||
assert Path("sdxl-turbo-nofp16/text_encoder/model.16.safetensors") not in filtered_files
|
|
||||||
|
|
||||||
|
|
||||||
def test_metadata_hf_urls(mm2_session: Session) -> None:
|
|
||||||
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
|
|
||||||
assert isinstance(metadata, HuggingFaceMetadata)
|
|
Loading…
Reference in New Issue
Block a user