fix merge

This commit is contained in:
maryhipp 2024-02-27 15:35:27 -05:00
parent efb5f2d202
commit 30228ce2a4
7 changed files with 129 additions and 49 deletions

View File

@ -5,6 +5,7 @@ import pathlib
import shutil
from hashlib import sha1
from random import randbytes
import traceback
from typing import Any, Dict, List, Optional, Set
from fastapi import Body, Path, Query, Response
@ -14,6 +15,7 @@ from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_metadata.metadata_store_base import ModelMetadataChanges
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
@ -32,6 +34,7 @@ from invokeai.backend.model_manager.config import (
)
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.metadata.metadata_base import BaseMetadata
from invokeai.backend.model_manager.search import ModelSearch
from ..dependencies import ApiDependencies
@ -242,6 +245,47 @@ async def get_model_metadata(
return result
@model_manager_router.patch(
"/i/{key}/metadata",
operation_id="update_model_metadata",
responses={
201: {
"description": "The model metadata was updated successfully",
"content": {"application/json": {"example": example_model_metadata}},
},
400: {"description": "Bad request"},
},
)
async def update_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."),
changes: ModelMetadataChanges = Body(description="The changes")
) -> Optional[AnyModelRepoMetadata]:
"""Updates or creates a model metadata object."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store
try:
original_metadata = record_store.get_metadata(key)
print(original_metadata)
if original_metadata:
original_metadata.trigger_phrases = changes.trigger_phrases
metadata_store.update_metadata(key, original_metadata)
else:
metadata_store.add_metadata(key, BaseMetadata(name="", author="",trigger_phrases=changes.trigger_phrases))
except Exception as e:
ApiDependencies.invoker.services.logger.error(traceback.format_exception(e))
raise HTTPException(
status_code=500,
detail=f"An error occurred while updating the model metadata: {e}",
)
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
return result
@model_manager_router.get(
"/tags",

View File

@ -84,16 +84,28 @@ class CompelInvocation(BaseInvocation):
ti_list = []
for trigger in extract_ti_triggers_from_prompt(self.prompt):
name = trigger[1:-1]
name_or_key = trigger[1:-1]
print(f"name_or_key: {name_or_key}")
try:
loaded_model = context.models.load(key=name).model
assert isinstance(loaded_model, TextualInversionModelRaw)
ti_list.append((name, loaded_model))
loaded_model = context.models.load(key=name_or_key)
model = loaded_model.model
print(model)
assert isinstance(model, TextualInversionModelRaw)
ti_list.append((name_or_key, model))
except UnknownModelException:
# print(e)
# import traceback
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
try:
print(f"base: {text_encoder_info.config.base}")
loaded_model = context.models.load_by_attrs(
model_name=name_or_key, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
)
model = loaded_model.model
print(model)
assert isinstance(model, TextualInversionModelRaw)
ti_list.append((name_or_key, model))
except UnknownModelException:
logger.warning(f'trigger: "{trigger}" not found')
except ValueError:
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
with (
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (

View File

@ -4,10 +4,23 @@ Storage for Model Metadata
"""
from abc import ABC, abstractmethod
from typing import List, Set, Tuple
from typing import List, Optional, Set, Tuple
from pydantic import Field
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"):
"""A set of changes to apply to model metadata.
Only limited changes are valid:
- `trigger_phrases`: the list of trigger phrases for this model
"""
trigger_phrases: Optional[List[str]] = Field(default=None, description="The model's list of trigger phrases")
"""The model's list of trigger phrases"""
class ModelMetadataStoreBase(ABC):
"""Store, search and fetch model metadata retrieved from remote repositories."""

View File

@ -38,6 +38,8 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
:param metadata: ModelRepoMetadata object to store
"""
json_serialized = metadata.model_dump_json()
print("json_serialized")
print(json_serialized)
with self._db.lock:
try:
self._cursor.execute(
@ -53,7 +55,7 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
json_serialized,
),
)
self._update_tags(model_key, metadata.tags)
# self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback()
@ -61,6 +63,8 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
except sqlite3.Error as excp:
self._db.conn.rollback()
raise excp
except Exception as e:
raise e
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
@ -115,6 +119,8 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
except Exception as e:
raise e
return self.get_metadata(model_key)
@ -179,44 +185,45 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None:
"""Update tags for the model referenced by model_key."""
if tags:
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)

View File

@ -164,6 +164,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
AllowDerivatives=model_json["allowDerivatives"],
AllowDifferentLicense=model_json["allowDifferentLicense"],
),
trigger_phrases=version_json["trainedWords"],
)
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:

View File

@ -73,7 +73,8 @@ class ModelMetadataBase(BaseModel):
name: str = Field(description="model's name")
author: str = Field(description="model's author")
tags: Set[str] = Field(description="tags provided by model source")
tags: Optional[Set[str]] = Field(description="tags provided by model source", default=None)
trigger_phrases: Optional[List[str]] = Field(description="trigger phrases for this model", default=None)
class BaseMetadata(ModelMetadataBase):

View File

@ -171,6 +171,8 @@ class ModelPatcher:
text_encoder: CLIPTextModel,
ti_list: List[Tuple[str, TextualInversionModelRaw]],
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
print("TI LIST")
print(ti_list)
init_tokens_count = None
new_tokens_added = None