mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix merge
This commit is contained in:
parent
efb5f2d202
commit
30228ce2a4
@ -5,6 +5,7 @@ import pathlib
|
||||
import shutil
|
||||
from hashlib import sha1
|
||||
from random import randbytes
|
||||
import traceback
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
@ -14,6 +15,7 @@ from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallJob
|
||||
from invokeai.app.services.model_metadata.metadata_store_base import ModelMetadataChanges
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
@ -32,6 +34,7 @@ from invokeai.backend.model_manager.config import (
|
||||
)
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import BaseMetadata
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
@ -242,6 +245,47 @@ async def get_model_metadata(
|
||||
|
||||
return result
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/i/{key}/metadata",
|
||||
operation_id="update_model_metadata",
|
||||
responses={
|
||||
201: {
|
||||
"description": "The model metadata was updated successfully",
|
||||
"content": {"application/json": {"example": example_model_metadata}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def update_model_metadata(
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
changes: ModelMetadataChanges = Body(description="The changes")
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Updates or creates a model metadata object."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store
|
||||
|
||||
try:
|
||||
original_metadata = record_store.get_metadata(key)
|
||||
print(original_metadata)
|
||||
if original_metadata:
|
||||
original_metadata.trigger_phrases = changes.trigger_phrases
|
||||
|
||||
metadata_store.update_metadata(key, original_metadata)
|
||||
else:
|
||||
metadata_store.add_metadata(key, BaseMetadata(name="", author="",trigger_phrases=changes.trigger_phrases))
|
||||
except Exception as e:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exception(e))
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An error occurred while updating the model metadata: {e}",
|
||||
)
|
||||
|
||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/tags",
|
||||
|
@ -84,16 +84,28 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
ti_list = []
|
||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
||||
name = trigger[1:-1]
|
||||
name_or_key = trigger[1:-1]
|
||||
print(f"name_or_key: {name_or_key}")
|
||||
try:
|
||||
loaded_model = context.models.load(key=name).model
|
||||
assert isinstance(loaded_model, TextualInversionModelRaw)
|
||||
ti_list.append((name, loaded_model))
|
||||
loaded_model = context.models.load(key=name_or_key)
|
||||
model = loaded_model.model
|
||||
print(model)
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
ti_list.append((name_or_key, model))
|
||||
except UnknownModelException:
|
||||
# print(e)
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
try:
|
||||
print(f"base: {text_encoder_info.config.base}")
|
||||
loaded_model = context.models.load_by_attrs(
|
||||
model_name=name_or_key, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
|
||||
)
|
||||
model = loaded_model.model
|
||||
print(model)
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
ti_list.append((name_or_key, model))
|
||||
except UnknownModelException:
|
||||
logger.warning(f'trigger: "{trigger}" not found')
|
||||
except ValueError:
|
||||
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
|
||||
|
||||
with (
|
||||
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
|
||||
|
@ -4,10 +4,23 @@ Storage for Model Metadata
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Set, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"):
|
||||
"""A set of changes to apply to model metadata.
|
||||
|
||||
Only limited changes are valid:
|
||||
- `trigger_phrases`: the list of trigger phrases for this model
|
||||
"""
|
||||
|
||||
trigger_phrases: Optional[List[str]] = Field(default=None, description="The model's list of trigger phrases")
|
||||
"""The model's list of trigger phrases"""
|
||||
|
||||
|
||||
class ModelMetadataStoreBase(ABC):
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
@ -38,6 +38,8 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json()
|
||||
print("json_serialized")
|
||||
print(json_serialized)
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
@ -53,7 +55,7 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
# self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||
self._db.conn.rollback()
|
||||
@ -61,6 +63,8 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||
except sqlite3.Error as excp:
|
||||
self._db.conn.rollback()
|
||||
raise excp
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
@ -115,6 +119,8 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return self.get_metadata(model_key)
|
||||
|
||||
@ -179,44 +185,45 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||
def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
if tags:
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
||||
|
@ -164,6 +164,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
AllowDerivatives=model_json["allowDerivatives"],
|
||||
AllowDifferentLicense=model_json["allowDifferentLicense"],
|
||||
),
|
||||
trigger_phrases=version_json["trainedWords"],
|
||||
)
|
||||
|
||||
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
|
@ -73,7 +73,8 @@ class ModelMetadataBase(BaseModel):
|
||||
|
||||
name: str = Field(description="model's name")
|
||||
author: str = Field(description="model's author")
|
||||
tags: Set[str] = Field(description="tags provided by model source")
|
||||
tags: Optional[Set[str]] = Field(description="tags provided by model source", default=None)
|
||||
trigger_phrases: Optional[List[str]] = Field(description="trigger phrases for this model", default=None)
|
||||
|
||||
|
||||
class BaseMetadata(ModelMetadataBase):
|
||||
|
@ -171,6 +171,8 @@ class ModelPatcher:
|
||||
text_encoder: CLIPTextModel,
|
||||
ti_list: List[Tuple[str, TextualInversionModelRaw]],
|
||||
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
|
||||
print("TI LIST")
|
||||
print(ti_list)
|
||||
init_tokens_count = None
|
||||
new_tokens_added = None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user