Compare commits

...

9 Commits

15 changed files with 1316 additions and 55 deletions

View File

@ -15,7 +15,12 @@ model. These are the:
their metadata, and `ModelRecordServiceBase` to store that
information. It is also responsible for managing the InvokeAI
`models` directory and its contents.
* _ModelMetadataStore_ and _ModelMetaDataFetch_ Backend modules that
are able to retrieve metadata from online model repositories,
transform them into Pydantic models, and cache them to the InvokeAI
SQL database.
* _DownloadQueueServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
A multithreaded downloader responsible
for downloading models from a remote source to disk. The download
@ -1184,3 +1189,248 @@ other resources that it might have been using.
This will start/pause/cancel all jobs that have been submitted to the
queue and have not yet reached a terminal state.
***
## This Meta be Good: Model Metadata Storage
The modules found under `invokeai.backend.model_manager.metadata`
provide a straightforward API for fetching model metadatda from online
repositories. Currently two repositories are supported: HuggingFace
and Civitai. However, the modules are easily extended for additional
repos, provided that they have defined APIs for metadata access.
Metadata comprises any descriptive information that is not essential
for getting the model to run. For example "author" is metadata, while
"type", "base" and "format" are not. The latter fields are part of the
model's config, as defined in `invokeai.backend.model_manager.config`.
### Example Usage:
```
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
CivitaiMetadataFetch,
CivitaiMetadata
ModelMetadataStore,
)
# to access the initialized sql database
from invokeai.app.api.dependencies import ApiDependencies
civitai = CivitaiMetadataFetch()
# fetch the metadata
model_metadata = civitai.from_url("https://civitai.com/models/215796")
# get some common metadata fields
author = model_metadata.author
tags = model_metadata.tags
# get some Civitai-specific fields
assert isinstance(model_metadata, CivitaiMetadata)
trained_words = model_metadata.trained_words
base_model = model_metadata.base_model_trained_on
thumbnail = model_metadata.thumbnail_url
# cache the metadata to the database using the key corresponding to
# an existing model config record in the `model_config` table
sql_cache = ModelMetadataStore(ApiDependencies.invoker.services.db)
sql_cache.add_metadata('fb237ace520b6716adc98bcb16e8462c', model_metadata)
# now we can search the database by tag, author or model name
# matches will contain a list of model keys that match the search
matches = sql_cache.search_by_tag({"tool", "turbo"})
```
### Structure of the Metadata objects
There is a short class hierarchy of Metadata objects, all of which
descend from the Pydantic `BaseModel`.
#### `ModelMetadataBase`
This is the common base class for metadata:
| **Field Name** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `name` | str | Repository's name for the model |
| `author` | str | Model's author |
| `tags` | Set[str] | Model tags |
Note that the model config record also has a `name` field. It is
intended that the config record version be locally customizable, while
the metadata version is read-only. However, enforcing this is expected
to be part of the business logic.
Descendents of the base add additional fields.
#### `HuggingFaceMetadata`
This descends from `ModelMetadataBase` and adds the following fields:
| **Field Name** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `type` | Literal["huggingface"] | Used for the discriminated union of metadata classes|
| `id` | str | HuggingFace repo_id |
| `tag_dict` | Dict[str, Any] | A dictionary of tag/value pairs provided in addition to `tags` |
| `last_modified`| datetime | Date of last commit of this model to the repo |
| `files` | List[Path] | List of the files in the model repo |
#### `CivitaiMetadata`
This descends from `ModelMetadataBase` and adds the following fields:
| **Field Name** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `type` | Literal["civitai"] | Used for the discriminated union of metadata classes|
| `id` | int | Civitai model id |
| `version_name` | str | Name of this version of the model (distinct from model name) |
| `version_id` | int | Civitai model version id (distinct from model id) |
| `created` | datetime | Date this version of the model was created |
| `updated` | datetime | Date this version of the model was last updated |
| `published` | datetime | Date this version of the model was published to Civitai |
| `description` | str | Model description. Quite verbose and contains HTML tags |
| `version_description` | str | Model version description, usually describes changes to the model |
| `nsfw` | bool | Whether the model tends to generate NSFW content |
| `restrictions` | LicenseRestrictions | An object that describes what is and isn't allowed with this model |
| `trained_words`| Set[str] | Trigger words for this model, if any |
| `download_url` | AnyHttpUrl | URL for downloading this version of the model |
| `base_model_trained_on` | str | Name of the model that this version was trained on |
| `thumbnail_url` | AnyHttpUrl | URL to access a representative thumbnail image of the model's output |
| `weight_min` | int | For LoRA sliders, the minimum suggested weight to apply |
| `weight_max` | int | For LoRA sliders, the maximum suggested weight to apply |
Note that `weight_min` and `weight_max` are not currently populated
and take the default values of (-1.0, +2.0). The issue is that these
values aren't part of the structured data but appear in the text
description. Some regular expression or LLM coding may be able to
extract these values.
Also be aware that `base_model_trained_on` is free text and doesn't
correspond to our `ModelType` enum.
`CivitaiMetadata` also defines some convenience properties relating to
licensing restrictions: `credit_required`, `allow_commercial_use`,
`allow_derivatives` and `allow_different_license`.
#### `AnyModelRepoMetadata`
This is a discriminated Union of `CivitaiMetadata` and
`HuggingFaceMetadata`.
### Fetching Metadata from Online Repos
The `HuggingFaceMetadataFetch` and `CivitaiMetadataFetch` classes will
retrieve metadata from their corresponding repositories and return
`AnyModelRepoMetadata` objects. Their base class
`ModelMetadataFetchBase` is an abstract class that defines two
methods: `from_url()` and `from_id()`. The former accepts the type of
model URLs that the user will try to cut and paste into the model
import form. The latter accepts a string ID in the format recognized
by the repository of choice. Both methods return an
`AnyModelRepoMetadata`.
The base class also has a class method `from_json()` which will take
the JSON representation of a `ModelMetadata` object, validate it, and
return the corresponding `AnyModelRepoMetadata` object.
When initializing one of the metadata fetching classes, you may
provide a `requests.Session` argument. This allows you to customize
the low-level HTTP fetch requests and is used, for instance, in the
testing suite to avoid hitting the internet.
The HuggingFace and Civitai fetcher subclasses add additional
repo-specific fetching methods:
#### HuggingFaceMetadataFetch
This overrides its base class `from_json()` method to return a
`HuggingFaceMetadata` object directly.
#### CivitaiMetadataFetch
This adds the following methods:
`from_civitai_modelid()` This takes the ID of a model, finds the
default version of the model, and then retrieves the metadata for
that version, returning a `CivitaiMetadata` object directly.
`from_civitai_versionid()` This takes the ID of a model version and
retrieves its metadata. Functionally equivalent to `from_id()`, the
only difference is that it returna a `CivitaiMetadata` object rather
than an `AnyModelRepoMetadata`.
### Metadata Storage
The `ModelMetadataStore` provides a simple facility to store model
metadata in the `invokeai.db` database. The data is stored as a JSON
blob, with a few common fields (`name`, `author`, `tags`) broken out
to be searchable.
When a metadata object is saved to the database, it is identified
using the model key, _and this key must correspond to an existing
model key in the model_config table_. There is a foreign key integrity
constraint between the `model_config.id` field and the
`model_metadata.id` field such that if you attempt to save metadata
under an unknown key, the attempt will result in an
`UnknownModelException`. Likewise, when a model is deleted from
`model_config`, the deletion of the corresponding metadata record will
be triggered.
Tags are stored in a normalized fashion in the tables `model_tags` and
`tags`. Triggers keep the tag table in sync with the `model_metadata`
table.
To create the storage object, initialize it with the InvokeAI
`SqliteDatabase` object. This is often done this way:
```
from invokeai.app.api.dependencies import ApiDependencies
metadata_store = ModelMetadataStore(ApiDependencies.invoker.services.db)
```
You can then access the storage with the following methods:
#### `add_metadata(key, metadata)`
Add the metadata using a previously-defined model key.
There is currently no `delete_metadata()` method. The metadata will
persist until the matching config is deleted from the `model_config`
table.
#### `get_metadata(key) -> AnyModelRepoMetadata`
Retrieve the metadata corresponding to the model key.
#### `update_metadata(key, new_metadata)`
Update an existing metadata record with new metadata.
#### `search_by_tag(tags: Set[str]) -> Set[str]`
Given a set of tags, find models that are tagged with them. If
multiple tags are provided then a matching model must be tagged with
*all* the tags in the set. This method returns a set of model keys and
is intended to be used in conjunction with the `ModelRecordService`:
```
model_config_store = ApiDependencies.invoker.services.model_records
matches = metadata_store.search_by_tag({'license:other'})
models = [model_config_store.get(x) for x in matches]
```
#### `search_by_name(name: str) -> Set[str]
Find all model metadata records that have the given name and return a
set of keys to the corresponding model config objects.
#### `search_by_author(author: str) -> Set[str]
Find all model metadata records that have the given author and return
a set of keys to the corresponding model config objects.

View File

@ -6,6 +6,7 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import build_migration_2
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import build_migration_3
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -28,7 +29,8 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator = SqliteMigrator(db=db)
migrator.register_migration(build_migration_1())
migrator.register_migration(build_migration_2(image_files=image_files, logger=logger))
migrator.register_migration(build_migration_3())
migrator.register_migration(build_migration_3(app_config=config, logger=logger))
migrator.register_migration(build_migration_4())
migrator.run_migrations()
return db

View File

@ -11,8 +11,6 @@ from invokeai.app.services.workflow_records.workflow_records_common import (
UnsafeWorkflowWithVersionValidator,
)
from .util.migrate_yaml_config_1 import MigrateModelYamlToDb1
class Migration2Callback:
def __init__(self, image_files: ImageFileStorageBase, logger: Logger):
@ -25,8 +23,6 @@ class Migration2Callback:
self._drop_old_workflow_tables(cursor)
self._add_workflow_library(cursor)
self._drop_model_manager_metadata(cursor)
self._recreate_model_config(cursor)
self._migrate_model_config_records(cursor)
self._migrate_embedded_workflows(cursor)
def _add_images_has_workflow(self, cursor: sqlite3.Cursor) -> None:
@ -100,45 +96,6 @@ class Migration2Callback:
"""Drops the `model_manager_metadata` table."""
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
def _recreate_model_config(self, cursor: sqlite3.Cursor) -> None:
"""
Drops the `model_config` table, recreating it.
In 3.4.0, this table used explicit columns but was changed to use json_extract 3.5.0.
Because this table is not used in production, we are able to simply drop it and recreate it.
"""
cursor.execute("DROP TABLE IF EXISTS model_config;")
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
)
def _migrate_model_config_records(self, cursor: sqlite3.Cursor) -> None:
"""After updating the model config table, we repopulate it."""
model_record_migrator = MigrateModelYamlToDb1(cursor)
model_record_migrator.migrate()
def _migrate_embedded_workflows(self, cursor: sqlite3.Cursor) -> None:
"""
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in

View File

@ -1,13 +1,16 @@
import sqlite3
from logging import Logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
from .util.migrate_yaml_config_1 import MigrateModelYamlToDb1
class Migration3Callback:
def __init__(self) -> None:
pass
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
self._app_config = app_config
self._logger = logger
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._drop_model_manager_metadata(cursor)
@ -54,11 +57,12 @@ class Migration3Callback:
def _migrate_model_config_records(self, cursor: sqlite3.Cursor) -> None:
"""After updating the model config table, we repopulate it."""
model_record_migrator = MigrateModelYamlToDb1(cursor)
self._logger.info("Migrating model config records from models.yaml to database")
model_record_migrator = MigrateModelYamlToDb1(self._app_config, self._logger, cursor)
model_record_migrator.migrate()
def build_migration_3() -> Migration:
def build_migration_3(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
"""
Build the migration from database version 2 to 3.
@ -69,7 +73,7 @@ def build_migration_3() -> Migration:
migration_3 = Migration(
from_version=2,
to_version=3,
callback=Migration3Callback(),
callback=Migration3Callback(app_config=app_config, logger=logger),
)
return migration_3

View File

@ -0,0 +1,94 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration4Callback:
"""Callback to do step 4 of migration."""
def __call__(self, cursor: sqlite3.Cursor) -> None: # noqa D102
self._create_model_metadata(cursor)
self._create_model_tags(cursor)
self._create_tags(cursor)
self._create_triggers(cursor)
def _create_model_metadata(self, cursor: sqlite3.Cursor) -> None:
"""Create the table used to store model metadata downloaded from remote sources."""
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_metadata (
id TEXT NOT NULL,
name TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.name')) VIRTUAL NOT NULL,
author TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.author')) VIRTUAL NOT NULL,
-- Serialized JSON representation of the whole metadata object,
-- which will contain additional fields from subclasses
metadata TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
FOREIGN KEY(id) REFERENCES model_config(id)
);
"""
)
def _create_model_tags(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_tags (
id TEXT NOT NULL,
tag_id INTEGER NOT NULL,
FOREIGN KEY(id) REFERENCES model_config(id),
FOREIGN KEY(tag_id) REFERENCES tags(tag_id),
UNIQUE(id,tag_id)
);
"""
)
def _create_tags(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS tags (
tag_id INTEGER NOT NULL PRIMARY KEY,
tag_text TEXT NOT NULL UNIQUE
);
"""
)
def _create_triggers(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS model_metadata_updated_at
AFTER UPDATE
ON model_metadata FOR EACH ROW
BEGIN
UPDATE model_metadata SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
)
cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS model_config_deleted
AFTER DELETE
ON model_config
BEGIN
DELETE from model_metadata WHERE id=old.id;
DELETE from model_tags WHERE id=old.id;
END;
"""
)
def build_migration_4() -> Migration:
"""
Build the migration from database version 3 to 4.
Adds the tables needed to store model metadata and tags.
"""
migration_4 = Migration(
from_version=3,
to_version=4,
callback=Migration4Callback(),
)
return migration_4

View File

@ -23,7 +23,6 @@ from invokeai.backend.model_manager.config import (
ModelType,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.util.logging import InvokeAILogger
ModelsValidator = TypeAdapter(AnyModelConfig)
@ -46,10 +45,9 @@ class MigrateModelYamlToDb1:
logger: Logger
cursor: sqlite3.Cursor
def __init__(self, cursor: sqlite3.Cursor = None) -> None:
self.config = InvokeAIAppConfig.get_config()
self.config.parse_args()
self.logger = InvokeAILogger.get_logger()
def __init__(self, config: InvokeAIAppConfig, logger: Logger, cursor: sqlite3.Cursor = None) -> None:
self.config = config
self.logger = logger
self.cursor = cursor
def get_yaml(self) -> DictConfig:

View File

@ -0,0 +1,43 @@
"""
Initialization file for invokeai.backend.model_manager.metadata
Usage:
from invokeai.backend.model_manager.metadata import(
AnyModelRepoMetadata,
CommercialUsage,
LicenseRestrictions,
HuggingFaceMetadata,
CivitaiMetadata,
)
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split")
assert isinstance(data, CivitaiMetadata)
if data.allow_commercial_use:
print("Commercial use of this model is allowed")
"""
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch
from .metadata_base import (
AnyModelRepoMetadata,
AnyModelRepoMetadataValidator,
CivitaiMetadata,
CommercialUsage,
HuggingFaceMetadata,
LicenseRestrictions,
)
from .metadata_store import ModelMetadataStore
__all__ = [
"AnyModelRepoMetadata",
"AnyModelRepoMetadataValidator",
"CommercialUsage",
"LicenseRestrictions",
"HuggingFaceMetadata",
"CivitaiMetadata",
"ModelMetadataStore",
"CivitaiMetadataFetch",
"HuggingFaceMetadataFetch",
]

View File

@ -0,0 +1,21 @@
"""
Initialization file for invokeai.backend.model_manager.metadata.fetch
Usage:
from invokeai.backend.model_manager.metadata.fetch import (
CivitaiMetadataFetch,
HuggingFaceMetadataFetch,
)
from invokeai.backend.model_manager.metadata import CivitaiMetadata
data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split")
assert isinstance(data, CivitaiMetadata)
if data.allow_commercial_use:
print("Commercial use of this model is allowed")
"""
from .fetch_base import ModelMetadataFetchBase
from .fetch_civitai import CivitaiMetadataFetch
from .fetch_huggingface import HuggingFaceMetadataFetch
__all__ = ["ModelMetadataFetchBase", "CivitaiMetadataFetch", "HuggingFaceMetadataFetch"]

View File

@ -0,0 +1,63 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
This module is the base class for subclasses that fetch metadata from model repositories
Usage:
from invokeai.backend.model_manager.metadata.fetch import CivitAIMetadataFetch
fetcher = CivitaiMetadataFetch()
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
print(metadata.trained_words)
"""
from abc import ABC, abstractmethod
from typing import Optional
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator
class ModelMetadataFetchBase(ABC):
"""Fetch metadata from remote generative model repositories."""
@abstractmethod
def __init__(self, session: Optional[Session] = None):
"""
Initialize the fetcher with an optional requests.sessions.Session object.
By providing a configurable Session object, we can support unit tests on
this module without an internet connection.
"""
pass
@abstractmethod
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""
Given a URL to a model repository, return a ModelMetadata object.
This method will raise an `invokeai.app.services.model_records.UnknownModelException`
in the event that the requested model metadata is not found at the provided location.
"""
pass
@abstractmethod
def from_id(self, id: str) -> AnyModelRepoMetadata:
"""
Given an ID for a model, return a ModelMetadata object.
This method will raise an `invokeai.app.services.model_records.UnknownModelException`
in the event that the requested model's metadata is not found at the provided id.
"""
pass
@classmethod
def from_json(self, json: str) -> AnyModelRepoMetadata:
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
metadata = AnyModelRepoMetadataValidator.validate_json(json)
return (
metadata # mypy complains that metadata is a <typing special form> and issues a type checking error. Why?
)

View File

@ -0,0 +1,157 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
This module fetches model metadata objects from the Civitai model repository.
In addition to the `from_url()` and `from_id()` methods inherited from the
`ModelMetadataFetchBase` base class.
Civitai has two separate ID spaces: a model ID and a version ID. The
version ID corresponds to a specific model, and is the ID accepted by
`from_id()`. The model ID corresponds to a family of related models,
such as different training checkpoints or 16 vs 32-bit versions. The
`from_civitai_modelid()` method will accept a model ID and return the
metadata from the default version within this model set. The default
version is the same as what the user sees when they click on a model's
thumbnail.
Usage:
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
fetcher = CivitaiMetadataFetch()
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
print(metadata.trained_words)
"""
import re
from datetime import datetime
from typing import Any, Dict, Optional
import requests
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from invokeai.app.services.model_records import UnknownModelException
from ..metadata_base import (
AnyModelRepoMetadata,
AnyModelRepoMetadataValidator,
CivitaiMetadata,
CommercialUsage,
LicenseRestrictions,
)
from .fetch_base import ModelMetadataFetchBase
CIVITAI_MODEL_PAGE_RE = r"https?://civitai.com/models/(\d+)"
CIVITAI_VERSION_PAGE_RE = r"https?://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
CIVITAI_DOWNLOAD_RE = r"https?://civitai.com/api/download/models/(\d+)"
CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
class CivitaiMetadataFetch(ModelMetadataFetchBase):
"""Fetch model metadata from Civitai."""
_requests: Session
def __init__(self, session: Optional[Session] = None):
"""
Initialize the fetcher with an optional requests.sessions.Session object.
By providing a configurable Session object, we can support unit tests on
this module without an internet connection.
"""
self._requests = session or requests.Session()
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""
Given a URL to a CivitAI model or version page, return a ModelMetadata object.
In the event that the URL points to a model page without the particular version
indicated, the default model version is returned. Otherwise, the requested version
is returned.
"""
if match := re.match(CIVITAI_MODEL_PAGE_RE, str(url)):
model_id = match.group(1)
return self.from_civitai_modelid(int(model_id))
elif match := re.match(CIVITAI_VERSION_PAGE_RE, str(url)):
version_id = match.group(1)
return self.from_civitai_versionid(int(version_id))
elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url)):
version_id = match.group(1)
return self.from_civitai_versionid(int(version_id))
raise UnknownModelException("The url '{url}' does not match any known Civitai URL patterns")
def from_id(self, id: str) -> AnyModelRepoMetadata:
"""
Given a Civitai model version ID, return a ModelRepoMetadata object.
May raise an `UnknownModelException`.
"""
return self.from_civitai_versionid(int(id))
def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata:
"""
Return metadata from the default version of the indicated model.
May raise an `UnknownModelException`.
"""
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model_json = self._requests.get(model_url).json()
return self._from_model_json(model_json)
def _from_model_json(self, model_json: Dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
version_id = version_id or model_json["modelVersions"][0]["id"]
# loop till we find the section containing the version requested
version_sections = [x for x in model_json["modelVersions"] if x["id"] == version_id]
if not version_sections:
raise UnknownModelException(f"Version {version_id} not found in model metadata")
version_json = version_sections[0]
safe_thumbnails = [x["url"] for x in version_json["images"] if x["nsfw"] == "None"]
return CivitaiMetadata(
id=model_json["id"],
name=model_json["name"],
version_id=version_json["id"],
version_name=version_json["name"],
created=datetime.fromisoformat(re.sub(r"Z$", "+00:00", version_json["createdAt"])),
updated=datetime.fromisoformat(re.sub(r"Z$", "+00:00", version_json["updatedAt"])),
published=datetime.fromisoformat(re.sub(r"Z$", "+00:00", version_json["publishedAt"])),
base_model_trained_on=version_json["baseModel"], # note - need a dictionary to turn into a BaseModelType
download_url=version_json["downloadUrl"],
thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None,
author=model_json["creator"]["username"],
description=model_json["description"],
version_description=version_json["description"] or "",
tags=model_json["tags"],
trained_words=version_json["trainedWords"],
nsfw=model_json["nsfw"],
restrictions=LicenseRestrictions(
AllowNoCredit=model_json["allowNoCredit"],
AllowCommercialUse=CommercialUsage(model_json["allowCommercialUse"]),
AllowDerivatives=model_json["allowDerivatives"],
AllowDifferentLicense=model_json["allowDifferentLicense"],
),
)
def from_civitai_versionid(self, version_id: int) -> CivitaiMetadata:
"""
Return a CivitaiMetadata object given a model version id.
May raise an `UnknownModelException`.
"""
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
version = self._requests.get(version_url).json()
model_url = CIVITAI_MODEL_ENDPOINT + str(version["modelId"])
model_json = self._requests.get(model_url).json()
return self._from_model_json(model_json, version_id)
@classmethod
def from_json(cls, json: str) -> CivitaiMetadata:
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
metadata = AnyModelRepoMetadataValidator.validate_json(json)
assert isinstance(metadata, CivitaiMetadata)
return metadata

View File

@ -0,0 +1,84 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
This module fetches model metadata objects from the HuggingFace model repository,
using either a `repo_id` or the model page URL.
Usage:
from invokeai.backend.model_manager.metadata.fetch import HuggingFaceMetadataFetch
fetcher = HuggingFaceMetadataFetch()
metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
print(metadata.tags)
"""
import re
from pathlib import Path
from typing import Optional
import requests
from huggingface_hub import HfApi, configure_http_backend
from huggingface_hub.utils._errors import RepositoryNotFoundError
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from invokeai.app.services.model_records import UnknownModelException
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator, HuggingFaceMetadata
from .fetch_base import ModelMetadataFetchBase
HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)"
class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
"""Fetch model metadata from HuggingFace."""
_requests: Session
def __init__(self, session: Optional[Session] = None):
"""
Initialize the fetcher with an optional requests.sessions.Session object.
By providing a configurable Session object, we can support unit tests on
this module without an internet connection.
"""
self._requests = session or requests.Session()
configure_http_backend(backend_factory=lambda: self._requests)
def from_id(self, id: str) -> AnyModelRepoMetadata:
"""Return a HuggingFaceMetadata object given the model's repo_id."""
try:
model_info = HfApi().model_info(repo_id=id, files_metadata=True)
except RepositoryNotFoundError as excp:
raise UnknownModelException(f"'{id}' not found. See trace for details.") from excp
_, name = id.split("/")
return HuggingFaceMetadata(
id=model_info.modelId,
author=model_info.author,
name=name,
last_modified=model_info.lastModified,
tag_dict=model_info.card_data.to_dict(),
tags=model_info.tags,
files=[Path(x.rfilename) for x in model_info.siblings],
)
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""
Return a HuggingFaceMetadata object given the model's web page URL.
In the case of an invalid or missing URL, raises a ModelNotFound exception.
"""
if match := re.match(HF_MODEL_RE, str(url)):
repo_id = match.group(1)
return self.from_id(repo_id)
else:
raise UnknownModelException(f"'{url}' does not look like a HuggingFace model page")
@classmethod
def from_json(cls, json: str) -> HuggingFaceMetadata:
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
metadata = AnyModelRepoMetadataValidator.validate_json(json)
assert isinstance(metadata, HuggingFaceMetadata)
return metadata

View File

@ -0,0 +1,119 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""This module defines core text-to-image model metadata fields.
Metadata comprises any descriptive information that is not essential
for getting the model to run. For example "author" is metadata, while
"type", "base" and "format" are not. The latter fields are part of the
model's config, as defined in invokeai.backend.model_manager.config.
Note that the "name" and "description" are also present in `config`
records. This is intentional. The config record fields are intended to
be editable by the user as a form of customization. The metadata
versions of these fields are intended to be kept in sync with the
remote repo.
"""
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Set, Union
from pydantic import BaseModel, Field, TypeAdapter
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
class CommercialUsage(str, Enum):
"""Type of commercial usage allowed."""
No = "None"
Image = "Image"
Rent = "Rent"
RentCivit = "RentCivit"
Sell = "Sell"
class LicenseRestrictions(BaseModel):
"""Broad categories of licensing restrictions."""
AllowNoCredit: bool = Field(
description="if true, model can be redistributed without crediting author", default=False
)
AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False)
AllowDifferentLicense: bool = Field(
description="if true, derivatives of this model be redistributed under a different license", default=False
)
AllowCommercialUse: CommercialUsage = Field(
description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default_factory=set
)
class ModelMetadataBase(BaseModel):
"""Base class for model metadata information."""
name: str = Field(description="model's name")
author: str = Field(description="model's author")
tags: Set[str] = Field(description="tags provided by model source")
class HuggingFaceMetadata(ModelMetadataBase):
"""Extended metadata fields provided by HuggingFace."""
type: Literal["huggingface"] = "huggingface"
id: str = Field(description="huggingface model id")
tag_dict: Dict[str, Any]
last_modified: datetime = Field(description="date of last commit to repo")
files: List[Path] = Field(description="sibling files that belong to this model", default_factory=list)
class CivitaiMetadata(ModelMetadataBase):
"""Extended metadata fields provided by Civitai."""
type: Literal["civitai"] = "civitai"
id: int = Field(description="Civitai model identifier")
version_name: str = Field(description="Version identifier, such as 'V2-alpha'")
version_id: int = Field(description="Civitai model version identifier")
created: datetime = Field(description="date the model was created")
updated: datetime = Field(description="date the model was last modified")
published: datetime = Field(description="date the model was published to Civitai")
description: str = Field(description="text description of model; may contain HTML")
version_description: str = Field(
description="text description of the model's reversion; usually change history; may contain HTML"
)
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
restrictions: LicenseRestrictions = Field(description="license terms", default_factory=LicenseRestrictions)
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
download_url: AnyHttpUrl = Field(description="download URL for this model")
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None)
weight_min: float = Field(
description="minimum suggested value for a LoRA or other secondary model", default=-1.0
) # note: For future use; not currently easily
weight_max: float = Field(
description="maximum suggested value for a LoRA or other secondary model", default=+2.0
) # recoverable from metadata
@property
def credit_required(self) -> bool:
"""Return True if you must give credit for derivatives of this model and images generated from it."""
return not self.restrictions.AllowNoCredit
@property
def allow_commercial_use(self) -> bool:
"""Return True if commercial use is allowed."""
return self.restrictions.AllowCommercialUse != CommercialUsage("None")
@property
def allow_derivatives(self) -> bool:
"""Return True if derivatives of this model can be redistributed."""
return self.restrictions.AllowDerivatives
@property
def allow_different_license(self) -> bool:
"""Return true if derivatives of this model can use a different license."""
return self.restrictions.AllowDifferentLicense
AnyModelRepoMetadata = Annotated[Union[HuggingFaceMetadata, CivitaiMetadata], Field(discriminator="type")]
AnyModelRepoMetadataValidator = TypeAdapter(AnyModelRepoMetadata)

View File

@ -0,0 +1,208 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Storage for Model Metadata
"""
import sqlite3
from typing import Optional, Set
from invokeai.app.services.model_records import UnknownModelException
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .fetch import ModelMetadataFetchBase
from .metadata_base import AnyModelRepoMetadata
class ModelMetadataStore:
"""Store, search and fetch model metadata retrieved from remote repositories."""
_db: SqliteDatabase
_cursor: sqlite3.Cursor
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
self._enable_foreign_key_constraints()
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
json_serialized = metadata.model_dump_json()
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_metadata(
id,
metadata
)
VALUES (?,?);
""",
(
model_key,
json_serialized,
),
)
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback()
raise UnknownModelException from excp
except sqlite3.Error as excp:
self._db.conn.rollback()
raise excp
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata FROM model_metadata
WHERE id=?;
""",
(model_key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException("model metadata not found")
return ModelMetadataFetchBase.from_json(rows[0])
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
json_serialized = metadata.model_dump_json() # turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_metadata
SET
metadata=?
WHERE id=?;
""",
(json_serialized, model_key),
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_metadata(model_key)
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
with self._db.lock:
try:
matches: Optional[Set[str]] = None
for tag in tags:
self._cursor.execute(
"""--sql
SELECT a.id FROM model_tags AS a,
tags AS b
WHERE a.tag_id=b.tag_id
AND b.tag_text=?;
""",
(tag,),
)
model_keys = {x[0] for x in self._cursor.fetchall()}
if matches is None:
matches = model_keys
matches = matches.intersection(model_keys)
except sqlite3.Error as e:
raise e
return matches
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE author=?;
""",
(author,),
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE name=?;
""",
(name,),
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
"""Update tags for the model referenced by model_key."""
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)
def _enable_foreign_key_constraints(self) -> None:
self._cursor.execute("PRAGMA foreign_keys = ON;")

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,244 @@
"""
Test model metadata fetching and storage.
"""
import datetime
from pathlib import Path
import pytest
import requests
from pydantic.networks import HttpUrl
from requests.sessions import Session
from requests_testadapter import TestAdapter
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceSQL, UnknownModelException
from invokeai.backend.model_manager.config import (
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.metadata import (
CivitaiMetadata,
CivitaiMetadataFetch,
CommercialUsage,
HuggingFaceMetadata,
HuggingFaceMetadataFetch,
ModelMetadataStore,
)
from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager_2.model_metadata.metadata_examples import (
RepoCivitaiModelMetadata1,
RepoCivitaiVersionMetadata1,
RepoHFMetadata1,
)
from tests.fixtures.sqlite_database import create_mock_sqlite_database
@pytest.fixture
def app_config(datadir: Path) -> InvokeAIAppConfig:
return InvokeAIAppConfig(
root=datadir / "root",
models_dir=datadir / "root/models",
)
@pytest.fixture
def record_store(app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
logger = InvokeAILogger.get_logger(config=app_config)
db = create_mock_sqlite_database(app_config, logger)
store = ModelRecordServiceSQL(db)
# add three simple config records to the database
raw1 = {
"path": "/tmp/foo1",
"format": ModelFormat("diffusers"),
"name": "test2",
"base": BaseModelType("sd-2"),
"type": ModelType("vae"),
"original_hash": "111222333444",
"source": "stabilityai/sdxl-vae",
}
raw2 = {
"path": "/tmp/foo2.ckpt",
"name": "model1",
"format": ModelFormat("checkpoint"),
"base": BaseModelType("sd-1"),
"type": "main",
"config": "/tmp/foo.yaml",
"variant": "normal",
"original_hash": "111222333444",
"source": "https://civitai.com/models/206883/split",
}
raw3 = {
"path": "/tmp/foo3",
"format": ModelFormat("diffusers"),
"name": "test3",
"base": BaseModelType("sdxl"),
"type": ModelType("main"),
"original_hash": "111222333444",
"source": "author3/model3",
}
store.add_model("test_config_1", raw1)
store.add_model("test_config_2", raw2)
store.add_model("test_config_3", raw3)
return store
@pytest.fixture
def session() -> Session:
sess = requests.Session()
sess.mount(
"https://huggingface.co/api/models/stabilityai/sdxl-turbo",
TestAdapter(
RepoHFMetadata1,
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
),
)
sess.mount(
"https://civitai.com/api/v1/model-versions/242807",
TestAdapter(
RepoCivitaiVersionMetadata1,
headers={
"Content-Length": len(RepoCivitaiVersionMetadata1),
},
),
)
sess.mount(
"https://civitai.com/api/v1/models/215485",
TestAdapter(
RepoCivitaiModelMetadata1,
headers={
"Content-Length": len(RepoCivitaiModelMetadata1),
},
),
)
return sess
@pytest.fixture
def metadata_store(record_store: ModelRecordServiceSQL) -> ModelMetadataStore:
db = record_store._db # to ensure we are sharing the same database
return ModelMetadataStore(db)
def test_metadata_store_put_get(metadata_store: ModelMetadataStore) -> None:
input_metadata = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
tags={"text-to-image", "diffusers"},
id="stabilityai/sdxl-vae",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
metadata_store.add_metadata("test_config_1", input_metadata)
output_metadata = metadata_store.get_metadata("test_config_1")
assert input_metadata == output_metadata
with pytest.raises(UnknownModelException):
metadata_store.add_metadata("unknown_key", input_metadata)
def test_metadata_store_update(metadata_store: ModelMetadataStore) -> None:
input_metadata = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
tags={"text-to-image", "diffusers"},
id="stabilityai/sdxl-vae",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
metadata_store.add_metadata("test_config_1", input_metadata)
input_metadata.name = "new-name"
metadata_store.update_metadata("test_config_1", input_metadata)
output_metadata = metadata_store.get_metadata("test_config_1")
assert output_metadata.name == "new-name"
assert input_metadata == output_metadata
def test_metadata_search(metadata_store: ModelMetadataStore) -> None:
metadata1 = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
tags={"text-to-image", "diffusers"},
id="stabilityai/sdxl-vae",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
metadata2 = HuggingFaceMetadata(
name="model2",
author="stabilityai",
tags={"text-to-image", "diffusers", "community-contributed"},
id="author2/model2",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
metadata3 = HuggingFaceMetadata(
name="model3",
author="author3",
tags={"text-to-image", "checkpoint", "community-contributed"},
id="author3/model3",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
metadata_store.add_metadata("test_config_1", metadata1)
metadata_store.add_metadata("test_config_2", metadata2)
metadata_store.add_metadata("test_config_3", metadata3)
matches = metadata_store.search_by_author("stabilityai")
assert len(matches) == 2
assert "test_config_1" in matches
assert "test_config_2" in matches
matches = metadata_store.search_by_author("Sherlock Holmes")
assert not matches
matches = metadata_store.search_by_name("model3")
assert len(matches) == 1
assert "test_config_3" in matches
matches = metadata_store.search_by_tag({"text-to-image"})
assert len(matches) == 3
matches = metadata_store.search_by_tag({"text-to-image", "diffusers"})
assert len(matches) == 2
assert "test_config_1" in matches
assert "test_config_2" in matches
matches = metadata_store.search_by_tag({"checkpoint", "community-contributed"})
assert len(matches) == 1
assert "test_config_3" in matches
# does the tag table update correctly?
matches = metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
assert not matches
metadata3.tags.add("licensed-for-commercial-use")
metadata_store.update_metadata("test_config_3", metadata3)
matches = metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
assert len(matches) == 1
def test_metadata_civitai_fetch(session: Session) -> None:
fetcher = CivitaiMetadataFetch(session)
metadata = fetcher.from_url(HttpUrl("https://civitai.com/models/215485/SDXL-turbo"))
assert isinstance(metadata, CivitaiMetadata)
assert metadata.id == 215485
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
assert metadata.restrictions.AllowCommercialUse == CommercialUsage("RentCivit")
assert metadata.version_id == 242807
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}
def test_metadata_hf_fetch(session: Session) -> None:
fetcher = HuggingFaceMetadataFetch(session)
metadata = fetcher.from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
assert isinstance(metadata, HuggingFaceMetadata)
assert metadata.author == "test_author" # this is not the same as the original
assert metadata.files
assert metadata.tags == {
"diffusers",
"onnx",
"safetensors",
"text-to-image",
"license:other",
"has_space",
"diffusers:StableDiffusionXLPipeline",
"region:us",
}