From 30228ce2a44d9829784c119aed6501acb57c0e59 Mon Sep 17 00:00:00 2001 From: maryhipp Date: Tue, 27 Feb 2024 15:35:27 -0500 Subject: [PATCH] fix merge --- invokeai/app/api/routers/model_manager.py | 44 ++++++++++ invokeai/app/invocations/compel.py | 28 ++++-- .../model_metadata/metadata_store_base.py | 15 +++- .../model_metadata/metadata_store_sql.py | 85 ++++++++++--------- .../model_manager/metadata/fetch/civitai.py | 1 + .../model_manager/metadata/metadata_base.py | 3 +- invokeai/backend/model_patcher.py | 2 + 7 files changed, 129 insertions(+), 49 deletions(-) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 50ebe5ce64..774f39909d 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -5,6 +5,7 @@ import pathlib import shutil from hashlib import sha1 from random import randbytes +import traceback from typing import Any, Dict, List, Optional, Set from fastapi import Body, Path, Query, Response @@ -14,6 +15,7 @@ from starlette.exceptions import HTTPException from typing_extensions import Annotated from invokeai.app.services.model_install import ModelInstallJob +from invokeai.app.services.model_metadata.metadata_store_base import ModelMetadataChanges from invokeai.app.services.model_records import ( DuplicateModelException, InvalidModelException, @@ -32,6 +34,7 @@ from invokeai.backend.model_manager.config import ( ) from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata +from invokeai.backend.model_manager.metadata.metadata_base import BaseMetadata from invokeai.backend.model_manager.search import ModelSearch from ..dependencies import ApiDependencies @@ -242,6 +245,47 @@ async def get_model_metadata( return result +@model_manager_router.patch( + "/i/{key}/metadata", + operation_id="update_model_metadata", + responses={ + 201: { + "description": "The model metadata was updated successfully", + "content": {"application/json": {"example": example_model_metadata}}, + }, + 400: {"description": "Bad request"}, + }, +) +async def update_model_metadata( + key: str = Path(description="Key of the model repo metadata to fetch."), + changes: ModelMetadataChanges = Body(description="The changes") +) -> Optional[AnyModelRepoMetadata]: + """Updates or creates a model metadata object.""" + logger = ApiDependencies.invoker.services.logger + record_store = ApiDependencies.invoker.services.model_manager.store + metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store + + try: + original_metadata = record_store.get_metadata(key) + print(original_metadata) + if original_metadata: + original_metadata.trigger_phrases = changes.trigger_phrases + + metadata_store.update_metadata(key, original_metadata) + else: + metadata_store.add_metadata(key, BaseMetadata(name="", author="",trigger_phrases=changes.trigger_phrases)) + except Exception as e: + ApiDependencies.invoker.services.logger.error(traceback.format_exception(e)) + raise HTTPException( + status_code=500, + detail=f"An error occurred while updating the model metadata: {e}", + ) + + result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key) + + return result + + @model_manager_router.get( "/tags", diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 50f5322513..0d558ec898 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -84,16 +84,28 @@ class CompelInvocation(BaseInvocation): ti_list = [] for trigger in extract_ti_triggers_from_prompt(self.prompt): - name = trigger[1:-1] + name_or_key = trigger[1:-1] + print(f"name_or_key: {name_or_key}") try: - loaded_model = context.models.load(key=name).model - assert isinstance(loaded_model, TextualInversionModelRaw) - ti_list.append((name, loaded_model)) + loaded_model = context.models.load(key=name_or_key) + model = loaded_model.model + print(model) + assert isinstance(model, TextualInversionModelRaw) + ti_list.append((name_or_key, model)) except UnknownModelException: - # print(e) - # import traceback - # print(traceback.format_exc()) - print(f'Warn: trigger: "{trigger}" not found') + try: + print(f"base: {text_encoder_info.config.base}") + loaded_model = context.models.load_by_attrs( + model_name=name_or_key, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion + ) + model = loaded_model.model + print(model) + assert isinstance(model, TextualInversionModelRaw) + ti_list.append((name_or_key, model)) + except UnknownModelException: + logger.warning(f'trigger: "{trigger}" not found') + except ValueError: + logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models') with ( ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as ( diff --git a/invokeai/app/services/model_metadata/metadata_store_base.py b/invokeai/app/services/model_metadata/metadata_store_base.py index e0e4381b09..e0ae34378b 100644 --- a/invokeai/app/services/model_metadata/metadata_store_base.py +++ b/invokeai/app/services/model_metadata/metadata_store_base.py @@ -4,10 +4,23 @@ Storage for Model Metadata """ from abc import ABC, abstractmethod -from typing import List, Set, Tuple +from typing import List, Optional, Set, Tuple + +from pydantic import Field +from invokeai.app.util.model_exclude_null import BaseModelExcludeNull from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata +class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"): + """A set of changes to apply to model metadata. + + Only limited changes are valid: + - `trigger_phrases`: the list of trigger phrases for this model + """ + + trigger_phrases: Optional[List[str]] = Field(default=None, description="The model's list of trigger phrases") + """The model's list of trigger phrases""" + class ModelMetadataStoreBase(ABC): """Store, search and fetch model metadata retrieved from remote repositories.""" diff --git a/invokeai/app/services/model_metadata/metadata_store_sql.py b/invokeai/app/services/model_metadata/metadata_store_sql.py index afe9d2c8c6..9d9057d0c5 100644 --- a/invokeai/app/services/model_metadata/metadata_store_sql.py +++ b/invokeai/app/services/model_metadata/metadata_store_sql.py @@ -38,6 +38,8 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase): :param metadata: ModelRepoMetadata object to store """ json_serialized = metadata.model_dump_json() + print("json_serialized") + print(json_serialized) with self._db.lock: try: self._cursor.execute( @@ -53,7 +55,7 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase): json_serialized, ), ) - self._update_tags(model_key, metadata.tags) + # self._update_tags(model_key, metadata.tags) self._db.conn.commit() except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table self._db.conn.rollback() @@ -61,6 +63,8 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase): except sqlite3.Error as excp: self._db.conn.rollback() raise excp + except Exception as e: + raise e def get_metadata(self, model_key: str) -> AnyModelRepoMetadata: """Retrieve the ModelRepoMetadata corresponding to model key.""" @@ -115,6 +119,8 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase): except sqlite3.Error as e: self._db.conn.rollback() raise e + except Exception as e: + raise e return self.get_metadata(model_key) @@ -179,44 +185,45 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase): ) return {x[0] for x in self._cursor.fetchall()} - def _update_tags(self, model_key: str, tags: Set[str]) -> None: + def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None: """Update tags for the model referenced by model_key.""" + if tags: # remove previous tags from this model - self._cursor.execute( - """--sql - DELETE FROM model_tags - WHERE model_id=?; - """, - (model_key,), - ) + self._cursor.execute( + """--sql + DELETE FROM model_tags + WHERE model_id=?; + """, + (model_key,), + ) - for tag in tags: - self._cursor.execute( - """--sql - INSERT OR IGNORE INTO tags ( - tag_text - ) - VALUES (?); - """, - (tag,), - ) - self._cursor.execute( - """--sql - SELECT tag_id - FROM tags - WHERE tag_text = ? - LIMIT 1; - """, - (tag,), - ) - tag_id = self._cursor.fetchone()[0] - self._cursor.execute( - """--sql - INSERT OR IGNORE INTO model_tags ( - model_id, - tag_id - ) - VALUES (?,?); - """, - (model_key, tag_id), - ) + for tag in tags: + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO tags ( + tag_text + ) + VALUES (?); + """, + (tag,), + ) + self._cursor.execute( + """--sql + SELECT tag_id + FROM tags + WHERE tag_text = ? + LIMIT 1; + """, + (tag,), + ) + tag_id = self._cursor.fetchone()[0] + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO model_tags ( + model_id, + tag_id + ) + VALUES (?,?); + """, + (model_key, tag_id), + ) diff --git a/invokeai/backend/model_manager/metadata/fetch/civitai.py b/invokeai/backend/model_manager/metadata/fetch/civitai.py index 7991f6a748..dcbf4ac1a9 100644 --- a/invokeai/backend/model_manager/metadata/fetch/civitai.py +++ b/invokeai/backend/model_manager/metadata/fetch/civitai.py @@ -164,6 +164,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase): AllowDerivatives=model_json["allowDerivatives"], AllowDifferentLicense=model_json["allowDifferentLicense"], ), + trigger_phrases=version_json["trainedWords"], ) def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata: diff --git a/invokeai/backend/model_manager/metadata/metadata_base.py b/invokeai/backend/model_manager/metadata/metadata_base.py index 6e410d8222..502467a0af 100644 --- a/invokeai/backend/model_manager/metadata/metadata_base.py +++ b/invokeai/backend/model_manager/metadata/metadata_base.py @@ -73,7 +73,8 @@ class ModelMetadataBase(BaseModel): name: str = Field(description="model's name") author: str = Field(description="model's author") - tags: Set[str] = Field(description="tags provided by model source") + tags: Optional[Set[str]] = Field(description="tags provided by model source", default=None) + trigger_phrases: Optional[List[str]] = Field(description="trigger phrases for this model", default=None) class BaseMetadata(ModelMetadataBase): diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index bee8909c31..87f10e4adc 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -171,6 +171,8 @@ class ModelPatcher: text_encoder: CLIPTextModel, ti_list: List[Tuple[str, TextualInversionModelRaw]], ) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]: + print("TI LIST") + print(ti_list) init_tokens_count = None new_tokens_added = None