fix merge

This commit is contained in:
maryhipp 2024-02-27 15:35:27 -05:00
parent efb5f2d202
commit 30228ce2a4
7 changed files with 129 additions and 49 deletions

View File

@ -5,6 +5,7 @@ import pathlib
import shutil import shutil
from hashlib import sha1 from hashlib import sha1
from random import randbytes from random import randbytes
import traceback
from typing import Any, Dict, List, Optional, Set from typing import Any, Dict, List, Optional, Set
from fastapi import Body, Path, Query, Response from fastapi import Body, Path, Query, Response
@ -14,6 +15,7 @@ from starlette.exceptions import HTTPException
from typing_extensions import Annotated from typing_extensions import Annotated
from invokeai.app.services.model_install import ModelInstallJob from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_metadata.metadata_store_base import ModelMetadataChanges
from invokeai.app.services.model_records import ( from invokeai.app.services.model_records import (
DuplicateModelException, DuplicateModelException,
InvalidModelException, InvalidModelException,
@ -32,6 +34,7 @@ from invokeai.backend.model_manager.config import (
) )
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.metadata.metadata_base import BaseMetadata
from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.model_manager.search import ModelSearch
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -242,6 +245,47 @@ async def get_model_metadata(
return result return result
@model_manager_router.patch(
"/i/{key}/metadata",
operation_id="update_model_metadata",
responses={
201: {
"description": "The model metadata was updated successfully",
"content": {"application/json": {"example": example_model_metadata}},
},
400: {"description": "Bad request"},
},
)
async def update_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."),
changes: ModelMetadataChanges = Body(description="The changes")
) -> Optional[AnyModelRepoMetadata]:
"""Updates or creates a model metadata object."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store
try:
original_metadata = record_store.get_metadata(key)
print(original_metadata)
if original_metadata:
original_metadata.trigger_phrases = changes.trigger_phrases
metadata_store.update_metadata(key, original_metadata)
else:
metadata_store.add_metadata(key, BaseMetadata(name="", author="",trigger_phrases=changes.trigger_phrases))
except Exception as e:
ApiDependencies.invoker.services.logger.error(traceback.format_exception(e))
raise HTTPException(
status_code=500,
detail=f"An error occurred while updating the model metadata: {e}",
)
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
return result
@model_manager_router.get( @model_manager_router.get(
"/tags", "/tags",

View File

@ -84,16 +84,28 @@ class CompelInvocation(BaseInvocation):
ti_list = [] ti_list = []
for trigger in extract_ti_triggers_from_prompt(self.prompt): for trigger in extract_ti_triggers_from_prompt(self.prompt):
name = trigger[1:-1] name_or_key = trigger[1:-1]
print(f"name_or_key: {name_or_key}")
try: try:
loaded_model = context.models.load(key=name).model loaded_model = context.models.load(key=name_or_key)
assert isinstance(loaded_model, TextualInversionModelRaw) model = loaded_model.model
ti_list.append((name, loaded_model)) print(model)
assert isinstance(model, TextualInversionModelRaw)
ti_list.append((name_or_key, model))
except UnknownModelException: except UnknownModelException:
# print(e) try:
# import traceback print(f"base: {text_encoder_info.config.base}")
# print(traceback.format_exc()) loaded_model = context.models.load_by_attrs(
print(f'Warn: trigger: "{trigger}" not found') model_name=name_or_key, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
)
model = loaded_model.model
print(model)
assert isinstance(model, TextualInversionModelRaw)
ti_list.append((name_or_key, model))
except UnknownModelException:
logger.warning(f'trigger: "{trigger}" not found')
except ValueError:
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
with ( with (
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as ( ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (

View File

@ -4,10 +4,23 @@ Storage for Model Metadata
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Set, Tuple from typing import List, Optional, Set, Tuple
from pydantic import Field
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"):
"""A set of changes to apply to model metadata.
Only limited changes are valid:
- `trigger_phrases`: the list of trigger phrases for this model
"""
trigger_phrases: Optional[List[str]] = Field(default=None, description="The model's list of trigger phrases")
"""The model's list of trigger phrases"""
class ModelMetadataStoreBase(ABC): class ModelMetadataStoreBase(ABC):
"""Store, search and fetch model metadata retrieved from remote repositories.""" """Store, search and fetch model metadata retrieved from remote repositories."""

View File

@ -38,6 +38,8 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
:param metadata: ModelRepoMetadata object to store :param metadata: ModelRepoMetadata object to store
""" """
json_serialized = metadata.model_dump_json() json_serialized = metadata.model_dump_json()
print("json_serialized")
print(json_serialized)
with self._db.lock: with self._db.lock:
try: try:
self._cursor.execute( self._cursor.execute(
@ -53,7 +55,7 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
json_serialized, json_serialized,
), ),
) )
self._update_tags(model_key, metadata.tags) # self._update_tags(model_key, metadata.tags)
self._db.conn.commit() self._db.conn.commit()
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback() self._db.conn.rollback()
@ -61,6 +63,8 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
except sqlite3.Error as excp: except sqlite3.Error as excp:
self._db.conn.rollback() self._db.conn.rollback()
raise excp raise excp
except Exception as e:
raise e
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata: def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key.""" """Retrieve the ModelRepoMetadata corresponding to model key."""
@ -115,6 +119,8 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
except sqlite3.Error as e: except sqlite3.Error as e:
self._db.conn.rollback() self._db.conn.rollback()
raise e raise e
except Exception as e:
raise e
return self.get_metadata(model_key) return self.get_metadata(model_key)
@ -179,44 +185,45 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
) )
return {x[0] for x in self._cursor.fetchall()} return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Set[str]) -> None: def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None:
"""Update tags for the model referenced by model_key.""" """Update tags for the model referenced by model_key."""
if tags:
# remove previous tags from this model # remove previous tags from this model
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
DELETE FROM model_tags DELETE FROM model_tags
WHERE model_id=?; WHERE model_id=?;
""", """,
(model_key,), (model_key,),
) )
for tag in tags: for tag in tags:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
INSERT OR IGNORE INTO tags ( INSERT OR IGNORE INTO tags (
tag_text tag_text
) )
VALUES (?); VALUES (?);
""", """,
(tag,), (tag,),
) )
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
SELECT tag_id SELECT tag_id
FROM tags FROM tags
WHERE tag_text = ? WHERE tag_text = ?
LIMIT 1; LIMIT 1;
""", """,
(tag,), (tag,),
) )
tag_id = self._cursor.fetchone()[0] tag_id = self._cursor.fetchone()[0]
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
INSERT OR IGNORE INTO model_tags ( INSERT OR IGNORE INTO model_tags (
model_id, model_id,
tag_id tag_id
) )
VALUES (?,?); VALUES (?,?);
""", """,
(model_key, tag_id), (model_key, tag_id),
) )

View File

@ -164,6 +164,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
AllowDerivatives=model_json["allowDerivatives"], AllowDerivatives=model_json["allowDerivatives"],
AllowDifferentLicense=model_json["allowDifferentLicense"], AllowDifferentLicense=model_json["allowDifferentLicense"],
), ),
trigger_phrases=version_json["trainedWords"],
) )
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata: def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:

View File

@ -73,7 +73,8 @@ class ModelMetadataBase(BaseModel):
name: str = Field(description="model's name") name: str = Field(description="model's name")
author: str = Field(description="model's author") author: str = Field(description="model's author")
tags: Set[str] = Field(description="tags provided by model source") tags: Optional[Set[str]] = Field(description="tags provided by model source", default=None)
trigger_phrases: Optional[List[str]] = Field(description="trigger phrases for this model", default=None)
class BaseMetadata(ModelMetadataBase): class BaseMetadata(ModelMetadataBase):

View File

@ -171,6 +171,8 @@ class ModelPatcher:
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
ti_list: List[Tuple[str, TextualInversionModelRaw]], ti_list: List[Tuple[str, TextualInversionModelRaw]],
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]: ) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
print("TI LIST")
print(ti_list)
init_tokens_count = None init_tokens_count = None
new_tokens_added = None new_tokens_added = None