add unit tests and documentation

This commit is contained in:
Lincoln Stein 2023-12-20 17:58:34 -05:00
parent 1940169925
commit a626ca3e1c
12 changed files with 518 additions and 105 deletions

View File

@ -15,7 +15,12 @@ model. These are the:
their metadata, and `ModelRecordServiceBase` to store that
information. It is also responsible for managing the InvokeAI
`models` directory and its contents.
* _ModelMetadataStore_ and _ModelMetaDataFetch_ Backend modules that
are able to retrieve metadata from online model repositories,
transform them into Pydantic models, and cache them to the InvokeAI
SQL database.
* _DownloadQueueServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
A multithreaded downloader responsible
for downloading models from a remote source to disk. The download
@ -1184,3 +1189,246 @@ other resources that it might have been using.
This will start/pause/cancel all jobs that have been submitted to the
queue and have not yet reached a terminal state.
***
## This Meta be Good: Model Metadata Storage
The modules found under `invokeai.backend.model_manager.metadata`
provide a straightforward API for fetching model metadatda from online
repositories. Currently two repositories are supported: HuggingFace
and Civitai. However, the modules are easily extended for additional
repos, provided that they have defined APIs for metadata access.
Metadata comprises any descriptive information that is not essential
for getting the model to run. For example "author" is metadata, while
"type", "base" and "format" are not. The latter fields are part of the
model's config, as defined in `invokeai.backend.model_manager.config`.
### Example Usage:
```
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
CivitaiMetadataFetch,
CivitaiMetadata
ModelMetadataStore,
)
# to access the initialized sql database
from invokeai.app.api.dependencies import ApiDependencies
civitai = CivitaiMetadataFetch()
# fetch the metadata
model_metadata = civitai.from_url("https://civitai.com/models/215796")
# get some common metadata fields
author = model_metadata.author
tags = model_metadata.tags
# get some Civitai-specific fields
assert isinstance(model_metadata, CivitaiMetadata)
trained_words = model_metadata.trained_words
base_model = model_metadata.base_model_trained_on
thumbnail = model_metadata.thumbnail_url
# cache the metadata to the database using the key corresponding to
# an existing model config record in the `model_config` table
sql_cache = ModelMetadataStore(ApiDependencies.invoker.services.db)
sql_cache.add_metadata('fb237ace520b6716adc98bcb16e8462c', model_metadata)
# now we can search the database by tag, author or model name
# matches will contain a list of model keys that match the search
matches = sql_cache.search_by_tag({"tool", "turbo"})
```
### Structure of the Metadata objects
There is a short class hierarchy of Metadata objects, all of which
descend from the Pydantic `BaseModel`.
#### `ModelMetadataBase`
This is the common base class for metadata:
| **Field Name** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `name` | str | Repository's name for the model |
| `author` | str | Model's author |
| `tags` | Set[str] | Model tags |
Note that the model config record also has a `name` field. It is
intended that the config record version be locally customizable, while
the metadata version is read-only. However, enforcing this is expected
to be part of the business logic.
Descendents of the base add additional fields.
#### `HuggingFaceMetadata`
This descends from `ModelMetadataBase` and adds the following fields:
| **Field Name** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `type` | Literal["huggingface"] | Used for the discriminated union of metadata classes|
| `id` | str | HuggingFace repo_id |
| `tag_dict` | Dict[str, Any] | A dictionary of tag/value pairs provided in addition to `tags` |
| `last_modified`| datetime | Date of last commit of this model to the repo |
| `files` | List[Path] | List of the files in the model repo |
#### `CivitaiMetadata`
This descends from `ModelMetadataBase` and adds the following fields:
| **Field Name** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `type` | Literal["civitai"] | Used for the discriminated union of metadata classes|
| `id` | int | Civitai model id |
| `version_name` | str | Name of this version of the model (distinct from model name) |
| `version_id` | int | Civitai model version id (distinct from model id) |
| `created` | datetime | Date the model was uploaded to Civitai; no modification date provided |
| `description` | str | Model description. Quite verbose and contains HTML tags |
| `version_description` | str | Model version description, usually describes changes to the model |
| `nsfw` | bool | Whether the model tends to generate NSFW content |
| `restrictions` | LicenseRestrictions | An object that describes what is and isn't allowed with this model |
| `trained_words`| Set[str] | Trigger words for this model, if any |
| `download_url` | AnyHttpUrl | URL for downloading this version of the model |
| `base_model_trained_on` | str | Name of the model that this version was trained on |
| `thumbnail_url` | AnyHttpUrl | URL to access a representative thumbnail image of the model's output |
| `weight_min` | int | For LoRA sliders, the minimum suggested weight to apply |
| `weight_max` | int | For LoRA sliders, the maximum suggested weight to apply |
Note that `weight_min` and `weight_max` are not currently populated
and take the default values of (-1.0, +2.0). The issue is that these
values aren't part of the structured data but appear in the text
description. Some regular expression or LLM coding may be able to
extract these values.
Also be aware that `base_model_trained_on` is free text and doesn't
correspond to our `ModelType` enum.
`CivitaiMetadata` also defines some convenience properties relating to
licensing restrictions: `credit_required`, `allow_commercial_use`,
`allow_derivatives` and `allow_different_license`.
#### `AnyModelRepoMetadata`
This is a discriminated Union of `CivitaiMetadata` and
`HuggingFaceMetadata`.
### Fetching Metadata from Online Repos
The `HuggingFaceMetadataFetch` and `CivitaiMetadataFetch` classes will
retrieve metadata from their corresponding repositories and return
`AnyModelRepoMetadata` objects. Their base class
`ModelMetadataFetchBase` is an abstract class that defines two
methods: `from_url()` and `from_id()`. The former accepts the type of
model URLs that the user will try to cut and paste into the model
import form. The latter accepts a string ID in the format recognized
by the repository of choice. Both methods return an
`AnyModelRepoMetadata`.
The base class also has a class method `from_json()` which will take
the JSON representation of a `ModelMetadata` object, validate it, and
return the corresponding `AnyModelRepoMetadata` object.
When initializing one of the metadata fetching classes, you may
provide a `requests.Session` argument. This allows you to customize
the low-level HTTP fetch requests and is used, for instance, in the
testing suite to avoid hitting the internet.
The HuggingFace and Civitai fetcher subclasses add additional
repo-specific fetching methods:
#### HuggingFaceMetadataFetch
This overrides its base class `from_json()` method to return a
`HuggingFaceMetadata` object directly.
#### CivitaiMetadataFetch
This adds the following methods:
`from_civitai_modelid()` This takes the ID of a model, finds the
default version of the model, and then retrieves the metadata for
that version, returning a `CivitaiMetadata` object directly.
`from_civitai_versionid()` This takes the ID of a model version and
retrieves its metadata. Functionally equivalent to `from_id()`, the
only difference is that it returna a `CivitaiMetadata` object rather
than an `AnyModelRepoMetadata`.
### Metadata Storage
The `ModelMetadataStore` provides a simple facility to store model
metadata in the `invokeai.db` database. The data is stored as a JSON
blob, with a few common fields (`name`, `author`, `tags`) broken out
to be searchable.
When a metadata object is saved to the database, it is identified
using the model key, _and this key must correspond to an existing
model key in the model_config table_. There is a foreign key integrity
constraint between the `model_config.id` field and the
`model_metadata.id` field such that if you attempt to save metadata
under an unknown key, the attempt will result in an
`UnknownModelException`. Likewise, when a model is deleted from
`model_config`, the deletion of the corresponding metadata record will
be triggered.
Tags are stored in a normalized fashion in the tables `model_tags` and
`tags`. Triggers keep the tag table in sync with the `model_metadata`
table.
To create the storage object, initialize it with the InvokeAI
`SqliteDatabase` object. This is often done this way:
```
from invokeai.app.api.dependencies import ApiDependencies
metadata_store = ModelMetadataStore(ApiDependencies.invoker.services.db)
```
You can then access the storage with the following methods:
#### `add_metadata(key, metadata)`
Add the metadata using a previously-defined model key.
There is currently no `delete_metadata()` method. The metadata will
persist until the matching config is deleted from the `model_config`
table.
#### `get_metadata(key) -> AnyModelRepoMetadata`
Retrieve the metadata corresponding to the model key.
#### `update_metadata(key, new_metadata)`
Update an existing metadata record with new metadata.
#### `search_by_tag(tags: Set[str]) -> Set[str]`
Given a set of tags, find models that are tagged with them. If
multiple tags are provided then a matching model must be tagged with
*all* the tags in the set. This method returns a set of model keys and
is intended to be used in conjunction with the `ModelRecordService`:
```
model_config_store = ApiDependencies.invoker.services.model_records
matches = metadata_store.search_by_tag({'license:other'})
models = [model_config_store.get(x) for x in matches]
```
#### `search_by_name(name: str) -> Set[str]
Find all model metadata records that have the given name and return a
set of keys to the corresponding model config objects.
#### `search_by_author(author: str) -> Set[str]
Find all model metadata records that have the given author and return
a set of keys to the corresponding model config objects.

View File

@ -29,7 +29,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator = SqliteMigrator(db=db)
migrator.register_migration(build_migration_1())
migrator.register_migration(build_migration_2(image_files=image_files, logger=logger))
migrator.register_migration(build_migration_3())
migrator.register_migration(build_migration_3(app_config=config, logger=logger))
migrator.register_migration(build_migration_4())
migrator.run_migrations()

View File

@ -11,8 +11,6 @@ from invokeai.app.services.workflow_records.workflow_records_common import (
UnsafeWorkflowWithVersionValidator,
)
from .util.migrate_yaml_config_1 import MigrateModelYamlToDb1
class Migration2Callback:
def __init__(self, image_files: ImageFileStorageBase, logger: Logger):
@ -25,8 +23,6 @@ class Migration2Callback:
self._drop_old_workflow_tables(cursor)
self._add_workflow_library(cursor)
self._drop_model_manager_metadata(cursor)
self._recreate_model_config(cursor)
self._migrate_model_config_records(cursor)
self._migrate_embedded_workflows(cursor)
def _add_images_has_workflow(self, cursor: sqlite3.Cursor) -> None:
@ -100,45 +96,6 @@ class Migration2Callback:
"""Drops the `model_manager_metadata` table."""
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
def _recreate_model_config(self, cursor: sqlite3.Cursor) -> None:
"""
Drops the `model_config` table, recreating it.
In 3.4.0, this table used explicit columns but was changed to use json_extract 3.5.0.
Because this table is not used in production, we are able to simply drop it and recreate it.
"""
cursor.execute("DROP TABLE IF EXISTS model_config;")
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
)
def _migrate_model_config_records(self, cursor: sqlite3.Cursor) -> None:
"""After updating the model config table, we repopulate it."""
model_record_migrator = MigrateModelYamlToDb1(cursor)
model_record_migrator.migrate()
def _migrate_embedded_workflows(self, cursor: sqlite3.Cursor) -> None:
"""
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in

View File

@ -1,13 +1,16 @@
import sqlite3
from logging import Logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
from .util.migrate_yaml_config_1 import MigrateModelYamlToDb1
class Migration3Callback:
def __init__(self) -> None:
pass
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
self._app_config = app_config
self._logger = logger
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._drop_model_manager_metadata(cursor)
@ -54,11 +57,12 @@ class Migration3Callback:
def _migrate_model_config_records(self, cursor: sqlite3.Cursor) -> None:
"""After updating the model config table, we repopulate it."""
model_record_migrator = MigrateModelYamlToDb1(cursor)
self._logger.info("Migrating model config records from models.yaml to database")
model_record_migrator = MigrateModelYamlToDb1(self._app_config, self._logger, cursor)
model_record_migrator.migrate()
def build_migration_3() -> Migration:
def build_migration_3(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
"""
Build the migration from database version 2 to 3.
@ -69,7 +73,7 @@ def build_migration_3() -> Migration:
migration_3 = Migration(
from_version=2,
to_version=3,
callback=Migration3Callback(),
callback=Migration3Callback(app_config=app_config, logger=logger),
)
return migration_3

View File

@ -23,7 +23,6 @@ from invokeai.backend.model_manager.config import (
ModelType,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.util.logging import InvokeAILogger
ModelsValidator = TypeAdapter(AnyModelConfig)
@ -46,10 +45,9 @@ class MigrateModelYamlToDb1:
logger: Logger
cursor: sqlite3.Cursor
def __init__(self, cursor: sqlite3.Cursor = None) -> None:
self.config = InvokeAIAppConfig.get_config()
self.config.parse_args()
self.logger = InvokeAILogger.get_logger()
def __init__(self, config: InvokeAIAppConfig, logger: Logger, cursor: sqlite3.Cursor = None) -> None:
self.config = config
self.logger = logger
self.cursor = cursor
def get_yaml(self) -> DictConfig:

View File

@ -19,6 +19,7 @@ if data.allow_commercial_use:
print("Commercial use of this model is allowed")
"""
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch
from .metadata_base import (
AnyModelRepoMetadata,
AnyModelRepoMetadataValidator,
@ -37,4 +38,6 @@ __all__ = [
"HuggingFaceMetadata",
"CivitaiMetadata",
"ModelMetadataStore",
"CivitaiMetadataFetch",
"HuggingFaceMetadataFetch",
]

View File

@ -77,10 +77,10 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
return self.from_civitai_modelid(int(model_id))
elif match := re.match(CIVITAI_VERSION_PAGE_RE, str(url)):
version_id = match.group(1)
return self._from_civitai_versionid(int(version_id))
return self.from_civitai_versionid(int(version_id))
elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url)):
version_id = match.group(1)
return self._from_civitai_versionid(int(version_id))
return self.from_civitai_versionid(int(version_id))
raise UnknownModelException("The url '{url}' does not match any known Civitai URL patterns")
def from_id(self, id: str) -> AnyModelRepoMetadata:
@ -89,7 +89,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
May raise an `UnknownModelException`.
"""
return self._from_civitai_versionid(int(id))
return self.from_civitai_versionid(int(id))
def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata:
"""
@ -100,9 +100,9 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model = self._requests.get(model_url).json()
default_version = model["modelVersions"][0]["id"]
return self._from_civitai_versionid(default_version, model)
return self.from_civitai_versionid(default_version, model)
def _from_civitai_versionid(
def from_civitai_versionid(
self, version_id: int, model_metadata: Optional[Dict[str, Any]] = None
) -> CivitaiMetadata:
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)

View File

@ -28,6 +28,8 @@ from invokeai.app.services.model_records import UnknownModelException
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator, HuggingFaceMetadata
from .fetch_base import ModelMetadataFetchBase
HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)"
class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
"""Fetch model metadata from HuggingFace."""
@ -68,7 +70,6 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
In the case of an invalid or missing URL, raises a ModelNotFound exception.
"""
HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)"
if match := re.match(HF_MODEL_RE, str(url)):
repo_id = match.group(1)
return self.from_id(repo_id)

View File

@ -1,15 +1,17 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
This module defines core text-to-image model metadata fields.
"""This module defines core text-to-image model metadata fields.
Metadata comprises any descriptive information that is not essential
for getting the model to run. For example "author" is metadata, while
"type", "base" and "format" are not. The latter fields are part of the
model's config, as defined in invokeai.backend.model_manager.config.
Note that the "name" and "description" are also present in `config`.
This may need reworking.
Note that the "name" and "description" are also present in `config`
records. This is intentional. The config record fields are intended to
be editable by the user as a form of customization. The metadata
versions of these fields are intended to be kept in sync with the
remote repo.
"""
from datetime import datetime
@ -78,7 +80,7 @@ class CivitaiMetadata(ModelMetadataBase):
description="text description of the model's reversion; usually change history; may contain HTML"
)
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
restrictions: LicenseRestrictions = Field(description="license terms", default=LicenseRestrictions)
restrictions: LicenseRestrictions = Field(description="license terms", default_factory=LicenseRestrictions)
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
download_url: AnyHttpUrl = Field(description="download URL for this model")
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
@ -98,7 +100,7 @@ class CivitaiMetadata(ModelMetadataBase):
@property
def allow_commercial_use(self) -> bool:
"""Return True if commercial use is allowed."""
return self.restrictions.AllowCommercialUse == CommercialUsage("None")
return self.restrictions.AllowCommercialUse != CommercialUsage("None")
@property
def allow_derivatives(self) -> bool:

View File

@ -4,7 +4,7 @@ SQL Storage for Model Metadata
"""
import sqlite3
from typing import Set
from typing import Optional, Set
from invokeai.app.services.model_records import UnknownModelException
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
@ -59,9 +59,12 @@ class ModelMetadataStore:
)
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.Error as e:
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback()
raise e
raise UnknownModelException from excp
except sqlite3.Error as excp:
self._db.conn.rollback()
raise excp
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
@ -111,19 +114,21 @@ class ModelMetadataStore:
"""Return the keys of models containing all of the listed tags."""
with self._db.lock:
try:
matches: Set[str] = set()
matches: Optional[Set[str]] = None
for tag in tags:
self._cursor.execute(
"""--sql
SELECT a.id FROM model_tags AS a,
tags AS b
tags AS b
WHERE a.tag_id=b.tag_id
AND b.tag_text=?;
""",
(tag,),
)
model_keys = {x[0] for x in self._cursor.fetchall()}
matches = matches.intersection(model_keys) if len(matches) > 0 else model_keys
if matches is None:
matches = model_keys
matches = matches.intersection(model_keys)
except sqlite3.Error as e:
raise e
return matches
@ -139,6 +144,24 @@ class ModelMetadataStore:
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE name=?;
""",
(name,),
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
"""Update tags for the model referenced by model_key."""
# remove previous tags from this model

File diff suppressed because one or more lines are too long

View File

@ -1,34 +1,39 @@
"""
Test model metadata fetching and storage.
"""
import pytest
import datetime
from pathlib import Path
from typing import Any, Dict, List
from pydantic import BaseModel, ValidationError
import pytest
import requests
from pydantic.networks import HttpUrl
from requests.sessions import Session
from requests_testadapter import TestAdapter
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
from invokeai.app.services.model_records import ModelRecordServiceSQL, UnknownModelException
from invokeai.backend.model_manager.config import (
BaseModelType,
MainCheckpointConfig,
MainDiffusersConfig,
ModelFormat,
ModelType,
TextualInversionConfig,
VaeDiffusersConfig,
)
from invokeai.backend.model_manager.metadata import (
ModelMetadataStore,
AnyModelRepoMetadata,
CommercialUsage,
LicenseRestrictions,
HuggingFaceMetadata,
CivitaiMetadata,
CivitaiMetadataFetch,
CommercialUsage,
HuggingFaceMetadata,
HuggingFaceMetadataFetch,
ModelMetadataStore,
)
from invokeai.backend.util.logging import InvokeAILogger
from tests.app.services.model_metadata.metadata_examples import (
RepoCivitaiModelMetadata1,
RepoCivitaiVersionMetadata1,
RepoHFMetadata1,
)
from tests.fixtures.sqlite_database import create_mock_sqlite_database
@pytest.fixture
def app_config(datadir: Path) -> InvokeAIAppConfig:
return InvokeAIAppConfig(
@ -36,49 +41,204 @@ def app_config(datadir: Path) -> InvokeAIAppConfig:
models_dir=datadir / "root/models",
)
@pytest.fixture
def record_store(app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
logger = InvokeAILogger.get_logger(config=app_config)
db = create_mock_sqlite_database(app_config, logger)
store = ModelRecordServiceSQL(db)
# add two config records to the database
# add three simple config records to the database
raw1 = {
"path": "/tmp/foo2.ckpt",
"path": "/tmp/foo1",
"format": ModelFormat("diffusers"),
"name": "test2",
"base": BaseModelType("sd-2"),
"type": ModelType("vae"),
"original_hash":"111222333444",
"original_hash": "111222333444",
"source": "stabilityai/sdxl-vae",
}
raw2 = {
"path": "/tmp/foo1.ckpt",
"path": "/tmp/foo2.ckpt",
"name": "model1",
"format": ModelFormat("checkpoint"),
"base": BaseModelType("sd-1"),
"type": "main",
"config": "/tmp/foo.yaml",
"variant": "normal",
"format": "checkpoint",
"original_hash": "111222333444",
"source": "https://civitai.com/models/206883/split",
}
store.add_model('test_config_1', raw1)
store.add_model('test_config_2', raw2)
raw3 = {
"path": "/tmp/foo3",
"format": ModelFormat("diffusers"),
"name": "test3",
"base": BaseModelType("sdxl"),
"type": ModelType("main"),
"original_hash": "111222333444",
"source": "author3/model3",
}
store.add_model("test_config_1", raw1)
store.add_model("test_config_2", raw2)
store.add_model("test_config_3", raw3)
return store
@pytest.fixture
def session() -> Session:
sess = requests.Session()
sess.mount(
"https://huggingface.co/api/models/stabilityai/sdxl-turbo",
TestAdapter(
RepoHFMetadata1,
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
),
)
sess.mount(
"https://civitai.com/api/v1/model-versions/242807",
TestAdapter(
RepoCivitaiVersionMetadata1,
headers={
"Content-Length": len(RepoCivitaiVersionMetadata1),
},
),
)
sess.mount(
"https://civitai.com/api/v1/models/215485",
TestAdapter(
RepoCivitaiModelMetadata1,
headers={
"Content-Length": len(RepoCivitaiModelMetadata1),
},
),
)
return sess
@pytest.fixture
def metadata_store(record_store: ModelRecordServiceSQL) -> ModelMetadataStore:
db = record_store._db # to ensure we are sharing the same database
db = record_store._db # to ensure we are sharing the same database
return ModelMetadataStore(db)
def test_metadata_store_put_get(metadata_store: ModelMetadataStore) -> None:
input_metadata = HuggingFaceMetadata(name="sdxl-vae",
author="stabilityai",
tags={"text-to-image","diffusers"},
id="stabilityai/sdxl-vae",
tag_dict={"license":"other"},
last_modified=datetime.datetime.now(),
)
metadata_store.add_metadata('test_config_1',input_metadata)
output_metadata = metadata_store.get_metadata('test_config_1')
input_metadata = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
tags={"text-to-image", "diffusers"},
id="stabilityai/sdxl-vae",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
metadata_store.add_metadata("test_config_1", input_metadata)
output_metadata = metadata_store.get_metadata("test_config_1")
assert input_metadata == output_metadata
with pytest.raises(UnknownModelException):
metadata_store.add_metadata("unknown_key", input_metadata)
def test_metadata_store_update(metadata_store: ModelMetadataStore) -> None:
input_metadata = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
tags={"text-to-image", "diffusers"},
id="stabilityai/sdxl-vae",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
metadata_store.add_metadata("test_config_1", input_metadata)
input_metadata.name = "new-name"
metadata_store.update_metadata("test_config_1", input_metadata)
output_metadata = metadata_store.get_metadata("test_config_1")
assert output_metadata.name == "new-name"
assert input_metadata == output_metadata
def test_metadata_search(metadata_store: ModelMetadataStore) -> None:
metadata1 = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
tags={"text-to-image", "diffusers"},
id="stabilityai/sdxl-vae",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
metadata2 = HuggingFaceMetadata(
name="model2",
author="stabilityai",
tags={"text-to-image", "diffusers", "community-contributed"},
id="author2/model2",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
metadata3 = HuggingFaceMetadata(
name="model3",
author="author3",
tags={"text-to-image", "checkpoint", "community-contributed"},
id="author3/model3",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
metadata_store.add_metadata("test_config_1", metadata1)
metadata_store.add_metadata("test_config_2", metadata2)
metadata_store.add_metadata("test_config_3", metadata3)
matches = metadata_store.search_by_author("stabilityai")
assert len(matches) == 2
assert "test_config_1" in matches
assert "test_config_2" in matches
matches = metadata_store.search_by_author("Sherlock Holmes")
assert not matches
matches = metadata_store.search_by_name("model3")
assert len(matches) == 1
assert "test_config_3" in matches
matches = metadata_store.search_by_tag({"text-to-image"})
assert len(matches) == 3
matches = metadata_store.search_by_tag({"text-to-image", "diffusers"})
assert len(matches) == 2
assert "test_config_1" in matches
assert "test_config_2" in matches
matches = metadata_store.search_by_tag({"checkpoint", "community-contributed"})
assert len(matches) == 1
assert "test_config_3" in matches
# does the tag table update correctly?
matches = metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
assert not matches
metadata3.tags.add("licensed-for-commercial-use")
metadata_store.update_metadata("test_config_3", metadata3)
matches = metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
assert len(matches) == 1
def test_metadata_civitai_fetch(session: Session) -> None:
fetcher = CivitaiMetadataFetch(session)
metadata = fetcher.from_url(HttpUrl("https://civitai.com/models/215485/SDXL-turbo"))
assert isinstance(metadata, CivitaiMetadata)
assert metadata.id == 215485
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
assert metadata.restrictions.AllowCommercialUse == CommercialUsage("RentCivit")
assert metadata.version_id == 242807
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}
def test_metadata_hf_fetch(session: Session) -> None:
fetcher = HuggingFaceMetadataFetch(session)
metadata = fetcher.from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
assert isinstance(metadata, HuggingFaceMetadata)
assert metadata.author == "test_author" # this is not the same as the original
assert metadata.files
assert metadata.tags == {
"diffusers",
"onnx",
"safetensors",
"text-to-image",
"license:other",
"has_space",
"diffusers:StableDiffusionXLPipeline",
"region:us",
}