mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add storage
This commit is contained in:
parent
c610283158
commit
e86f3fe29e
@ -5,6 +5,7 @@ from invokeai.app.services.image_files.image_files_base import ImageFileStorageB
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import build_migration_2
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@ -27,6 +28,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator = SqliteMigrator(db=db)
|
||||
migrator.register_migration(build_migration_1())
|
||||
migrator.register_migration(build_migration_2(image_files=image_files, logger=logger))
|
||||
migrator.register_migration(build_migration_4())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
@ -0,0 +1,94 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration4Callback:
|
||||
"""Callback to do step 4 of migration."""
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None: # noqa D102
|
||||
self._create_model_metadata(cursor)
|
||||
self._create_model_tags(cursor)
|
||||
self._create_tags(cursor)
|
||||
self._create_triggers(cursor)
|
||||
|
||||
def _create_model_metadata(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Create the table used to store model metadata downloaded from remote sources."""
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_metadata (
|
||||
id TEXT NOT NULL,
|
||||
name TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.name')) VIRTUAL NOT NULL,
|
||||
author TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.author')) VIRTUAL NOT NULL,
|
||||
-- Serialized JSON representation of the whole metadata object,
|
||||
-- which will contain additional fields from subclasses
|
||||
metadata TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
FOREIGN KEY(id) REFERENCES model_config(id)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def _create_model_tags(self, cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_tags (
|
||||
id TEXT NOT NULL,
|
||||
tag_id INTEGER NOT NULL,
|
||||
FOREIGN KEY(id) REFERENCES model_config(id),
|
||||
FOREIGN KEY(tag_id) REFERENCES tags(tag_id),
|
||||
UNIQUE(id,tag_id)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def _create_tags(self, cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS tags (
|
||||
tag_id INTEGER NOT NULL PRIMARY KEY,
|
||||
tag_text TEXT NOT NULL UNIQUE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def _create_triggers(self, cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS model_metadata_updated_at
|
||||
AFTER UPDATE
|
||||
ON model_metadata FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE model_metadata SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS model_config_deleted
|
||||
AFTER DELETE
|
||||
ON model_config
|
||||
BEGIN
|
||||
DELETE from model_metadata WHERE id=old.id;
|
||||
DELETE from model_tags WHERE id=old.id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def build_migration_4() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 3 to 4.
|
||||
|
||||
Adds the tables needed to store model metadata and tags.
|
||||
"""
|
||||
migration_4 = Migration(
|
||||
from_version=2, # until migration_3 is merged, pretend we are doing 2-3
|
||||
to_version=3,
|
||||
callback=Migration4Callback(),
|
||||
)
|
||||
|
||||
return migration_4
|
@ -0,0 +1,38 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_manager.metadata
|
||||
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.model_manager.metadata import(
|
||||
AnyModelRepoMetadata,
|
||||
CommercialUsage,
|
||||
LicenseRestrictions,
|
||||
HuggingFaceMetadata,
|
||||
CivitaiMetadata,
|
||||
)
|
||||
|
||||
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
|
||||
|
||||
data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split")
|
||||
assert isinstance(data, CivitaiMetadata)
|
||||
if data.allow_commercial_use:
|
||||
print("Commercial use of this model is allowed")
|
||||
"""
|
||||
|
||||
from .metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
AnyModelRepoMetadataValidator,
|
||||
CivitaiMetadata,
|
||||
CommercialUsage,
|
||||
HuggingFaceMetadata,
|
||||
LicenseRestrictions,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AnyModelRepoMetadata",
|
||||
"AnyModelRepoMetadataValidator",
|
||||
"CommercialUsage",
|
||||
"LicenseRestrictions",
|
||||
"HuggingFaceMetadata",
|
||||
"CivitaiMetadata",
|
||||
]
|
@ -1,139 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
This module fetches metadata objects from HuggingFace and Civitai.
|
||||
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.model_manager.metadata.fetch import MetadataFetch
|
||||
|
||||
metadata = MetadataFetch.from_civitai_url("https://civitai.com/models/58390/detail-tweaker-lora-lora")
|
||||
print(metadata.description)
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
from huggingface_hub import HfApi, configure_http_backend
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
|
||||
from .base import (
|
||||
CivitaiMetadata,
|
||||
HuggingFaceMetadata,
|
||||
LicenseRestrictions,
|
||||
CommercialUsage,
|
||||
)
|
||||
|
||||
HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)"
|
||||
CIVITAI_MODEL_PAGE_RE = r"https?://civitai.com/models/(\d+)"
|
||||
CIVITAI_VERSION_PAGE_RE = r"https?://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
|
||||
CIVITAI_DOWNLOAD_RE = r"https?://civitai.com/api/download/models/(\d+)"
|
||||
|
||||
CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
|
||||
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
|
||||
|
||||
class MetadataFetch:
|
||||
"""Fetch metadata from HuggingFace and Civitai URLs."""
|
||||
|
||||
_requests: Session
|
||||
|
||||
def __init__(self, session: Optional[Session]=None):
|
||||
"""
|
||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||
|
||||
By providing a configurable Session object, we can support unit tests on
|
||||
this module without an internet connection.
|
||||
"""
|
||||
self._requests = session or requests.Session()
|
||||
configure_http_backend(backend_factory = lambda: self._requests)
|
||||
|
||||
def from_huggingface_repoid(self, repo_id: str) -> HuggingFaceMetadata:
|
||||
"""Return a HuggingFaceMetadata object given the model's repo_id."""
|
||||
try:
|
||||
model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True)
|
||||
except RepositoryNotFoundError as excp:
|
||||
raise UnknownModelException(f"'{repo_id}' not found. See trace for details.") from excp
|
||||
|
||||
_, name = repo_id.split("/")
|
||||
return HuggingFaceMetadata(
|
||||
id = model_info.modelId,
|
||||
author = model_info.author,
|
||||
name = name,
|
||||
last_modified = model_info.lastModified,
|
||||
tags = model_info.tags,
|
||||
tag_dict = model_info.card_data.to_dict(),
|
||||
files = [Path(x.rfilename) for x in model_info.siblings]
|
||||
)
|
||||
|
||||
def from_huggingface_url(self, url: AnyHttpUrl) -> HuggingFaceMetadata:
|
||||
"""
|
||||
Return a HuggingFaceMetadata object given the model's web page URL.
|
||||
|
||||
In the case of an invalid or missing URL, raises a ModelNotFound exception.
|
||||
"""
|
||||
if match := re.match(HF_MODEL_RE, str(url)):
|
||||
repo_id = match.group(1)
|
||||
return self.from_huggingface_repoid(repo_id)
|
||||
else:
|
||||
raise UnknownModelException(f"'{url}' does not look like a HuggingFace model page")
|
||||
|
||||
def from_civitai_versionid(self, version_id: int, model_metadata: Optional[Dict[str,Any]]=None) -> CivitaiMetadata:
|
||||
"""Return Civitai metadata using a model's version id."""
|
||||
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
||||
version = self._requests.get(version_url).json()
|
||||
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(version['modelId'])
|
||||
model = model_metadata or self._requests.get(model_url).json()
|
||||
safe_thumbnails = [x['url'] for x in version['images'] if x['nsfw']=='None']
|
||||
|
||||
return CivitaiMetadata(
|
||||
id=version['modelId'],
|
||||
name=version['model']['name'],
|
||||
version_id=version['id'],
|
||||
version_name=version['name'],
|
||||
created=datetime.fromisoformat(re.sub(r"Z$", "+00:00", version['createdAt'])),
|
||||
base_model_trained_on=version['baseModel'], # note - need a dictionary to turn into a BaseModelType
|
||||
download_url=version['downloadUrl'],
|
||||
thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None,
|
||||
author=model['creator']['username'],
|
||||
description=model['description'],
|
||||
version_description=version['description'] or "",
|
||||
tags=model['tags'],
|
||||
trained_words=version['trainedWords'],
|
||||
nsfw=version['model']['nsfw'],
|
||||
restrictions=LicenseRestrictions(
|
||||
AllowNoCredit=model['allowNoCredit'],
|
||||
AllowCommercialUse=CommercialUsage(model['allowCommercialUse']),
|
||||
AllowDerivatives=model['allowDerivatives'],
|
||||
AllowDifferentLicense=model['allowDifferentLicense']
|
||||
),
|
||||
)
|
||||
|
||||
def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata:
|
||||
"""Return metadata from the default version of the indicated model."""
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model = self._requests.get(model_url).json()
|
||||
default_version = model['modelVersions'][0]['id']
|
||||
return self.from_civitai_versionid(default_version, model)
|
||||
|
||||
def from_civitai_url(self, url: AnyHttpUrl) -> CivitaiMetadata:
|
||||
"""Parse a Civitai URL that user is likely to pass and return its metadata."""
|
||||
if match := re.match(CIVITAI_MODEL_PAGE_RE, str(url)):
|
||||
model_id = match.group(1)
|
||||
return self.from_civitai_modelid(int(model_id))
|
||||
elif match := re.match(CIVITAI_VERSION_PAGE_RE, str(url)):
|
||||
version_id = match.group(1)
|
||||
return self.from_civitai_versionid(int(version_id))
|
||||
elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url)):
|
||||
version_id = match.group(1)
|
||||
return self.from_civitai_versionid(int(version_id))
|
||||
raise UnknownModelException("The url '{url}' does not match any known Civitai URL patterns")
|
||||
|
||||
|
21
invokeai/backend/model_manager/metadata/fetch/__init__.py
Normal file
21
invokeai/backend/model_manager/metadata/fetch/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_manager.metadata.fetch
|
||||
|
||||
Usage:
|
||||
from invokeai.backend.model_manager.metadata.fetch import (
|
||||
CivitaiMetadataFetch,
|
||||
HuggingFaceMetadataFetch,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import CivitaiMetadata
|
||||
|
||||
data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split")
|
||||
assert isinstance(data, CivitaiMetadata)
|
||||
if data.allow_commercial_use:
|
||||
print("Commercial use of this model is allowed")
|
||||
"""
|
||||
|
||||
from .fetch_base import ModelMetadataFetchBase
|
||||
from .fetch_civitai import CivitaiMetadataFetch
|
||||
from .fetch_huggingface import HuggingFaceMetadataFetch
|
||||
|
||||
__all__ = ["ModelMetadataFetchBase", "CivitaiMetadataFetch", "HuggingFaceMetadataFetch"]
|
63
invokeai/backend/model_manager/metadata/fetch/fetch_base.py
Normal file
63
invokeai/backend/model_manager/metadata/fetch/fetch_base.py
Normal file
@ -0,0 +1,63 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
This module is the base class for subclasses that fetch metadata from model repositories
|
||||
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.model_manager.metadata.fetch import CivitAIMetadataFetch
|
||||
|
||||
fetcher = CivitaiMetadataFetch()
|
||||
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
|
||||
print(metadata.trained_words)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator
|
||||
|
||||
|
||||
class ModelMetadataFetchBase(ABC):
|
||||
"""Fetch metadata from remote generative model repositories."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, session: Optional[Session] = None):
|
||||
"""
|
||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||
|
||||
By providing a configurable Session object, we can support unit tests on
|
||||
this module without an internet connection.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Given a URL to a model repository, return a ModelMetadata object.
|
||||
|
||||
This method will raise an `invokeai.app.services.model_records.UnknownModelException`
|
||||
in the event that the requested model metadata is not found at the provided location.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def from_id(self, id: str) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Given an ID for a model, return a ModelMetadata object.
|
||||
|
||||
This method will raise an `invokeai.app.services.model_records.UnknownModelException`
|
||||
in the event that the requested model's metadata is not found at the provided id.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_json(self, json: str) -> AnyModelRepoMetadata:
|
||||
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
|
||||
metadata = AnyModelRepoMetadataValidator.validate_json(json)
|
||||
return (
|
||||
metadata # mypy complains that metadata is a <typing special form> and issues a type checking error. Why?
|
||||
)
|
150
invokeai/backend/model_manager/metadata/fetch/fetch_civitai.py
Normal file
150
invokeai/backend/model_manager/metadata/fetch/fetch_civitai.py
Normal file
@ -0,0 +1,150 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
This module fetches model metadata objects from the Civitai model repository.
|
||||
In addition to the `from_url()` and `from_id()` methods inherited from the
|
||||
`ModelMetadataFetchBase` base class.
|
||||
|
||||
Civitai has two separate ID spaces: a model ID and a version ID. The
|
||||
version ID corresponds to a specific model, and is the ID accepted by
|
||||
`from_id()`. The model ID corresponds to a family of related models,
|
||||
such as different training checkpoints or 16 vs 32-bit versions. The
|
||||
`from_civitai_modelid()` method will accept a model ID and return the
|
||||
metadata from the default version within this model set. The default
|
||||
version is the same as what the user sees when they click on a model's
|
||||
thumbnail.
|
||||
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
|
||||
|
||||
fetcher = CivitaiMetadataFetch()
|
||||
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
|
||||
print(metadata.trained_words)
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
|
||||
from ..metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
AnyModelRepoMetadataValidator,
|
||||
CivitaiMetadata,
|
||||
CommercialUsage,
|
||||
LicenseRestrictions,
|
||||
)
|
||||
from .fetch_base import ModelMetadataFetchBase
|
||||
|
||||
CIVITAI_MODEL_PAGE_RE = r"https?://civitai.com/models/(\d+)"
|
||||
CIVITAI_VERSION_PAGE_RE = r"https?://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
|
||||
CIVITAI_DOWNLOAD_RE = r"https?://civitai.com/api/download/models/(\d+)"
|
||||
|
||||
CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
|
||||
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
|
||||
|
||||
|
||||
class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
"""Fetch model metadata from Civitai."""
|
||||
|
||||
_requests: Session
|
||||
|
||||
def __init__(self, session: Optional[Session] = None):
|
||||
"""
|
||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||
|
||||
By providing a configurable Session object, we can support unit tests on
|
||||
this module without an internet connection.
|
||||
"""
|
||||
self._requests = session or requests.Session()
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Given a URL to a CivitAI model or version page, return a ModelMetadata object.
|
||||
|
||||
In the event that the URL points to a model page without the particular version
|
||||
indicated, the default model version is returned. Otherwise, the requested version
|
||||
is returned.
|
||||
"""
|
||||
if match := re.match(CIVITAI_MODEL_PAGE_RE, str(url)):
|
||||
model_id = match.group(1)
|
||||
return self.from_civitai_modelid(int(model_id))
|
||||
elif match := re.match(CIVITAI_VERSION_PAGE_RE, str(url)):
|
||||
version_id = match.group(1)
|
||||
return self._from_civitai_versionid(int(version_id))
|
||||
elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url)):
|
||||
version_id = match.group(1)
|
||||
return self._from_civitai_versionid(int(version_id))
|
||||
raise UnknownModelException("The url '{url}' does not match any known Civitai URL patterns")
|
||||
|
||||
def from_id(self, id: str) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Given a Civitai model version ID, return a ModelRepoMetadata object.
|
||||
|
||||
May raise an `UnknownModelException`.
|
||||
"""
|
||||
return self._from_civitai_versionid(int(id))
|
||||
|
||||
def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata:
|
||||
"""
|
||||
Return metadata from the default version of the indicated model.
|
||||
|
||||
May raise an `UnknownModelException`.
|
||||
"""
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model = self._requests.get(model_url).json()
|
||||
default_version = model["modelVersions"][0]["id"]
|
||||
return self._from_civitai_versionid(default_version, model)
|
||||
|
||||
def _from_civitai_versionid(
|
||||
self, version_id: int, model_metadata: Optional[Dict[str, Any]] = None
|
||||
) -> CivitaiMetadata:
|
||||
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
||||
version = self._requests.get(version_url).json()
|
||||
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(version["modelId"])
|
||||
model = model_metadata or self._requests.get(model_url).json()
|
||||
safe_thumbnails = [x["url"] for x in version["images"] if x["nsfw"] == "None"]
|
||||
|
||||
# It would be more elegant to define a Pydantic BaseModel that matches the Civitai metadata JSON.
|
||||
# However the contents of the JSON does not exactly match the documentation at
|
||||
# https://github.com/civitai/civitai/wiki/REST-API-Reference, and it feels safer to cherry pick
|
||||
# a subset of the fields.
|
||||
#
|
||||
# In addition, there are some fields that I want to pick up from the model JSON, such as `tags`,
|
||||
# that are not present in the version JSON.
|
||||
return CivitaiMetadata(
|
||||
id=version["modelId"],
|
||||
name=version["model"]["name"],
|
||||
version_id=version["id"],
|
||||
version_name=version["name"],
|
||||
created=datetime.fromisoformat(re.sub(r"Z$", "+00:00", version["createdAt"])),
|
||||
base_model_trained_on=version["baseModel"], # note - need a dictionary to turn into a BaseModelType
|
||||
download_url=version["downloadUrl"],
|
||||
thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None,
|
||||
author=model["creator"]["username"],
|
||||
description=model["description"],
|
||||
version_description=version["description"] or "",
|
||||
tags=model["tags"],
|
||||
trained_words=version["trainedWords"],
|
||||
nsfw=version["model"]["nsfw"],
|
||||
restrictions=LicenseRestrictions(
|
||||
AllowNoCredit=model["allowNoCredit"],
|
||||
AllowCommercialUse=CommercialUsage(model["allowCommercialUse"]),
|
||||
AllowDerivatives=model["allowDerivatives"],
|
||||
AllowDifferentLicense=model["allowDifferentLicense"],
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: str) -> CivitaiMetadata:
|
||||
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
|
||||
metadata = AnyModelRepoMetadataValidator.validate_json(json)
|
||||
assert isinstance(metadata, CivitaiMetadata)
|
||||
return metadata
|
@ -0,0 +1,83 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
This module fetches model metadata objects from the HuggingFace model repository,
|
||||
using either a `repo_id` or the model page URL.
|
||||
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.model_manager.metadata.fetch import HuggingFaceMetadataFetch
|
||||
|
||||
fetcher = HuggingFaceMetadataFetch()
|
||||
metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
|
||||
print(metadata.tags)
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from huggingface_hub import HfApi, configure_http_backend
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
|
||||
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator, HuggingFaceMetadata
|
||||
from .fetch_base import ModelMetadataFetchBase
|
||||
|
||||
|
||||
class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
"""Fetch model metadata from HuggingFace."""
|
||||
|
||||
_requests: Session
|
||||
|
||||
def __init__(self, session: Optional[Session] = None):
|
||||
"""
|
||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||
|
||||
By providing a configurable Session object, we can support unit tests on
|
||||
this module without an internet connection.
|
||||
"""
|
||||
self._requests = session or requests.Session()
|
||||
configure_http_backend(backend_factory=lambda: self._requests)
|
||||
|
||||
def from_id(self, id: str) -> AnyModelRepoMetadata:
|
||||
"""Return a HuggingFaceMetadata object given the model's repo_id."""
|
||||
try:
|
||||
model_info = HfApi().model_info(repo_id=id, files_metadata=True)
|
||||
except RepositoryNotFoundError as excp:
|
||||
raise UnknownModelException(f"'{id}' not found. See trace for details.") from excp
|
||||
|
||||
_, name = id.split("/")
|
||||
return HuggingFaceMetadata(
|
||||
id=model_info.modelId,
|
||||
author=model_info.author,
|
||||
name=name,
|
||||
last_modified=model_info.lastModified,
|
||||
tag_dict=model_info.card_data.to_dict(),
|
||||
tags=model_info.tags,
|
||||
files=[Path(x.rfilename) for x in model_info.siblings],
|
||||
)
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Return a HuggingFaceMetadata object given the model's web page URL.
|
||||
|
||||
In the case of an invalid or missing URL, raises a ModelNotFound exception.
|
||||
"""
|
||||
HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)"
|
||||
if match := re.match(HF_MODEL_RE, str(url)):
|
||||
repo_id = match.group(1)
|
||||
return self.from_id(repo_id)
|
||||
else:
|
||||
raise UnknownModelException(f"'{url}' does not look like a HuggingFace model page")
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: str) -> HuggingFaceMetadata:
|
||||
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
|
||||
metadata = AnyModelRepoMetadataValidator.validate_json(json)
|
||||
assert isinstance(metadata, HuggingFaceMetadata)
|
||||
return metadata
|
@ -15,10 +15,12 @@ This may need reworking.
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Set, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class CommercialUsage(str, Enum):
|
||||
"""Type of commercial usage allowed."""
|
||||
@ -29,13 +31,21 @@ class CommercialUsage(str, Enum):
|
||||
RentCivit = "RentCivit"
|
||||
Sell = "Sell"
|
||||
|
||||
|
||||
class LicenseRestrictions(BaseModel):
|
||||
"""Broad categories of licensing restrictions."""
|
||||
|
||||
AllowNoCredit: bool = Field(description="if true, model can be redistributed without crediting author", default=False)
|
||||
AllowNoCredit: bool = Field(
|
||||
description="if true, model can be redistributed without crediting author", default=False
|
||||
)
|
||||
AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False)
|
||||
AllowDifferentLicense: bool = Field(description="if true, derivatives of this model be redistributed under a different license", default=False)
|
||||
AllowCommercialUse: CommercialUsage = Field(description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default_factory=set)
|
||||
AllowDifferentLicense: bool = Field(
|
||||
description="if true, derivatives of this model be redistributed under a different license", default=False
|
||||
)
|
||||
AllowCommercialUse: CommercialUsage = Field(
|
||||
description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default_factory=set
|
||||
)
|
||||
|
||||
|
||||
class ModelMetadataBase(BaseModel):
|
||||
"""Base class for model metadata information."""
|
||||
@ -44,32 +54,41 @@ class ModelMetadataBase(BaseModel):
|
||||
author: str = Field(description="model's author")
|
||||
tags: Set[str] = Field(description="tags provided by model source")
|
||||
|
||||
|
||||
class HuggingFaceMetadata(ModelMetadataBase):
|
||||
"""Extended metadata fields provided by HuggingFace."""
|
||||
|
||||
type: Literal["huggingface"] = "huggingface"
|
||||
id: str = Field(description="huggingface model id")
|
||||
tag_dict: Dict[str, Any]
|
||||
last_modified: datetime = Field(description="date of last commit to repo")
|
||||
files: List[Path] = Field(description="sibling files that belong to this model", default_factory=list)
|
||||
|
||||
|
||||
class CivitaiMetadata(ModelMetadataBase):
|
||||
"""Extended metadata fields provided by Civitai."""
|
||||
|
||||
type: Literal["civitai"] = "civitai"
|
||||
id: int = Field(description="Civitai model identifier")
|
||||
version_name: str = Field(description="Version identifier, such as 'V2-alpha'")
|
||||
version_id: int = Field(description="Civitai model version identifier")
|
||||
created: datetime = Field(description="date the model was posted to CivitAI")
|
||||
description: str = Field(description="text description of model; may contain HTML")
|
||||
version_description: str = Field(description="text description of the model's reversion; usually change history; may contain HTML")
|
||||
version_description: str = Field(
|
||||
description="text description of the model's reversion; usually change history; may contain HTML"
|
||||
)
|
||||
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
|
||||
restrictions: LicenseRestrictions = Field(description="license terms", default=LicenseRestrictions)
|
||||
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
|
||||
download_url: AnyHttpUrl = Field(description="download URL for this model")
|
||||
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
|
||||
thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None)
|
||||
weight_min: float = Field(description="minimum suggested value for a LoRA or other secondary model", default=-1.0) # note: For future use; not currently easily
|
||||
weight_max: float = Field(description="maximum suggested value for a LoRA or other secondary model", default=+2.0) # recoverable from metadata
|
||||
|
||||
weight_min: float = Field(
|
||||
description="minimum suggested value for a LoRA or other secondary model", default=-1.0
|
||||
) # note: For future use; not currently easily
|
||||
weight_max: float = Field(
|
||||
description="maximum suggested value for a LoRA or other secondary model", default=+2.0
|
||||
) # recoverable from metadata
|
||||
|
||||
@property
|
||||
def credit_required(self) -> bool:
|
||||
@ -79,7 +98,7 @@ class CivitaiMetadata(ModelMetadataBase):
|
||||
@property
|
||||
def allow_commercial_use(self) -> bool:
|
||||
"""Return True if commercial use is allowed."""
|
||||
return self.restrictions.AllowCommercialUse == CommercialUsage('None')
|
||||
return self.restrictions.AllowCommercialUse == CommercialUsage("None")
|
||||
|
||||
@property
|
||||
def allow_derivatives(self) -> bool:
|
||||
@ -90,3 +109,7 @@ class CivitaiMetadata(ModelMetadataBase):
|
||||
def allow_different_license(self) -> bool:
|
||||
"""Return true if derivatives of this model can use a different license."""
|
||||
return self.restrictions.AllowDifferentLicense
|
||||
|
||||
|
||||
AnyModelRepoMetadata = Annotated[Union[HuggingFaceMetadata, CivitaiMetadata], Field(discriminator="type")]
|
||||
AnyModelRepoMetadataValidator = TypeAdapter(AnyModelRepoMetadata)
|
185
invokeai/backend/model_manager/metadata/metadata_store.py
Normal file
185
invokeai/backend/model_manager/metadata/metadata_store.py
Normal file
@ -0,0 +1,185 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
SQL Storage for Model Metadata
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from typing import Set
|
||||
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
from .fetch import ModelMetadataFetchBase
|
||||
from .metadata_base import AnyModelRepoMetadata
|
||||
|
||||
|
||||
class ModelMetadataStore:
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
_db: SqliteDatabase
|
||||
_cursor: sqlite3.Cursor
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
self._enable_foreign_key_constraints()
|
||||
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json()
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_metadata(
|
||||
id,
|
||||
metadata
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
model_key,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM model_metadata
|
||||
WHERE id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model metadata not found")
|
||||
return ModelMetadataFetchBase.from_json(rows[0])
|
||||
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_metadata
|
||||
SET
|
||||
metadata=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, model_key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_metadata(model_key)
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
matches: Set[str] = set()
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.id FROM model_tags AS a,
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
matches = matches.intersection(model_keys) if len(matches) > 0 else model_keys
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return matches
|
||||
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE author=?;
|
||||
""",
|
||||
(author,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
||||
|
||||
def _enable_foreign_key_constraints(self) -> None:
|
||||
self._cursor.execute("PRAGMA foreign_keys = ON;")
|
Loading…
Reference in New Issue
Block a user