fix merge

This commit is contained in:
maryhipp 2024-02-27 15:35:27 -05:00 committed by Kent Keirsey
parent 8e633d02be
commit a13dce18c7
6 changed files with 109 additions and 41 deletions

View File

@ -5,6 +5,7 @@ import pathlib
import shutil
from hashlib import sha1
from random import randbytes
import traceback
from typing import Any, Dict, List, Optional, Set
from fastapi import Body, Path, Query, Response
@ -14,6 +15,7 @@ from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_metadata.metadata_store_base import ModelMetadataChanges
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
@ -32,6 +34,7 @@ from invokeai.backend.model_manager.config import (
)
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.metadata.metadata_base import BaseMetadata
from invokeai.backend.model_manager.search import ModelSearch
from ..dependencies import ApiDependencies
@ -242,6 +245,47 @@ async def get_model_metadata(
return result
@model_manager_router.patch(
"/i/{key}/metadata",
operation_id="update_model_metadata",
responses={
201: {
"description": "The model metadata was updated successfully",
"content": {"application/json": {"example": example_model_metadata}},
},
400: {"description": "Bad request"},
},
)
async def update_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."),
changes: ModelMetadataChanges = Body(description="The changes")
) -> Optional[AnyModelRepoMetadata]:
"""Updates or creates a model metadata object."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store
try:
original_metadata = record_store.get_metadata(key)
print(original_metadata)
if original_metadata:
original_metadata.trigger_phrases = changes.trigger_phrases
metadata_store.update_metadata(key, original_metadata)
else:
metadata_store.add_metadata(key, BaseMetadata(name="", author="",trigger_phrases=changes.trigger_phrases))
except Exception as e:
ApiDependencies.invoker.services.logger.error(traceback.format_exception(e))
raise HTTPException(
status_code=500,
detail=f"An error occurred while updating the model metadata: {e}",
)
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
return result
@model_manager_router.get(
"/tags",

View File

@ -4,10 +4,23 @@ Storage for Model Metadata
"""
from abc import ABC, abstractmethod
from typing import List, Set, Tuple
from typing import List, Optional, Set, Tuple
from pydantic import Field
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"):
"""A set of changes to apply to model metadata.
Only limited changes are valid:
- `trigger_phrases`: the list of trigger phrases for this model
"""
trigger_phrases: Optional[List[str]] = Field(default=None, description="The model's list of trigger phrases")
"""The model's list of trigger phrases"""
class ModelMetadataStoreBase(ABC):
"""Store, search and fetch model metadata retrieved from remote repositories."""

View File

@ -38,6 +38,8 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
:param metadata: ModelRepoMetadata object to store
"""
json_serialized = metadata.model_dump_json()
print("json_serialized")
print(json_serialized)
with self._db.lock:
try:
self._cursor.execute(
@ -53,7 +55,7 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
json_serialized,
),
)
self._update_tags(model_key, metadata.tags)
# self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback()
@ -61,6 +63,8 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
except sqlite3.Error as excp:
self._db.conn.rollback()
raise excp
except Exception as e:
raise e
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
@ -115,6 +119,8 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
except Exception as e:
raise e
return self.get_metadata(model_key)
@ -179,44 +185,45 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None:
"""Update tags for the model referenced by model_key."""
if tags:
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)

View File

@ -164,6 +164,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
AllowDerivatives=model_json["allowDerivatives"],
AllowDifferentLicense=model_json["allowDifferentLicense"],
),
trigger_phrases=version_json["trainedWords"],
)
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:

View File

@ -73,7 +73,8 @@ class ModelMetadataBase(BaseModel):
name: str = Field(description="model's name")
author: str = Field(description="model's author")
tags: Set[str] = Field(description="tags provided by model source")
tags: Optional[Set[str]] = Field(description="tags provided by model source", default=None)
trigger_phrases: Optional[List[str]] = Field(description="trigger phrases for this model", default=None)
class BaseMetadata(ModelMetadataBase):

View File

@ -172,6 +172,8 @@ class ModelPatcher:
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
ti_list: List[Tuple[str, TextualInversionModelRaw]],
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
print("TI LIST")
print(ti_list)
init_tokens_count = None
new_tokens_added = None