mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add unit tests and documentation
This commit is contained in:
parent
1940169925
commit
a626ca3e1c
@ -15,7 +15,12 @@ model. These are the:
|
||||
their metadata, and `ModelRecordServiceBase` to store that
|
||||
information. It is also responsible for managing the InvokeAI
|
||||
`models` directory and its contents.
|
||||
|
||||
|
||||
* _ModelMetadataStore_ and _ModelMetaDataFetch_ Backend modules that
|
||||
are able to retrieve metadata from online model repositories,
|
||||
transform them into Pydantic models, and cache them to the InvokeAI
|
||||
SQL database.
|
||||
|
||||
* _DownloadQueueServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
|
||||
A multithreaded downloader responsible
|
||||
for downloading models from a remote source to disk. The download
|
||||
@ -1184,3 +1189,246 @@ other resources that it might have been using.
|
||||
This will start/pause/cancel all jobs that have been submitted to the
|
||||
queue and have not yet reached a terminal state.
|
||||
|
||||
***
|
||||
|
||||
## This Meta be Good: Model Metadata Storage
|
||||
|
||||
The modules found under `invokeai.backend.model_manager.metadata`
|
||||
provide a straightforward API for fetching model metadatda from online
|
||||
repositories. Currently two repositories are supported: HuggingFace
|
||||
and Civitai. However, the modules are easily extended for additional
|
||||
repos, provided that they have defined APIs for metadata access.
|
||||
|
||||
Metadata comprises any descriptive information that is not essential
|
||||
for getting the model to run. For example "author" is metadata, while
|
||||
"type", "base" and "format" are not. The latter fields are part of the
|
||||
model's config, as defined in `invokeai.backend.model_manager.config`.
|
||||
|
||||
### Example Usage:
|
||||
|
||||
```
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadataFetch,
|
||||
CivitaiMetadata
|
||||
ModelMetadataStore,
|
||||
)
|
||||
# to access the initialized sql database
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
|
||||
civitai = CivitaiMetadataFetch()
|
||||
|
||||
# fetch the metadata
|
||||
model_metadata = civitai.from_url("https://civitai.com/models/215796")
|
||||
|
||||
# get some common metadata fields
|
||||
author = model_metadata.author
|
||||
tags = model_metadata.tags
|
||||
|
||||
# get some Civitai-specific fields
|
||||
assert isinstance(model_metadata, CivitaiMetadata)
|
||||
|
||||
trained_words = model_metadata.trained_words
|
||||
base_model = model_metadata.base_model_trained_on
|
||||
thumbnail = model_metadata.thumbnail_url
|
||||
|
||||
# cache the metadata to the database using the key corresponding to
|
||||
# an existing model config record in the `model_config` table
|
||||
sql_cache = ModelMetadataStore(ApiDependencies.invoker.services.db)
|
||||
sql_cache.add_metadata('fb237ace520b6716adc98bcb16e8462c', model_metadata)
|
||||
|
||||
# now we can search the database by tag, author or model name
|
||||
# matches will contain a list of model keys that match the search
|
||||
matches = sql_cache.search_by_tag({"tool", "turbo"})
|
||||
```
|
||||
|
||||
### Structure of the Metadata objects
|
||||
|
||||
There is a short class hierarchy of Metadata objects, all of which
|
||||
descend from the Pydantic `BaseModel`.
|
||||
|
||||
#### `ModelMetadataBase`
|
||||
|
||||
This is the common base class for metadata:
|
||||
|
||||
| **Field Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `name` | str | Repository's name for the model |
|
||||
| `author` | str | Model's author |
|
||||
| `tags` | Set[str] | Model tags |
|
||||
|
||||
|
||||
Note that the model config record also has a `name` field. It is
|
||||
intended that the config record version be locally customizable, while
|
||||
the metadata version is read-only. However, enforcing this is expected
|
||||
to be part of the business logic.
|
||||
|
||||
Descendents of the base add additional fields.
|
||||
|
||||
#### `HuggingFaceMetadata`
|
||||
|
||||
This descends from `ModelMetadataBase` and adds the following fields:
|
||||
|
||||
| **Field Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `type` | Literal["huggingface"] | Used for the discriminated union of metadata classes|
|
||||
| `id` | str | HuggingFace repo_id |
|
||||
| `tag_dict` | Dict[str, Any] | A dictionary of tag/value pairs provided in addition to `tags` |
|
||||
| `last_modified`| datetime | Date of last commit of this model to the repo |
|
||||
| `files` | List[Path] | List of the files in the model repo |
|
||||
|
||||
|
||||
#### `CivitaiMetadata`
|
||||
|
||||
This descends from `ModelMetadataBase` and adds the following fields:
|
||||
|
||||
| **Field Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `type` | Literal["civitai"] | Used for the discriminated union of metadata classes|
|
||||
| `id` | int | Civitai model id |
|
||||
| `version_name` | str | Name of this version of the model (distinct from model name) |
|
||||
| `version_id` | int | Civitai model version id (distinct from model id) |
|
||||
| `created` | datetime | Date the model was uploaded to Civitai; no modification date provided |
|
||||
| `description` | str | Model description. Quite verbose and contains HTML tags |
|
||||
| `version_description` | str | Model version description, usually describes changes to the model |
|
||||
| `nsfw` | bool | Whether the model tends to generate NSFW content |
|
||||
| `restrictions` | LicenseRestrictions | An object that describes what is and isn't allowed with this model |
|
||||
| `trained_words`| Set[str] | Trigger words for this model, if any |
|
||||
| `download_url` | AnyHttpUrl | URL for downloading this version of the model |
|
||||
| `base_model_trained_on` | str | Name of the model that this version was trained on |
|
||||
| `thumbnail_url` | AnyHttpUrl | URL to access a representative thumbnail image of the model's output |
|
||||
| `weight_min` | int | For LoRA sliders, the minimum suggested weight to apply |
|
||||
| `weight_max` | int | For LoRA sliders, the maximum suggested weight to apply |
|
||||
|
||||
Note that `weight_min` and `weight_max` are not currently populated
|
||||
and take the default values of (-1.0, +2.0). The issue is that these
|
||||
values aren't part of the structured data but appear in the text
|
||||
description. Some regular expression or LLM coding may be able to
|
||||
extract these values.
|
||||
|
||||
Also be aware that `base_model_trained_on` is free text and doesn't
|
||||
correspond to our `ModelType` enum.
|
||||
|
||||
`CivitaiMetadata` also defines some convenience properties relating to
|
||||
licensing restrictions: `credit_required`, `allow_commercial_use`,
|
||||
`allow_derivatives` and `allow_different_license`.
|
||||
|
||||
#### `AnyModelRepoMetadata`
|
||||
|
||||
This is a discriminated Union of `CivitaiMetadata` and
|
||||
`HuggingFaceMetadata`.
|
||||
|
||||
### Fetching Metadata from Online Repos
|
||||
|
||||
The `HuggingFaceMetadataFetch` and `CivitaiMetadataFetch` classes will
|
||||
retrieve metadata from their corresponding repositories and return
|
||||
`AnyModelRepoMetadata` objects. Their base class
|
||||
`ModelMetadataFetchBase` is an abstract class that defines two
|
||||
methods: `from_url()` and `from_id()`. The former accepts the type of
|
||||
model URLs that the user will try to cut and paste into the model
|
||||
import form. The latter accepts a string ID in the format recognized
|
||||
by the repository of choice. Both methods return an
|
||||
`AnyModelRepoMetadata`.
|
||||
|
||||
The base class also has a class method `from_json()` which will take
|
||||
the JSON representation of a `ModelMetadata` object, validate it, and
|
||||
return the corresponding `AnyModelRepoMetadata` object.
|
||||
|
||||
When initializing one of the metadata fetching classes, you may
|
||||
provide a `requests.Session` argument. This allows you to customize
|
||||
the low-level HTTP fetch requests and is used, for instance, in the
|
||||
testing suite to avoid hitting the internet.
|
||||
|
||||
The HuggingFace and Civitai fetcher subclasses add additional
|
||||
repo-specific fetching methods:
|
||||
|
||||
|
||||
#### HuggingFaceMetadataFetch
|
||||
|
||||
This overrides its base class `from_json()` method to return a
|
||||
`HuggingFaceMetadata` object directly.
|
||||
|
||||
#### CivitaiMetadataFetch
|
||||
|
||||
This adds the following methods:
|
||||
|
||||
`from_civitai_modelid()` This takes the ID of a model, finds the
|
||||
default version of the model, and then retrieves the metadata for
|
||||
that version, returning a `CivitaiMetadata` object directly.
|
||||
|
||||
`from_civitai_versionid()` This takes the ID of a model version and
|
||||
retrieves its metadata. Functionally equivalent to `from_id()`, the
|
||||
only difference is that it returna a `CivitaiMetadata` object rather
|
||||
than an `AnyModelRepoMetadata`.
|
||||
|
||||
|
||||
### Metadata Storage
|
||||
|
||||
The `ModelMetadataStore` provides a simple facility to store model
|
||||
metadata in the `invokeai.db` database. The data is stored as a JSON
|
||||
blob, with a few common fields (`name`, `author`, `tags`) broken out
|
||||
to be searchable.
|
||||
|
||||
When a metadata object is saved to the database, it is identified
|
||||
using the model key, _and this key must correspond to an existing
|
||||
model key in the model_config table_. There is a foreign key integrity
|
||||
constraint between the `model_config.id` field and the
|
||||
`model_metadata.id` field such that if you attempt to save metadata
|
||||
under an unknown key, the attempt will result in an
|
||||
`UnknownModelException`. Likewise, when a model is deleted from
|
||||
`model_config`, the deletion of the corresponding metadata record will
|
||||
be triggered.
|
||||
|
||||
Tags are stored in a normalized fashion in the tables `model_tags` and
|
||||
`tags`. Triggers keep the tag table in sync with the `model_metadata`
|
||||
table.
|
||||
|
||||
To create the storage object, initialize it with the InvokeAI
|
||||
`SqliteDatabase` object. This is often done this way:
|
||||
|
||||
```
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
metadata_store = ModelMetadataStore(ApiDependencies.invoker.services.db)
|
||||
```
|
||||
|
||||
You can then access the storage with the following methods:
|
||||
|
||||
#### `add_metadata(key, metadata)`
|
||||
|
||||
Add the metadata using a previously-defined model key.
|
||||
|
||||
There is currently no `delete_metadata()` method. The metadata will
|
||||
persist until the matching config is deleted from the `model_config`
|
||||
table.
|
||||
|
||||
#### `get_metadata(key) -> AnyModelRepoMetadata`
|
||||
|
||||
Retrieve the metadata corresponding to the model key.
|
||||
|
||||
#### `update_metadata(key, new_metadata)`
|
||||
|
||||
Update an existing metadata record with new metadata.
|
||||
|
||||
#### `search_by_tag(tags: Set[str]) -> Set[str]`
|
||||
|
||||
Given a set of tags, find models that are tagged with them. If
|
||||
multiple tags are provided then a matching model must be tagged with
|
||||
*all* the tags in the set. This method returns a set of model keys and
|
||||
is intended to be used in conjunction with the `ModelRecordService`:
|
||||
|
||||
```
|
||||
model_config_store = ApiDependencies.invoker.services.model_records
|
||||
matches = metadata_store.search_by_tag({'license:other'})
|
||||
models = [model_config_store.get(x) for x in matches]
|
||||
```
|
||||
|
||||
#### `search_by_name(name: str) -> Set[str]
|
||||
|
||||
Find all model metadata records that have the given name and return a
|
||||
set of keys to the corresponding model config objects.
|
||||
|
||||
#### `search_by_author(author: str) -> Set[str]
|
||||
|
||||
Find all model metadata records that have the given author and return
|
||||
a set of keys to the corresponding model config objects.
|
||||
|
||||
|
@ -29,7 +29,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator = SqliteMigrator(db=db)
|
||||
migrator.register_migration(build_migration_1())
|
||||
migrator.register_migration(build_migration_2(image_files=image_files, logger=logger))
|
||||
migrator.register_migration(build_migration_3())
|
||||
migrator.register_migration(build_migration_3(app_config=config, logger=logger))
|
||||
migrator.register_migration(build_migration_4())
|
||||
migrator.run_migrations()
|
||||
|
||||
|
@ -11,8 +11,6 @@ from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
UnsafeWorkflowWithVersionValidator,
|
||||
)
|
||||
|
||||
from .util.migrate_yaml_config_1 import MigrateModelYamlToDb1
|
||||
|
||||
|
||||
class Migration2Callback:
|
||||
def __init__(self, image_files: ImageFileStorageBase, logger: Logger):
|
||||
@ -25,8 +23,6 @@ class Migration2Callback:
|
||||
self._drop_old_workflow_tables(cursor)
|
||||
self._add_workflow_library(cursor)
|
||||
self._drop_model_manager_metadata(cursor)
|
||||
self._recreate_model_config(cursor)
|
||||
self._migrate_model_config_records(cursor)
|
||||
self._migrate_embedded_workflows(cursor)
|
||||
|
||||
def _add_images_has_workflow(self, cursor: sqlite3.Cursor) -> None:
|
||||
@ -100,45 +96,6 @@ class Migration2Callback:
|
||||
"""Drops the `model_manager_metadata` table."""
|
||||
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
|
||||
|
||||
def _recreate_model_config(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
Drops the `model_config` table, recreating it.
|
||||
|
||||
In 3.4.0, this table used explicit columns but was changed to use json_extract 3.5.0.
|
||||
|
||||
Because this table is not used in production, we are able to simply drop it and recreate it.
|
||||
"""
|
||||
|
||||
cursor.execute("DROP TABLE IF EXISTS model_config;")
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_config (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
-- The next 3 fields are enums in python, unrestricted string here
|
||||
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
|
||||
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
|
||||
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
|
||||
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
|
||||
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
|
||||
original_hash TEXT, -- could be null
|
||||
-- Serialized JSON representation of the whole config object,
|
||||
-- which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- unique constraint on combo of name, base and type
|
||||
UNIQUE(name, base, type)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def _migrate_model_config_records(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""After updating the model config table, we repopulate it."""
|
||||
model_record_migrator = MigrateModelYamlToDb1(cursor)
|
||||
model_record_migrator.migrate()
|
||||
|
||||
def _migrate_embedded_workflows(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
|
||||
|
@ -1,13 +1,16 @@
|
||||
import sqlite3
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
from .util.migrate_yaml_config_1 import MigrateModelYamlToDb1
|
||||
|
||||
|
||||
class Migration3Callback:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
|
||||
self._app_config = app_config
|
||||
self._logger = logger
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._drop_model_manager_metadata(cursor)
|
||||
@ -54,11 +57,12 @@ class Migration3Callback:
|
||||
|
||||
def _migrate_model_config_records(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""After updating the model config table, we repopulate it."""
|
||||
model_record_migrator = MigrateModelYamlToDb1(cursor)
|
||||
self._logger.info("Migrating model config records from models.yaml to database")
|
||||
model_record_migrator = MigrateModelYamlToDb1(self._app_config, self._logger, cursor)
|
||||
model_record_migrator.migrate()
|
||||
|
||||
|
||||
def build_migration_3() -> Migration:
|
||||
def build_migration_3(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
|
||||
"""
|
||||
Build the migration from database version 2 to 3.
|
||||
|
||||
@ -69,7 +73,7 @@ def build_migration_3() -> Migration:
|
||||
migration_3 = Migration(
|
||||
from_version=2,
|
||||
to_version=3,
|
||||
callback=Migration3Callback(),
|
||||
callback=Migration3Callback(app_config=app_config, logger=logger),
|
||||
)
|
||||
|
||||
return migration_3
|
||||
|
@ -23,7 +23,6 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
ModelsValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
@ -46,10 +45,9 @@ class MigrateModelYamlToDb1:
|
||||
logger: Logger
|
||||
cursor: sqlite3.Cursor
|
||||
|
||||
def __init__(self, cursor: sqlite3.Cursor = None) -> None:
|
||||
self.config = InvokeAIAppConfig.get_config()
|
||||
self.config.parse_args()
|
||||
self.logger = InvokeAILogger.get_logger()
|
||||
def __init__(self, config: InvokeAIAppConfig, logger: Logger, cursor: sqlite3.Cursor = None) -> None:
|
||||
self.config = config
|
||||
self.logger = logger
|
||||
self.cursor = cursor
|
||||
|
||||
def get_yaml(self) -> DictConfig:
|
||||
|
@ -19,6 +19,7 @@ if data.allow_commercial_use:
|
||||
print("Commercial use of this model is allowed")
|
||||
"""
|
||||
|
||||
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch
|
||||
from .metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
AnyModelRepoMetadataValidator,
|
||||
@ -37,4 +38,6 @@ __all__ = [
|
||||
"HuggingFaceMetadata",
|
||||
"CivitaiMetadata",
|
||||
"ModelMetadataStore",
|
||||
"CivitaiMetadataFetch",
|
||||
"HuggingFaceMetadataFetch",
|
||||
]
|
||||
|
@ -77,10 +77,10 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
return self.from_civitai_modelid(int(model_id))
|
||||
elif match := re.match(CIVITAI_VERSION_PAGE_RE, str(url)):
|
||||
version_id = match.group(1)
|
||||
return self._from_civitai_versionid(int(version_id))
|
||||
return self.from_civitai_versionid(int(version_id))
|
||||
elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url)):
|
||||
version_id = match.group(1)
|
||||
return self._from_civitai_versionid(int(version_id))
|
||||
return self.from_civitai_versionid(int(version_id))
|
||||
raise UnknownModelException("The url '{url}' does not match any known Civitai URL patterns")
|
||||
|
||||
def from_id(self, id: str) -> AnyModelRepoMetadata:
|
||||
@ -89,7 +89,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
|
||||
May raise an `UnknownModelException`.
|
||||
"""
|
||||
return self._from_civitai_versionid(int(id))
|
||||
return self.from_civitai_versionid(int(id))
|
||||
|
||||
def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata:
|
||||
"""
|
||||
@ -100,9 +100,9 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model = self._requests.get(model_url).json()
|
||||
default_version = model["modelVersions"][0]["id"]
|
||||
return self._from_civitai_versionid(default_version, model)
|
||||
return self.from_civitai_versionid(default_version, model)
|
||||
|
||||
def _from_civitai_versionid(
|
||||
def from_civitai_versionid(
|
||||
self, version_id: int, model_metadata: Optional[Dict[str, Any]] = None
|
||||
) -> CivitaiMetadata:
|
||||
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
||||
|
@ -28,6 +28,8 @@ from invokeai.app.services.model_records import UnknownModelException
|
||||
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator, HuggingFaceMetadata
|
||||
from .fetch_base import ModelMetadataFetchBase
|
||||
|
||||
HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)"
|
||||
|
||||
|
||||
class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
"""Fetch model metadata from HuggingFace."""
|
||||
@ -68,7 +70,6 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
|
||||
In the case of an invalid or missing URL, raises a ModelNotFound exception.
|
||||
"""
|
||||
HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)"
|
||||
if match := re.match(HF_MODEL_RE, str(url)):
|
||||
repo_id = match.group(1)
|
||||
return self.from_id(repo_id)
|
||||
|
@ -1,15 +1,17 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
This module defines core text-to-image model metadata fields.
|
||||
"""This module defines core text-to-image model metadata fields.
|
||||
|
||||
Metadata comprises any descriptive information that is not essential
|
||||
for getting the model to run. For example "author" is metadata, while
|
||||
"type", "base" and "format" are not. The latter fields are part of the
|
||||
model's config, as defined in invokeai.backend.model_manager.config.
|
||||
|
||||
Note that the "name" and "description" are also present in `config`.
|
||||
This may need reworking.
|
||||
Note that the "name" and "description" are also present in `config`
|
||||
records. This is intentional. The config record fields are intended to
|
||||
be editable by the user as a form of customization. The metadata
|
||||
versions of these fields are intended to be kept in sync with the
|
||||
remote repo.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
@ -78,7 +80,7 @@ class CivitaiMetadata(ModelMetadataBase):
|
||||
description="text description of the model's reversion; usually change history; may contain HTML"
|
||||
)
|
||||
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
|
||||
restrictions: LicenseRestrictions = Field(description="license terms", default=LicenseRestrictions)
|
||||
restrictions: LicenseRestrictions = Field(description="license terms", default_factory=LicenseRestrictions)
|
||||
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
|
||||
download_url: AnyHttpUrl = Field(description="download URL for this model")
|
||||
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
|
||||
@ -98,7 +100,7 @@ class CivitaiMetadata(ModelMetadataBase):
|
||||
@property
|
||||
def allow_commercial_use(self) -> bool:
|
||||
"""Return True if commercial use is allowed."""
|
||||
return self.restrictions.AllowCommercialUse == CommercialUsage("None")
|
||||
return self.restrictions.AllowCommercialUse != CommercialUsage("None")
|
||||
|
||||
@property
|
||||
def allow_derivatives(self) -> bool:
|
||||
|
@ -4,7 +4,7 @@ SQL Storage for Model Metadata
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from typing import Set
|
||||
from typing import Optional, Set
|
||||
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
@ -59,9 +59,12 @@ class ModelMetadataStore:
|
||||
)
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
raise UnknownModelException from excp
|
||||
except sqlite3.Error as excp:
|
||||
self._db.conn.rollback()
|
||||
raise excp
|
||||
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
@ -111,19 +114,21 @@ class ModelMetadataStore:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
matches: Set[str] = set()
|
||||
matches: Optional[Set[str]] = None
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.id FROM model_tags AS a,
|
||||
tags AS b
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
matches = matches.intersection(model_keys) if len(matches) > 0 else model_keys
|
||||
if matches is None:
|
||||
matches = model_keys
|
||||
matches = matches.intersection(model_keys)
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return matches
|
||||
@ -139,6 +144,24 @@ class ModelMetadataStore:
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE name=?;
|
||||
""",
|
||||
(name,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
# remove previous tags from this model
|
||||
|
17
tests/app/services/model_metadata/metadata_examples.py
Normal file
17
tests/app/services/model_metadata/metadata_examples.py
Normal file
File diff suppressed because one or more lines are too long
@ -1,34 +1,39 @@
|
||||
"""
|
||||
Test model metadata fetching and storage.
|
||||
"""
|
||||
import pytest
|
||||
import datetime
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from pydantic.networks import HttpUrl
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
|
||||
from invokeai.app.services.model_records import ModelRecordServiceSQL, UnknownModelException
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
MainCheckpointConfig,
|
||||
MainDiffusersConfig,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
TextualInversionConfig,
|
||||
VaeDiffusersConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
ModelMetadataStore,
|
||||
AnyModelRepoMetadata,
|
||||
CommercialUsage,
|
||||
LicenseRestrictions,
|
||||
HuggingFaceMetadata,
|
||||
CivitaiMetadata,
|
||||
CivitaiMetadataFetch,
|
||||
CommercialUsage,
|
||||
HuggingFaceMetadata,
|
||||
HuggingFaceMetadataFetch,
|
||||
ModelMetadataStore,
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.app.services.model_metadata.metadata_examples import (
|
||||
RepoCivitaiModelMetadata1,
|
||||
RepoCivitaiVersionMetadata1,
|
||||
RepoHFMetadata1,
|
||||
)
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_config(datadir: Path) -> InvokeAIAppConfig:
|
||||
return InvokeAIAppConfig(
|
||||
@ -36,49 +41,204 @@ def app_config(datadir: Path) -> InvokeAIAppConfig:
|
||||
models_dir=datadir / "root/models",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def record_store(app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
db = create_mock_sqlite_database(app_config, logger)
|
||||
store = ModelRecordServiceSQL(db)
|
||||
# add two config records to the database
|
||||
# add three simple config records to the database
|
||||
raw1 = {
|
||||
"path": "/tmp/foo2.ckpt",
|
||||
"path": "/tmp/foo1",
|
||||
"format": ModelFormat("diffusers"),
|
||||
"name": "test2",
|
||||
"base": BaseModelType("sd-2"),
|
||||
"type": ModelType("vae"),
|
||||
"original_hash":"111222333444",
|
||||
"original_hash": "111222333444",
|
||||
"source": "stabilityai/sdxl-vae",
|
||||
}
|
||||
raw2 = {
|
||||
"path": "/tmp/foo1.ckpt",
|
||||
"path": "/tmp/foo2.ckpt",
|
||||
"name": "model1",
|
||||
"format": ModelFormat("checkpoint"),
|
||||
"base": BaseModelType("sd-1"),
|
||||
"type": "main",
|
||||
"config": "/tmp/foo.yaml",
|
||||
"variant": "normal",
|
||||
"format": "checkpoint",
|
||||
"original_hash": "111222333444",
|
||||
"source": "https://civitai.com/models/206883/split",
|
||||
}
|
||||
store.add_model('test_config_1', raw1)
|
||||
store.add_model('test_config_2', raw2)
|
||||
raw3 = {
|
||||
"path": "/tmp/foo3",
|
||||
"format": ModelFormat("diffusers"),
|
||||
"name": "test3",
|
||||
"base": BaseModelType("sdxl"),
|
||||
"type": ModelType("main"),
|
||||
"original_hash": "111222333444",
|
||||
"source": "author3/model3",
|
||||
}
|
||||
store.add_model("test_config_1", raw1)
|
||||
store.add_model("test_config_2", raw2)
|
||||
store.add_model("test_config_3", raw3)
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session() -> Session:
|
||||
sess = requests.Session()
|
||||
sess.mount(
|
||||
"https://huggingface.co/api/models/stabilityai/sdxl-turbo",
|
||||
TestAdapter(
|
||||
RepoHFMetadata1,
|
||||
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
|
||||
),
|
||||
)
|
||||
sess.mount(
|
||||
"https://civitai.com/api/v1/model-versions/242807",
|
||||
TestAdapter(
|
||||
RepoCivitaiVersionMetadata1,
|
||||
headers={
|
||||
"Content-Length": len(RepoCivitaiVersionMetadata1),
|
||||
},
|
||||
),
|
||||
)
|
||||
sess.mount(
|
||||
"https://civitai.com/api/v1/models/215485",
|
||||
TestAdapter(
|
||||
RepoCivitaiModelMetadata1,
|
||||
headers={
|
||||
"Content-Length": len(RepoCivitaiModelMetadata1),
|
||||
},
|
||||
),
|
||||
)
|
||||
return sess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metadata_store(record_store: ModelRecordServiceSQL) -> ModelMetadataStore:
|
||||
db = record_store._db # to ensure we are sharing the same database
|
||||
db = record_store._db # to ensure we are sharing the same database
|
||||
return ModelMetadataStore(db)
|
||||
|
||||
|
||||
def test_metadata_store_put_get(metadata_store: ModelMetadataStore) -> None:
|
||||
input_metadata = HuggingFaceMetadata(name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
tags={"text-to-image","diffusers"},
|
||||
id="stabilityai/sdxl-vae",
|
||||
tag_dict={"license":"other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
metadata_store.add_metadata('test_config_1',input_metadata)
|
||||
output_metadata = metadata_store.get_metadata('test_config_1')
|
||||
input_metadata = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
tags={"text-to-image", "diffusers"},
|
||||
id="stabilityai/sdxl-vae",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
metadata_store.add_metadata("test_config_1", input_metadata)
|
||||
output_metadata = metadata_store.get_metadata("test_config_1")
|
||||
assert input_metadata == output_metadata
|
||||
with pytest.raises(UnknownModelException):
|
||||
metadata_store.add_metadata("unknown_key", input_metadata)
|
||||
|
||||
|
||||
def test_metadata_store_update(metadata_store: ModelMetadataStore) -> None:
|
||||
input_metadata = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
tags={"text-to-image", "diffusers"},
|
||||
id="stabilityai/sdxl-vae",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
metadata_store.add_metadata("test_config_1", input_metadata)
|
||||
input_metadata.name = "new-name"
|
||||
metadata_store.update_metadata("test_config_1", input_metadata)
|
||||
output_metadata = metadata_store.get_metadata("test_config_1")
|
||||
assert output_metadata.name == "new-name"
|
||||
assert input_metadata == output_metadata
|
||||
|
||||
|
||||
def test_metadata_search(metadata_store: ModelMetadataStore) -> None:
|
||||
metadata1 = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
tags={"text-to-image", "diffusers"},
|
||||
id="stabilityai/sdxl-vae",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
metadata2 = HuggingFaceMetadata(
|
||||
name="model2",
|
||||
author="stabilityai",
|
||||
tags={"text-to-image", "diffusers", "community-contributed"},
|
||||
id="author2/model2",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
metadata3 = HuggingFaceMetadata(
|
||||
name="model3",
|
||||
author="author3",
|
||||
tags={"text-to-image", "checkpoint", "community-contributed"},
|
||||
id="author3/model3",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
metadata_store.add_metadata("test_config_1", metadata1)
|
||||
metadata_store.add_metadata("test_config_2", metadata2)
|
||||
metadata_store.add_metadata("test_config_3", metadata3)
|
||||
|
||||
matches = metadata_store.search_by_author("stabilityai")
|
||||
assert len(matches) == 2
|
||||
assert "test_config_1" in matches
|
||||
assert "test_config_2" in matches
|
||||
matches = metadata_store.search_by_author("Sherlock Holmes")
|
||||
assert not matches
|
||||
|
||||
matches = metadata_store.search_by_name("model3")
|
||||
assert len(matches) == 1
|
||||
assert "test_config_3" in matches
|
||||
|
||||
matches = metadata_store.search_by_tag({"text-to-image"})
|
||||
assert len(matches) == 3
|
||||
|
||||
matches = metadata_store.search_by_tag({"text-to-image", "diffusers"})
|
||||
assert len(matches) == 2
|
||||
assert "test_config_1" in matches
|
||||
assert "test_config_2" in matches
|
||||
|
||||
matches = metadata_store.search_by_tag({"checkpoint", "community-contributed"})
|
||||
assert len(matches) == 1
|
||||
assert "test_config_3" in matches
|
||||
|
||||
# does the tag table update correctly?
|
||||
matches = metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
|
||||
assert not matches
|
||||
metadata3.tags.add("licensed-for-commercial-use")
|
||||
metadata_store.update_metadata("test_config_3", metadata3)
|
||||
matches = metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
|
||||
assert len(matches) == 1
|
||||
|
||||
|
||||
def test_metadata_civitai_fetch(session: Session) -> None:
|
||||
fetcher = CivitaiMetadataFetch(session)
|
||||
metadata = fetcher.from_url(HttpUrl("https://civitai.com/models/215485/SDXL-turbo"))
|
||||
assert isinstance(metadata, CivitaiMetadata)
|
||||
assert metadata.id == 215485
|
||||
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
|
||||
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
|
||||
assert metadata.restrictions.AllowCommercialUse == CommercialUsage("RentCivit")
|
||||
assert metadata.version_id == 242807
|
||||
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}
|
||||
|
||||
|
||||
def test_metadata_hf_fetch(session: Session) -> None:
|
||||
fetcher = HuggingFaceMetadataFetch(session)
|
||||
metadata = fetcher.from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
assert isinstance(metadata, HuggingFaceMetadata)
|
||||
assert metadata.author == "test_author" # this is not the same as the original
|
||||
assert metadata.files
|
||||
assert metadata.tags == {
|
||||
"diffusers",
|
||||
"onnx",
|
||||
"safetensors",
|
||||
"text-to-image",
|
||||
"license:other",
|
||||
"has_space",
|
||||
"diffusers:StableDiffusionXLPipeline",
|
||||
"region:us",
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user