mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add unit tests and documentation
This commit is contained in:
parent
1940169925
commit
a626ca3e1c
@ -16,6 +16,11 @@ model. These are the:
|
|||||||
information. It is also responsible for managing the InvokeAI
|
information. It is also responsible for managing the InvokeAI
|
||||||
`models` directory and its contents.
|
`models` directory and its contents.
|
||||||
|
|
||||||
|
* _ModelMetadataStore_ and _ModelMetaDataFetch_ Backend modules that
|
||||||
|
are able to retrieve metadata from online model repositories,
|
||||||
|
transform them into Pydantic models, and cache them to the InvokeAI
|
||||||
|
SQL database.
|
||||||
|
|
||||||
* _DownloadQueueServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
|
* _DownloadQueueServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
|
||||||
A multithreaded downloader responsible
|
A multithreaded downloader responsible
|
||||||
for downloading models from a remote source to disk. The download
|
for downloading models from a remote source to disk. The download
|
||||||
@ -1184,3 +1189,246 @@ other resources that it might have been using.
|
|||||||
This will start/pause/cancel all jobs that have been submitted to the
|
This will start/pause/cancel all jobs that have been submitted to the
|
||||||
queue and have not yet reached a terminal state.
|
queue and have not yet reached a terminal state.
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
## This Meta be Good: Model Metadata Storage
|
||||||
|
|
||||||
|
The modules found under `invokeai.backend.model_manager.metadata`
|
||||||
|
provide a straightforward API for fetching model metadatda from online
|
||||||
|
repositories. Currently two repositories are supported: HuggingFace
|
||||||
|
and Civitai. However, the modules are easily extended for additional
|
||||||
|
repos, provided that they have defined APIs for metadata access.
|
||||||
|
|
||||||
|
Metadata comprises any descriptive information that is not essential
|
||||||
|
for getting the model to run. For example "author" is metadata, while
|
||||||
|
"type", "base" and "format" are not. The latter fields are part of the
|
||||||
|
model's config, as defined in `invokeai.backend.model_manager.config`.
|
||||||
|
|
||||||
|
### Example Usage:
|
||||||
|
|
||||||
|
```
|
||||||
|
from invokeai.backend.model_manager.metadata import (
|
||||||
|
AnyModelRepoMetadata,
|
||||||
|
CivitaiMetadataFetch,
|
||||||
|
CivitaiMetadata
|
||||||
|
ModelMetadataStore,
|
||||||
|
)
|
||||||
|
# to access the initialized sql database
|
||||||
|
from invokeai.app.api.dependencies import ApiDependencies
|
||||||
|
|
||||||
|
civitai = CivitaiMetadataFetch()
|
||||||
|
|
||||||
|
# fetch the metadata
|
||||||
|
model_metadata = civitai.from_url("https://civitai.com/models/215796")
|
||||||
|
|
||||||
|
# get some common metadata fields
|
||||||
|
author = model_metadata.author
|
||||||
|
tags = model_metadata.tags
|
||||||
|
|
||||||
|
# get some Civitai-specific fields
|
||||||
|
assert isinstance(model_metadata, CivitaiMetadata)
|
||||||
|
|
||||||
|
trained_words = model_metadata.trained_words
|
||||||
|
base_model = model_metadata.base_model_trained_on
|
||||||
|
thumbnail = model_metadata.thumbnail_url
|
||||||
|
|
||||||
|
# cache the metadata to the database using the key corresponding to
|
||||||
|
# an existing model config record in the `model_config` table
|
||||||
|
sql_cache = ModelMetadataStore(ApiDependencies.invoker.services.db)
|
||||||
|
sql_cache.add_metadata('fb237ace520b6716adc98bcb16e8462c', model_metadata)
|
||||||
|
|
||||||
|
# now we can search the database by tag, author or model name
|
||||||
|
# matches will contain a list of model keys that match the search
|
||||||
|
matches = sql_cache.search_by_tag({"tool", "turbo"})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Structure of the Metadata objects
|
||||||
|
|
||||||
|
There is a short class hierarchy of Metadata objects, all of which
|
||||||
|
descend from the Pydantic `BaseModel`.
|
||||||
|
|
||||||
|
#### `ModelMetadataBase`
|
||||||
|
|
||||||
|
This is the common base class for metadata:
|
||||||
|
|
||||||
|
| **Field Name** | **Type** | **Description** |
|
||||||
|
|----------------|-----------------|------------------|
|
||||||
|
| `name` | str | Repository's name for the model |
|
||||||
|
| `author` | str | Model's author |
|
||||||
|
| `tags` | Set[str] | Model tags |
|
||||||
|
|
||||||
|
|
||||||
|
Note that the model config record also has a `name` field. It is
|
||||||
|
intended that the config record version be locally customizable, while
|
||||||
|
the metadata version is read-only. However, enforcing this is expected
|
||||||
|
to be part of the business logic.
|
||||||
|
|
||||||
|
Descendents of the base add additional fields.
|
||||||
|
|
||||||
|
#### `HuggingFaceMetadata`
|
||||||
|
|
||||||
|
This descends from `ModelMetadataBase` and adds the following fields:
|
||||||
|
|
||||||
|
| **Field Name** | **Type** | **Description** |
|
||||||
|
|----------------|-----------------|------------------|
|
||||||
|
| `type` | Literal["huggingface"] | Used for the discriminated union of metadata classes|
|
||||||
|
| `id` | str | HuggingFace repo_id |
|
||||||
|
| `tag_dict` | Dict[str, Any] | A dictionary of tag/value pairs provided in addition to `tags` |
|
||||||
|
| `last_modified`| datetime | Date of last commit of this model to the repo |
|
||||||
|
| `files` | List[Path] | List of the files in the model repo |
|
||||||
|
|
||||||
|
|
||||||
|
#### `CivitaiMetadata`
|
||||||
|
|
||||||
|
This descends from `ModelMetadataBase` and adds the following fields:
|
||||||
|
|
||||||
|
| **Field Name** | **Type** | **Description** |
|
||||||
|
|----------------|-----------------|------------------|
|
||||||
|
| `type` | Literal["civitai"] | Used for the discriminated union of metadata classes|
|
||||||
|
| `id` | int | Civitai model id |
|
||||||
|
| `version_name` | str | Name of this version of the model (distinct from model name) |
|
||||||
|
| `version_id` | int | Civitai model version id (distinct from model id) |
|
||||||
|
| `created` | datetime | Date the model was uploaded to Civitai; no modification date provided |
|
||||||
|
| `description` | str | Model description. Quite verbose and contains HTML tags |
|
||||||
|
| `version_description` | str | Model version description, usually describes changes to the model |
|
||||||
|
| `nsfw` | bool | Whether the model tends to generate NSFW content |
|
||||||
|
| `restrictions` | LicenseRestrictions | An object that describes what is and isn't allowed with this model |
|
||||||
|
| `trained_words`| Set[str] | Trigger words for this model, if any |
|
||||||
|
| `download_url` | AnyHttpUrl | URL for downloading this version of the model |
|
||||||
|
| `base_model_trained_on` | str | Name of the model that this version was trained on |
|
||||||
|
| `thumbnail_url` | AnyHttpUrl | URL to access a representative thumbnail image of the model's output |
|
||||||
|
| `weight_min` | int | For LoRA sliders, the minimum suggested weight to apply |
|
||||||
|
| `weight_max` | int | For LoRA sliders, the maximum suggested weight to apply |
|
||||||
|
|
||||||
|
Note that `weight_min` and `weight_max` are not currently populated
|
||||||
|
and take the default values of (-1.0, +2.0). The issue is that these
|
||||||
|
values aren't part of the structured data but appear in the text
|
||||||
|
description. Some regular expression or LLM coding may be able to
|
||||||
|
extract these values.
|
||||||
|
|
||||||
|
Also be aware that `base_model_trained_on` is free text and doesn't
|
||||||
|
correspond to our `ModelType` enum.
|
||||||
|
|
||||||
|
`CivitaiMetadata` also defines some convenience properties relating to
|
||||||
|
licensing restrictions: `credit_required`, `allow_commercial_use`,
|
||||||
|
`allow_derivatives` and `allow_different_license`.
|
||||||
|
|
||||||
|
#### `AnyModelRepoMetadata`
|
||||||
|
|
||||||
|
This is a discriminated Union of `CivitaiMetadata` and
|
||||||
|
`HuggingFaceMetadata`.
|
||||||
|
|
||||||
|
### Fetching Metadata from Online Repos
|
||||||
|
|
||||||
|
The `HuggingFaceMetadataFetch` and `CivitaiMetadataFetch` classes will
|
||||||
|
retrieve metadata from their corresponding repositories and return
|
||||||
|
`AnyModelRepoMetadata` objects. Their base class
|
||||||
|
`ModelMetadataFetchBase` is an abstract class that defines two
|
||||||
|
methods: `from_url()` and `from_id()`. The former accepts the type of
|
||||||
|
model URLs that the user will try to cut and paste into the model
|
||||||
|
import form. The latter accepts a string ID in the format recognized
|
||||||
|
by the repository of choice. Both methods return an
|
||||||
|
`AnyModelRepoMetadata`.
|
||||||
|
|
||||||
|
The base class also has a class method `from_json()` which will take
|
||||||
|
the JSON representation of a `ModelMetadata` object, validate it, and
|
||||||
|
return the corresponding `AnyModelRepoMetadata` object.
|
||||||
|
|
||||||
|
When initializing one of the metadata fetching classes, you may
|
||||||
|
provide a `requests.Session` argument. This allows you to customize
|
||||||
|
the low-level HTTP fetch requests and is used, for instance, in the
|
||||||
|
testing suite to avoid hitting the internet.
|
||||||
|
|
||||||
|
The HuggingFace and Civitai fetcher subclasses add additional
|
||||||
|
repo-specific fetching methods:
|
||||||
|
|
||||||
|
|
||||||
|
#### HuggingFaceMetadataFetch
|
||||||
|
|
||||||
|
This overrides its base class `from_json()` method to return a
|
||||||
|
`HuggingFaceMetadata` object directly.
|
||||||
|
|
||||||
|
#### CivitaiMetadataFetch
|
||||||
|
|
||||||
|
This adds the following methods:
|
||||||
|
|
||||||
|
`from_civitai_modelid()` This takes the ID of a model, finds the
|
||||||
|
default version of the model, and then retrieves the metadata for
|
||||||
|
that version, returning a `CivitaiMetadata` object directly.
|
||||||
|
|
||||||
|
`from_civitai_versionid()` This takes the ID of a model version and
|
||||||
|
retrieves its metadata. Functionally equivalent to `from_id()`, the
|
||||||
|
only difference is that it returna a `CivitaiMetadata` object rather
|
||||||
|
than an `AnyModelRepoMetadata`.
|
||||||
|
|
||||||
|
|
||||||
|
### Metadata Storage
|
||||||
|
|
||||||
|
The `ModelMetadataStore` provides a simple facility to store model
|
||||||
|
metadata in the `invokeai.db` database. The data is stored as a JSON
|
||||||
|
blob, with a few common fields (`name`, `author`, `tags`) broken out
|
||||||
|
to be searchable.
|
||||||
|
|
||||||
|
When a metadata object is saved to the database, it is identified
|
||||||
|
using the model key, _and this key must correspond to an existing
|
||||||
|
model key in the model_config table_. There is a foreign key integrity
|
||||||
|
constraint between the `model_config.id` field and the
|
||||||
|
`model_metadata.id` field such that if you attempt to save metadata
|
||||||
|
under an unknown key, the attempt will result in an
|
||||||
|
`UnknownModelException`. Likewise, when a model is deleted from
|
||||||
|
`model_config`, the deletion of the corresponding metadata record will
|
||||||
|
be triggered.
|
||||||
|
|
||||||
|
Tags are stored in a normalized fashion in the tables `model_tags` and
|
||||||
|
`tags`. Triggers keep the tag table in sync with the `model_metadata`
|
||||||
|
table.
|
||||||
|
|
||||||
|
To create the storage object, initialize it with the InvokeAI
|
||||||
|
`SqliteDatabase` object. This is often done this way:
|
||||||
|
|
||||||
|
```
|
||||||
|
from invokeai.app.api.dependencies import ApiDependencies
|
||||||
|
metadata_store = ModelMetadataStore(ApiDependencies.invoker.services.db)
|
||||||
|
```
|
||||||
|
|
||||||
|
You can then access the storage with the following methods:
|
||||||
|
|
||||||
|
#### `add_metadata(key, metadata)`
|
||||||
|
|
||||||
|
Add the metadata using a previously-defined model key.
|
||||||
|
|
||||||
|
There is currently no `delete_metadata()` method. The metadata will
|
||||||
|
persist until the matching config is deleted from the `model_config`
|
||||||
|
table.
|
||||||
|
|
||||||
|
#### `get_metadata(key) -> AnyModelRepoMetadata`
|
||||||
|
|
||||||
|
Retrieve the metadata corresponding to the model key.
|
||||||
|
|
||||||
|
#### `update_metadata(key, new_metadata)`
|
||||||
|
|
||||||
|
Update an existing metadata record with new metadata.
|
||||||
|
|
||||||
|
#### `search_by_tag(tags: Set[str]) -> Set[str]`
|
||||||
|
|
||||||
|
Given a set of tags, find models that are tagged with them. If
|
||||||
|
multiple tags are provided then a matching model must be tagged with
|
||||||
|
*all* the tags in the set. This method returns a set of model keys and
|
||||||
|
is intended to be used in conjunction with the `ModelRecordService`:
|
||||||
|
|
||||||
|
```
|
||||||
|
model_config_store = ApiDependencies.invoker.services.model_records
|
||||||
|
matches = metadata_store.search_by_tag({'license:other'})
|
||||||
|
models = [model_config_store.get(x) for x in matches]
|
||||||
|
```
|
||||||
|
|
||||||
|
#### `search_by_name(name: str) -> Set[str]
|
||||||
|
|
||||||
|
Find all model metadata records that have the given name and return a
|
||||||
|
set of keys to the corresponding model config objects.
|
||||||
|
|
||||||
|
#### `search_by_author(author: str) -> Set[str]
|
||||||
|
|
||||||
|
Find all model metadata records that have the given author and return
|
||||||
|
a set of keys to the corresponding model config objects.
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
|||||||
migrator = SqliteMigrator(db=db)
|
migrator = SqliteMigrator(db=db)
|
||||||
migrator.register_migration(build_migration_1())
|
migrator.register_migration(build_migration_1())
|
||||||
migrator.register_migration(build_migration_2(image_files=image_files, logger=logger))
|
migrator.register_migration(build_migration_2(image_files=image_files, logger=logger))
|
||||||
migrator.register_migration(build_migration_3())
|
migrator.register_migration(build_migration_3(app_config=config, logger=logger))
|
||||||
migrator.register_migration(build_migration_4())
|
migrator.register_migration(build_migration_4())
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
|
|
||||||
|
@ -11,8 +11,6 @@ from invokeai.app.services.workflow_records.workflow_records_common import (
|
|||||||
UnsafeWorkflowWithVersionValidator,
|
UnsafeWorkflowWithVersionValidator,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .util.migrate_yaml_config_1 import MigrateModelYamlToDb1
|
|
||||||
|
|
||||||
|
|
||||||
class Migration2Callback:
|
class Migration2Callback:
|
||||||
def __init__(self, image_files: ImageFileStorageBase, logger: Logger):
|
def __init__(self, image_files: ImageFileStorageBase, logger: Logger):
|
||||||
@ -25,8 +23,6 @@ class Migration2Callback:
|
|||||||
self._drop_old_workflow_tables(cursor)
|
self._drop_old_workflow_tables(cursor)
|
||||||
self._add_workflow_library(cursor)
|
self._add_workflow_library(cursor)
|
||||||
self._drop_model_manager_metadata(cursor)
|
self._drop_model_manager_metadata(cursor)
|
||||||
self._recreate_model_config(cursor)
|
|
||||||
self._migrate_model_config_records(cursor)
|
|
||||||
self._migrate_embedded_workflows(cursor)
|
self._migrate_embedded_workflows(cursor)
|
||||||
|
|
||||||
def _add_images_has_workflow(self, cursor: sqlite3.Cursor) -> None:
|
def _add_images_has_workflow(self, cursor: sqlite3.Cursor) -> None:
|
||||||
@ -100,45 +96,6 @@ class Migration2Callback:
|
|||||||
"""Drops the `model_manager_metadata` table."""
|
"""Drops the `model_manager_metadata` table."""
|
||||||
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
|
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
|
||||||
|
|
||||||
def _recreate_model_config(self, cursor: sqlite3.Cursor) -> None:
|
|
||||||
"""
|
|
||||||
Drops the `model_config` table, recreating it.
|
|
||||||
|
|
||||||
In 3.4.0, this table used explicit columns but was changed to use json_extract 3.5.0.
|
|
||||||
|
|
||||||
Because this table is not used in production, we are able to simply drop it and recreate it.
|
|
||||||
"""
|
|
||||||
|
|
||||||
cursor.execute("DROP TABLE IF EXISTS model_config;")
|
|
||||||
|
|
||||||
cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TABLE IF NOT EXISTS model_config (
|
|
||||||
id TEXT NOT NULL PRIMARY KEY,
|
|
||||||
-- The next 3 fields are enums in python, unrestricted string here
|
|
||||||
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
|
|
||||||
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
|
|
||||||
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
|
|
||||||
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
|
|
||||||
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
|
|
||||||
original_hash TEXT, -- could be null
|
|
||||||
-- Serialized JSON representation of the whole config object,
|
|
||||||
-- which will contain additional fields from subclasses
|
|
||||||
config TEXT NOT NULL,
|
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
-- Updated via trigger
|
|
||||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
-- unique constraint on combo of name, base and type
|
|
||||||
UNIQUE(name, base, type)
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
def _migrate_model_config_records(self, cursor: sqlite3.Cursor) -> None:
|
|
||||||
"""After updating the model config table, we repopulate it."""
|
|
||||||
model_record_migrator = MigrateModelYamlToDb1(cursor)
|
|
||||||
model_record_migrator.migrate()
|
|
||||||
|
|
||||||
def _migrate_embedded_workflows(self, cursor: sqlite3.Cursor) -> None:
|
def _migrate_embedded_workflows(self, cursor: sqlite3.Cursor) -> None:
|
||||||
"""
|
"""
|
||||||
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
|
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
|
||||||
|
@ -1,13 +1,16 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
|
from logging import Logger
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||||
|
|
||||||
from .util.migrate_yaml_config_1 import MigrateModelYamlToDb1
|
from .util.migrate_yaml_config_1 import MigrateModelYamlToDb1
|
||||||
|
|
||||||
|
|
||||||
class Migration3Callback:
|
class Migration3Callback:
|
||||||
def __init__(self) -> None:
|
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
|
||||||
pass
|
self._app_config = app_config
|
||||||
|
self._logger = logger
|
||||||
|
|
||||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||||
self._drop_model_manager_metadata(cursor)
|
self._drop_model_manager_metadata(cursor)
|
||||||
@ -54,11 +57,12 @@ class Migration3Callback:
|
|||||||
|
|
||||||
def _migrate_model_config_records(self, cursor: sqlite3.Cursor) -> None:
|
def _migrate_model_config_records(self, cursor: sqlite3.Cursor) -> None:
|
||||||
"""After updating the model config table, we repopulate it."""
|
"""After updating the model config table, we repopulate it."""
|
||||||
model_record_migrator = MigrateModelYamlToDb1(cursor)
|
self._logger.info("Migrating model config records from models.yaml to database")
|
||||||
|
model_record_migrator = MigrateModelYamlToDb1(self._app_config, self._logger, cursor)
|
||||||
model_record_migrator.migrate()
|
model_record_migrator.migrate()
|
||||||
|
|
||||||
|
|
||||||
def build_migration_3() -> Migration:
|
def build_migration_3(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
|
||||||
"""
|
"""
|
||||||
Build the migration from database version 2 to 3.
|
Build the migration from database version 2 to 3.
|
||||||
|
|
||||||
@ -69,7 +73,7 @@ def build_migration_3() -> Migration:
|
|||||||
migration_3 = Migration(
|
migration_3 = Migration(
|
||||||
from_version=2,
|
from_version=2,
|
||||||
to_version=3,
|
to_version=3,
|
||||||
callback=Migration3Callback(),
|
callback=Migration3Callback(app_config=app_config, logger=logger),
|
||||||
)
|
)
|
||||||
|
|
||||||
return migration_3
|
return migration_3
|
||||||
|
@ -23,7 +23,6 @@ from invokeai.backend.model_manager.config import (
|
|||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.hash import FastModelHash
|
from invokeai.backend.model_manager.hash import FastModelHash
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
|
||||||
|
|
||||||
ModelsValidator = TypeAdapter(AnyModelConfig)
|
ModelsValidator = TypeAdapter(AnyModelConfig)
|
||||||
|
|
||||||
@ -46,10 +45,9 @@ class MigrateModelYamlToDb1:
|
|||||||
logger: Logger
|
logger: Logger
|
||||||
cursor: sqlite3.Cursor
|
cursor: sqlite3.Cursor
|
||||||
|
|
||||||
def __init__(self, cursor: sqlite3.Cursor = None) -> None:
|
def __init__(self, config: InvokeAIAppConfig, logger: Logger, cursor: sqlite3.Cursor = None) -> None:
|
||||||
self.config = InvokeAIAppConfig.get_config()
|
self.config = config
|
||||||
self.config.parse_args()
|
self.logger = logger
|
||||||
self.logger = InvokeAILogger.get_logger()
|
|
||||||
self.cursor = cursor
|
self.cursor = cursor
|
||||||
|
|
||||||
def get_yaml(self) -> DictConfig:
|
def get_yaml(self) -> DictConfig:
|
||||||
|
@ -19,6 +19,7 @@ if data.allow_commercial_use:
|
|||||||
print("Commercial use of this model is allowed")
|
print("Commercial use of this model is allowed")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch
|
||||||
from .metadata_base import (
|
from .metadata_base import (
|
||||||
AnyModelRepoMetadata,
|
AnyModelRepoMetadata,
|
||||||
AnyModelRepoMetadataValidator,
|
AnyModelRepoMetadataValidator,
|
||||||
@ -37,4 +38,6 @@ __all__ = [
|
|||||||
"HuggingFaceMetadata",
|
"HuggingFaceMetadata",
|
||||||
"CivitaiMetadata",
|
"CivitaiMetadata",
|
||||||
"ModelMetadataStore",
|
"ModelMetadataStore",
|
||||||
|
"CivitaiMetadataFetch",
|
||||||
|
"HuggingFaceMetadataFetch",
|
||||||
]
|
]
|
||||||
|
@ -77,10 +77,10 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
|||||||
return self.from_civitai_modelid(int(model_id))
|
return self.from_civitai_modelid(int(model_id))
|
||||||
elif match := re.match(CIVITAI_VERSION_PAGE_RE, str(url)):
|
elif match := re.match(CIVITAI_VERSION_PAGE_RE, str(url)):
|
||||||
version_id = match.group(1)
|
version_id = match.group(1)
|
||||||
return self._from_civitai_versionid(int(version_id))
|
return self.from_civitai_versionid(int(version_id))
|
||||||
elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url)):
|
elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url)):
|
||||||
version_id = match.group(1)
|
version_id = match.group(1)
|
||||||
return self._from_civitai_versionid(int(version_id))
|
return self.from_civitai_versionid(int(version_id))
|
||||||
raise UnknownModelException("The url '{url}' does not match any known Civitai URL patterns")
|
raise UnknownModelException("The url '{url}' does not match any known Civitai URL patterns")
|
||||||
|
|
||||||
def from_id(self, id: str) -> AnyModelRepoMetadata:
|
def from_id(self, id: str) -> AnyModelRepoMetadata:
|
||||||
@ -89,7 +89,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
|||||||
|
|
||||||
May raise an `UnknownModelException`.
|
May raise an `UnknownModelException`.
|
||||||
"""
|
"""
|
||||||
return self._from_civitai_versionid(int(id))
|
return self.from_civitai_versionid(int(id))
|
||||||
|
|
||||||
def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata:
|
def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata:
|
||||||
"""
|
"""
|
||||||
@ -100,9 +100,9 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
|||||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||||
model = self._requests.get(model_url).json()
|
model = self._requests.get(model_url).json()
|
||||||
default_version = model["modelVersions"][0]["id"]
|
default_version = model["modelVersions"][0]["id"]
|
||||||
return self._from_civitai_versionid(default_version, model)
|
return self.from_civitai_versionid(default_version, model)
|
||||||
|
|
||||||
def _from_civitai_versionid(
|
def from_civitai_versionid(
|
||||||
self, version_id: int, model_metadata: Optional[Dict[str, Any]] = None
|
self, version_id: int, model_metadata: Optional[Dict[str, Any]] = None
|
||||||
) -> CivitaiMetadata:
|
) -> CivitaiMetadata:
|
||||||
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
||||||
|
@ -28,6 +28,8 @@ from invokeai.app.services.model_records import UnknownModelException
|
|||||||
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator, HuggingFaceMetadata
|
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator, HuggingFaceMetadata
|
||||||
from .fetch_base import ModelMetadataFetchBase
|
from .fetch_base import ModelMetadataFetchBase
|
||||||
|
|
||||||
|
HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)"
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||||
"""Fetch model metadata from HuggingFace."""
|
"""Fetch model metadata from HuggingFace."""
|
||||||
@ -68,7 +70,6 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
|||||||
|
|
||||||
In the case of an invalid or missing URL, raises a ModelNotFound exception.
|
In the case of an invalid or missing URL, raises a ModelNotFound exception.
|
||||||
"""
|
"""
|
||||||
HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)"
|
|
||||||
if match := re.match(HF_MODEL_RE, str(url)):
|
if match := re.match(HF_MODEL_RE, str(url)):
|
||||||
repo_id = match.group(1)
|
repo_id = match.group(1)
|
||||||
return self.from_id(repo_id)
|
return self.from_id(repo_id)
|
||||||
|
@ -1,15 +1,17 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||||
|
|
||||||
"""
|
"""This module defines core text-to-image model metadata fields.
|
||||||
This module defines core text-to-image model metadata fields.
|
|
||||||
|
|
||||||
Metadata comprises any descriptive information that is not essential
|
Metadata comprises any descriptive information that is not essential
|
||||||
for getting the model to run. For example "author" is metadata, while
|
for getting the model to run. For example "author" is metadata, while
|
||||||
"type", "base" and "format" are not. The latter fields are part of the
|
"type", "base" and "format" are not. The latter fields are part of the
|
||||||
model's config, as defined in invokeai.backend.model_manager.config.
|
model's config, as defined in invokeai.backend.model_manager.config.
|
||||||
|
|
||||||
Note that the "name" and "description" are also present in `config`.
|
Note that the "name" and "description" are also present in `config`
|
||||||
This may need reworking.
|
records. This is intentional. The config record fields are intended to
|
||||||
|
be editable by the user as a form of customization. The metadata
|
||||||
|
versions of these fields are intended to be kept in sync with the
|
||||||
|
remote repo.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@ -78,7 +80,7 @@ class CivitaiMetadata(ModelMetadataBase):
|
|||||||
description="text description of the model's reversion; usually change history; may contain HTML"
|
description="text description of the model's reversion; usually change history; may contain HTML"
|
||||||
)
|
)
|
||||||
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
|
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
|
||||||
restrictions: LicenseRestrictions = Field(description="license terms", default=LicenseRestrictions)
|
restrictions: LicenseRestrictions = Field(description="license terms", default_factory=LicenseRestrictions)
|
||||||
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
|
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
|
||||||
download_url: AnyHttpUrl = Field(description="download URL for this model")
|
download_url: AnyHttpUrl = Field(description="download URL for this model")
|
||||||
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
|
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
|
||||||
@ -98,7 +100,7 @@ class CivitaiMetadata(ModelMetadataBase):
|
|||||||
@property
|
@property
|
||||||
def allow_commercial_use(self) -> bool:
|
def allow_commercial_use(self) -> bool:
|
||||||
"""Return True if commercial use is allowed."""
|
"""Return True if commercial use is allowed."""
|
||||||
return self.restrictions.AllowCommercialUse == CommercialUsage("None")
|
return self.restrictions.AllowCommercialUse != CommercialUsage("None")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def allow_derivatives(self) -> bool:
|
def allow_derivatives(self) -> bool:
|
||||||
|
@ -4,7 +4,7 @@ SQL Storage for Model Metadata
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from typing import Set
|
from typing import Optional, Set
|
||||||
|
|
||||||
from invokeai.app.services.model_records import UnknownModelException
|
from invokeai.app.services.model_records import UnknownModelException
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
@ -59,9 +59,12 @@ class ModelMetadataStore:
|
|||||||
)
|
)
|
||||||
self._update_tags(model_key, metadata.tags)
|
self._update_tags(model_key, metadata.tags)
|
||||||
self._db.conn.commit()
|
self._db.conn.commit()
|
||||||
except sqlite3.Error as e:
|
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||||
self._db.conn.rollback()
|
self._db.conn.rollback()
|
||||||
raise e
|
raise UnknownModelException from excp
|
||||||
|
except sqlite3.Error as excp:
|
||||||
|
self._db.conn.rollback()
|
||||||
|
raise excp
|
||||||
|
|
||||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||||
@ -111,7 +114,7 @@ class ModelMetadataStore:
|
|||||||
"""Return the keys of models containing all of the listed tags."""
|
"""Return the keys of models containing all of the listed tags."""
|
||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
try:
|
try:
|
||||||
matches: Set[str] = set()
|
matches: Optional[Set[str]] = None
|
||||||
for tag in tags:
|
for tag in tags:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
@ -123,7 +126,9 @@ class ModelMetadataStore:
|
|||||||
(tag,),
|
(tag,),
|
||||||
)
|
)
|
||||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||||
matches = matches.intersection(model_keys) if len(matches) > 0 else model_keys
|
if matches is None:
|
||||||
|
matches = model_keys
|
||||||
|
matches = matches.intersection(model_keys)
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
raise e
|
raise e
|
||||||
return matches
|
return matches
|
||||||
@ -139,6 +144,24 @@ class ModelMetadataStore:
|
|||||||
)
|
)
|
||||||
return {x[0] for x in self._cursor.fetchall()}
|
return {x[0] for x in self._cursor.fetchall()}
|
||||||
|
|
||||||
|
def search_by_name(self, name: str) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Return the keys of models with the indicated name.
|
||||||
|
|
||||||
|
Note that this is the name of the model given to it by
|
||||||
|
the remote source. The user may have changed the local
|
||||||
|
name. The local name will be located in the model config
|
||||||
|
record object.
|
||||||
|
"""
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT id FROM model_metadata
|
||||||
|
WHERE name=?;
|
||||||
|
""",
|
||||||
|
(name,),
|
||||||
|
)
|
||||||
|
return {x[0] for x in self._cursor.fetchall()}
|
||||||
|
|
||||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||||
"""Update tags for the model referenced by model_key."""
|
"""Update tags for the model referenced by model_key."""
|
||||||
# remove previous tags from this model
|
# remove previous tags from this model
|
||||||
|
17
tests/app/services/model_metadata/metadata_examples.py
Normal file
17
tests/app/services/model_metadata/metadata_examples.py
Normal file
File diff suppressed because one or more lines are too long
@ -1,34 +1,39 @@
|
|||||||
"""
|
"""
|
||||||
Test model metadata fetching and storage.
|
Test model metadata fetching and storage.
|
||||||
"""
|
"""
|
||||||
import pytest
|
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List
|
|
||||||
from pydantic import BaseModel, ValidationError
|
import pytest
|
||||||
|
import requests
|
||||||
|
from pydantic.networks import HttpUrl
|
||||||
|
from requests.sessions import Session
|
||||||
|
from requests_testadapter import TestAdapter
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.app.services.model_records import ModelRecordServiceSQL, UnknownModelException
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
|
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
MainCheckpointConfig,
|
ModelFormat,
|
||||||
MainDiffusersConfig,
|
|
||||||
ModelType,
|
ModelType,
|
||||||
TextualInversionConfig,
|
|
||||||
VaeDiffusersConfig,
|
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.metadata import (
|
from invokeai.backend.model_manager.metadata import (
|
||||||
ModelMetadataStore,
|
|
||||||
AnyModelRepoMetadata,
|
|
||||||
CommercialUsage,
|
|
||||||
LicenseRestrictions,
|
|
||||||
HuggingFaceMetadata,
|
|
||||||
CivitaiMetadata,
|
CivitaiMetadata,
|
||||||
|
CivitaiMetadataFetch,
|
||||||
|
CommercialUsage,
|
||||||
|
HuggingFaceMetadata,
|
||||||
|
HuggingFaceMetadataFetch,
|
||||||
|
ModelMetadataStore,
|
||||||
|
)
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
from tests.app.services.model_metadata.metadata_examples import (
|
||||||
|
RepoCivitaiModelMetadata1,
|
||||||
|
RepoCivitaiVersionMetadata1,
|
||||||
|
RepoHFMetadata1,
|
||||||
)
|
)
|
||||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def app_config(datadir: Path) -> InvokeAIAppConfig:
|
def app_config(datadir: Path) -> InvokeAIAppConfig:
|
||||||
return InvokeAIAppConfig(
|
return InvokeAIAppConfig(
|
||||||
@ -36,49 +41,204 @@ def app_config(datadir: Path) -> InvokeAIAppConfig:
|
|||||||
models_dir=datadir / "root/models",
|
models_dir=datadir / "root/models",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def record_store(app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
|
def record_store(app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
|
||||||
logger = InvokeAILogger.get_logger(config=app_config)
|
logger = InvokeAILogger.get_logger(config=app_config)
|
||||||
db = create_mock_sqlite_database(app_config, logger)
|
db = create_mock_sqlite_database(app_config, logger)
|
||||||
store = ModelRecordServiceSQL(db)
|
store = ModelRecordServiceSQL(db)
|
||||||
# add two config records to the database
|
# add three simple config records to the database
|
||||||
raw1 = {
|
raw1 = {
|
||||||
"path": "/tmp/foo2.ckpt",
|
"path": "/tmp/foo1",
|
||||||
|
"format": ModelFormat("diffusers"),
|
||||||
"name": "test2",
|
"name": "test2",
|
||||||
"base": BaseModelType("sd-2"),
|
"base": BaseModelType("sd-2"),
|
||||||
"type": ModelType("vae"),
|
"type": ModelType("vae"),
|
||||||
"original_hash":"111222333444",
|
"original_hash": "111222333444",
|
||||||
"source": "stabilityai/sdxl-vae",
|
"source": "stabilityai/sdxl-vae",
|
||||||
}
|
}
|
||||||
raw2 = {
|
raw2 = {
|
||||||
"path": "/tmp/foo1.ckpt",
|
"path": "/tmp/foo2.ckpt",
|
||||||
"name": "model1",
|
"name": "model1",
|
||||||
|
"format": ModelFormat("checkpoint"),
|
||||||
"base": BaseModelType("sd-1"),
|
"base": BaseModelType("sd-1"),
|
||||||
"type": "main",
|
"type": "main",
|
||||||
"config": "/tmp/foo.yaml",
|
"config": "/tmp/foo.yaml",
|
||||||
"variant": "normal",
|
"variant": "normal",
|
||||||
"format": "checkpoint",
|
|
||||||
"original_hash": "111222333444",
|
"original_hash": "111222333444",
|
||||||
"source": "https://civitai.com/models/206883/split",
|
"source": "https://civitai.com/models/206883/split",
|
||||||
}
|
}
|
||||||
store.add_model('test_config_1', raw1)
|
raw3 = {
|
||||||
store.add_model('test_config_2', raw2)
|
"path": "/tmp/foo3",
|
||||||
|
"format": ModelFormat("diffusers"),
|
||||||
|
"name": "test3",
|
||||||
|
"base": BaseModelType("sdxl"),
|
||||||
|
"type": ModelType("main"),
|
||||||
|
"original_hash": "111222333444",
|
||||||
|
"source": "author3/model3",
|
||||||
|
}
|
||||||
|
store.add_model("test_config_1", raw1)
|
||||||
|
store.add_model("test_config_2", raw2)
|
||||||
|
store.add_model("test_config_3", raw3)
|
||||||
return store
|
return store
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def session() -> Session:
|
||||||
|
sess = requests.Session()
|
||||||
|
sess.mount(
|
||||||
|
"https://huggingface.co/api/models/stabilityai/sdxl-turbo",
|
||||||
|
TestAdapter(
|
||||||
|
RepoHFMetadata1,
|
||||||
|
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
sess.mount(
|
||||||
|
"https://civitai.com/api/v1/model-versions/242807",
|
||||||
|
TestAdapter(
|
||||||
|
RepoCivitaiVersionMetadata1,
|
||||||
|
headers={
|
||||||
|
"Content-Length": len(RepoCivitaiVersionMetadata1),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
sess.mount(
|
||||||
|
"https://civitai.com/api/v1/models/215485",
|
||||||
|
TestAdapter(
|
||||||
|
RepoCivitaiModelMetadata1,
|
||||||
|
headers={
|
||||||
|
"Content-Length": len(RepoCivitaiModelMetadata1),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return sess
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def metadata_store(record_store: ModelRecordServiceSQL) -> ModelMetadataStore:
|
def metadata_store(record_store: ModelRecordServiceSQL) -> ModelMetadataStore:
|
||||||
db = record_store._db # to ensure we are sharing the same database
|
db = record_store._db # to ensure we are sharing the same database
|
||||||
return ModelMetadataStore(db)
|
return ModelMetadataStore(db)
|
||||||
|
|
||||||
|
|
||||||
def test_metadata_store_put_get(metadata_store: ModelMetadataStore) -> None:
|
def test_metadata_store_put_get(metadata_store: ModelMetadataStore) -> None:
|
||||||
input_metadata = HuggingFaceMetadata(name="sdxl-vae",
|
input_metadata = HuggingFaceMetadata(
|
||||||
|
name="sdxl-vae",
|
||||||
author="stabilityai",
|
author="stabilityai",
|
||||||
tags={"text-to-image","diffusers"},
|
tags={"text-to-image", "diffusers"},
|
||||||
id="stabilityai/sdxl-vae",
|
id="stabilityai/sdxl-vae",
|
||||||
tag_dict={"license":"other"},
|
tag_dict={"license": "other"},
|
||||||
last_modified=datetime.datetime.now(),
|
last_modified=datetime.datetime.now(),
|
||||||
)
|
)
|
||||||
metadata_store.add_metadata('test_config_1',input_metadata)
|
metadata_store.add_metadata("test_config_1", input_metadata)
|
||||||
output_metadata = metadata_store.get_metadata('test_config_1')
|
output_metadata = metadata_store.get_metadata("test_config_1")
|
||||||
|
assert input_metadata == output_metadata
|
||||||
|
with pytest.raises(UnknownModelException):
|
||||||
|
metadata_store.add_metadata("unknown_key", input_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def test_metadata_store_update(metadata_store: ModelMetadataStore) -> None:
|
||||||
|
input_metadata = HuggingFaceMetadata(
|
||||||
|
name="sdxl-vae",
|
||||||
|
author="stabilityai",
|
||||||
|
tags={"text-to-image", "diffusers"},
|
||||||
|
id="stabilityai/sdxl-vae",
|
||||||
|
tag_dict={"license": "other"},
|
||||||
|
last_modified=datetime.datetime.now(),
|
||||||
|
)
|
||||||
|
metadata_store.add_metadata("test_config_1", input_metadata)
|
||||||
|
input_metadata.name = "new-name"
|
||||||
|
metadata_store.update_metadata("test_config_1", input_metadata)
|
||||||
|
output_metadata = metadata_store.get_metadata("test_config_1")
|
||||||
|
assert output_metadata.name == "new-name"
|
||||||
assert input_metadata == output_metadata
|
assert input_metadata == output_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def test_metadata_search(metadata_store: ModelMetadataStore) -> None:
|
||||||
|
metadata1 = HuggingFaceMetadata(
|
||||||
|
name="sdxl-vae",
|
||||||
|
author="stabilityai",
|
||||||
|
tags={"text-to-image", "diffusers"},
|
||||||
|
id="stabilityai/sdxl-vae",
|
||||||
|
tag_dict={"license": "other"},
|
||||||
|
last_modified=datetime.datetime.now(),
|
||||||
|
)
|
||||||
|
metadata2 = HuggingFaceMetadata(
|
||||||
|
name="model2",
|
||||||
|
author="stabilityai",
|
||||||
|
tags={"text-to-image", "diffusers", "community-contributed"},
|
||||||
|
id="author2/model2",
|
||||||
|
tag_dict={"license": "other"},
|
||||||
|
last_modified=datetime.datetime.now(),
|
||||||
|
)
|
||||||
|
metadata3 = HuggingFaceMetadata(
|
||||||
|
name="model3",
|
||||||
|
author="author3",
|
||||||
|
tags={"text-to-image", "checkpoint", "community-contributed"},
|
||||||
|
id="author3/model3",
|
||||||
|
tag_dict={"license": "other"},
|
||||||
|
last_modified=datetime.datetime.now(),
|
||||||
|
)
|
||||||
|
metadata_store.add_metadata("test_config_1", metadata1)
|
||||||
|
metadata_store.add_metadata("test_config_2", metadata2)
|
||||||
|
metadata_store.add_metadata("test_config_3", metadata3)
|
||||||
|
|
||||||
|
matches = metadata_store.search_by_author("stabilityai")
|
||||||
|
assert len(matches) == 2
|
||||||
|
assert "test_config_1" in matches
|
||||||
|
assert "test_config_2" in matches
|
||||||
|
matches = metadata_store.search_by_author("Sherlock Holmes")
|
||||||
|
assert not matches
|
||||||
|
|
||||||
|
matches = metadata_store.search_by_name("model3")
|
||||||
|
assert len(matches) == 1
|
||||||
|
assert "test_config_3" in matches
|
||||||
|
|
||||||
|
matches = metadata_store.search_by_tag({"text-to-image"})
|
||||||
|
assert len(matches) == 3
|
||||||
|
|
||||||
|
matches = metadata_store.search_by_tag({"text-to-image", "diffusers"})
|
||||||
|
assert len(matches) == 2
|
||||||
|
assert "test_config_1" in matches
|
||||||
|
assert "test_config_2" in matches
|
||||||
|
|
||||||
|
matches = metadata_store.search_by_tag({"checkpoint", "community-contributed"})
|
||||||
|
assert len(matches) == 1
|
||||||
|
assert "test_config_3" in matches
|
||||||
|
|
||||||
|
# does the tag table update correctly?
|
||||||
|
matches = metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
|
||||||
|
assert not matches
|
||||||
|
metadata3.tags.add("licensed-for-commercial-use")
|
||||||
|
metadata_store.update_metadata("test_config_3", metadata3)
|
||||||
|
matches = metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
|
||||||
|
assert len(matches) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_metadata_civitai_fetch(session: Session) -> None:
|
||||||
|
fetcher = CivitaiMetadataFetch(session)
|
||||||
|
metadata = fetcher.from_url(HttpUrl("https://civitai.com/models/215485/SDXL-turbo"))
|
||||||
|
assert isinstance(metadata, CivitaiMetadata)
|
||||||
|
assert metadata.id == 215485
|
||||||
|
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
|
||||||
|
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
|
||||||
|
assert metadata.restrictions.AllowCommercialUse == CommercialUsage("RentCivit")
|
||||||
|
assert metadata.version_id == 242807
|
||||||
|
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_metadata_hf_fetch(session: Session) -> None:
|
||||||
|
fetcher = HuggingFaceMetadataFetch(session)
|
||||||
|
metadata = fetcher.from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||||
|
assert isinstance(metadata, HuggingFaceMetadata)
|
||||||
|
assert metadata.author == "test_author" # this is not the same as the original
|
||||||
|
assert metadata.files
|
||||||
|
assert metadata.tags == {
|
||||||
|
"diffusers",
|
||||||
|
"onnx",
|
||||||
|
"safetensors",
|
||||||
|
"text-to-image",
|
||||||
|
"license:other",
|
||||||
|
"has_space",
|
||||||
|
"diffusers:StableDiffusionXLPipeline",
|
||||||
|
"region:us",
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user