mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix merge
This commit is contained in:
parent
efb5f2d202
commit
30228ce2a4
@ -5,6 +5,7 @@ import pathlib
|
|||||||
import shutil
|
import shutil
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
from random import randbytes
|
from random import randbytes
|
||||||
|
import traceback
|
||||||
from typing import Any, Dict, List, Optional, Set
|
from typing import Any, Dict, List, Optional, Set
|
||||||
|
|
||||||
from fastapi import Body, Path, Query, Response
|
from fastapi import Body, Path, Query, Response
|
||||||
@ -14,6 +15,7 @@ from starlette.exceptions import HTTPException
|
|||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from invokeai.app.services.model_install import ModelInstallJob
|
from invokeai.app.services.model_install import ModelInstallJob
|
||||||
|
from invokeai.app.services.model_metadata.metadata_store_base import ModelMetadataChanges
|
||||||
from invokeai.app.services.model_records import (
|
from invokeai.app.services.model_records import (
|
||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
InvalidModelException,
|
InvalidModelException,
|
||||||
@ -32,6 +34,7 @@ from invokeai.backend.model_manager.config import (
|
|||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||||
|
from invokeai.backend.model_manager.metadata.metadata_base import BaseMetadata
|
||||||
from invokeai.backend.model_manager.search import ModelSearch
|
from invokeai.backend.model_manager.search import ModelSearch
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
@ -242,6 +245,47 @@ async def get_model_metadata(
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@model_manager_router.patch(
|
||||||
|
"/i/{key}/metadata",
|
||||||
|
operation_id="update_model_metadata",
|
||||||
|
responses={
|
||||||
|
201: {
|
||||||
|
"description": "The model metadata was updated successfully",
|
||||||
|
"content": {"application/json": {"example": example_model_metadata}},
|
||||||
|
},
|
||||||
|
400: {"description": "Bad request"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def update_model_metadata(
|
||||||
|
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||||
|
changes: ModelMetadataChanges = Body(description="The changes")
|
||||||
|
) -> Optional[AnyModelRepoMetadata]:
|
||||||
|
"""Updates or creates a model metadata object."""
|
||||||
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||||
|
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store
|
||||||
|
|
||||||
|
try:
|
||||||
|
original_metadata = record_store.get_metadata(key)
|
||||||
|
print(original_metadata)
|
||||||
|
if original_metadata:
|
||||||
|
original_metadata.trigger_phrases = changes.trigger_phrases
|
||||||
|
|
||||||
|
metadata_store.update_metadata(key, original_metadata)
|
||||||
|
else:
|
||||||
|
metadata_store.add_metadata(key, BaseMetadata(name="", author="",trigger_phrases=changes.trigger_phrases))
|
||||||
|
except Exception as e:
|
||||||
|
ApiDependencies.invoker.services.logger.error(traceback.format_exception(e))
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"An error occurred while updating the model metadata: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.get(
|
@model_manager_router.get(
|
||||||
"/tags",
|
"/tags",
|
||||||
|
@ -84,16 +84,28 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
ti_list = []
|
ti_list = []
|
||||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
||||||
name = trigger[1:-1]
|
name_or_key = trigger[1:-1]
|
||||||
|
print(f"name_or_key: {name_or_key}")
|
||||||
try:
|
try:
|
||||||
loaded_model = context.models.load(key=name).model
|
loaded_model = context.models.load(key=name_or_key)
|
||||||
assert isinstance(loaded_model, TextualInversionModelRaw)
|
model = loaded_model.model
|
||||||
ti_list.append((name, loaded_model))
|
print(model)
|
||||||
|
assert isinstance(model, TextualInversionModelRaw)
|
||||||
|
ti_list.append((name_or_key, model))
|
||||||
except UnknownModelException:
|
except UnknownModelException:
|
||||||
# print(e)
|
try:
|
||||||
# import traceback
|
print(f"base: {text_encoder_info.config.base}")
|
||||||
# print(traceback.format_exc())
|
loaded_model = context.models.load_by_attrs(
|
||||||
print(f'Warn: trigger: "{trigger}" not found')
|
model_name=name_or_key, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
|
||||||
|
)
|
||||||
|
model = loaded_model.model
|
||||||
|
print(model)
|
||||||
|
assert isinstance(model, TextualInversionModelRaw)
|
||||||
|
ti_list.append((name_or_key, model))
|
||||||
|
except UnknownModelException:
|
||||||
|
logger.warning(f'trigger: "{trigger}" not found')
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
|
||||||
|
|
||||||
with (
|
with (
|
||||||
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
|
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
|
||||||
|
@ -4,10 +4,23 @@ Storage for Model Metadata
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Set, Tuple
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
|
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||||
|
|
||||||
|
class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"):
|
||||||
|
"""A set of changes to apply to model metadata.
|
||||||
|
|
||||||
|
Only limited changes are valid:
|
||||||
|
- `trigger_phrases`: the list of trigger phrases for this model
|
||||||
|
"""
|
||||||
|
|
||||||
|
trigger_phrases: Optional[List[str]] = Field(default=None, description="The model's list of trigger phrases")
|
||||||
|
"""The model's list of trigger phrases"""
|
||||||
|
|
||||||
|
|
||||||
class ModelMetadataStoreBase(ABC):
|
class ModelMetadataStoreBase(ABC):
|
||||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||||
|
@ -38,6 +38,8 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
|||||||
:param metadata: ModelRepoMetadata object to store
|
:param metadata: ModelRepoMetadata object to store
|
||||||
"""
|
"""
|
||||||
json_serialized = metadata.model_dump_json()
|
json_serialized = metadata.model_dump_json()
|
||||||
|
print("json_serialized")
|
||||||
|
print(json_serialized)
|
||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
try:
|
try:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
@ -53,7 +55,7 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
|||||||
json_serialized,
|
json_serialized,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._update_tags(model_key, metadata.tags)
|
# self._update_tags(model_key, metadata.tags)
|
||||||
self._db.conn.commit()
|
self._db.conn.commit()
|
||||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||||
self._db.conn.rollback()
|
self._db.conn.rollback()
|
||||||
@ -61,6 +63,8 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
|||||||
except sqlite3.Error as excp:
|
except sqlite3.Error as excp:
|
||||||
self._db.conn.rollback()
|
self._db.conn.rollback()
|
||||||
raise excp
|
raise excp
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||||
@ -115,6 +119,8 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
|||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
self._db.conn.rollback()
|
self._db.conn.rollback()
|
||||||
raise e
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
return self.get_metadata(model_key)
|
return self.get_metadata(model_key)
|
||||||
|
|
||||||
@ -179,44 +185,45 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
|||||||
)
|
)
|
||||||
return {x[0] for x in self._cursor.fetchall()}
|
return {x[0] for x in self._cursor.fetchall()}
|
||||||
|
|
||||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None:
|
||||||
"""Update tags for the model referenced by model_key."""
|
"""Update tags for the model referenced by model_key."""
|
||||||
|
if tags:
|
||||||
# remove previous tags from this model
|
# remove previous tags from this model
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
DELETE FROM model_tags
|
DELETE FROM model_tags
|
||||||
WHERE model_id=?;
|
WHERE model_id=?;
|
||||||
""",
|
""",
|
||||||
(model_key,),
|
(model_key,),
|
||||||
)
|
)
|
||||||
|
|
||||||
for tag in tags:
|
for tag in tags:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
INSERT OR IGNORE INTO tags (
|
INSERT OR IGNORE INTO tags (
|
||||||
tag_text
|
tag_text
|
||||||
)
|
)
|
||||||
VALUES (?);
|
VALUES (?);
|
||||||
""",
|
""",
|
||||||
(tag,),
|
(tag,),
|
||||||
)
|
)
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
SELECT tag_id
|
SELECT tag_id
|
||||||
FROM tags
|
FROM tags
|
||||||
WHERE tag_text = ?
|
WHERE tag_text = ?
|
||||||
LIMIT 1;
|
LIMIT 1;
|
||||||
""",
|
""",
|
||||||
(tag,),
|
(tag,),
|
||||||
)
|
)
|
||||||
tag_id = self._cursor.fetchone()[0]
|
tag_id = self._cursor.fetchone()[0]
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
INSERT OR IGNORE INTO model_tags (
|
INSERT OR IGNORE INTO model_tags (
|
||||||
model_id,
|
model_id,
|
||||||
tag_id
|
tag_id
|
||||||
)
|
)
|
||||||
VALUES (?,?);
|
VALUES (?,?);
|
||||||
""",
|
""",
|
||||||
(model_key, tag_id),
|
(model_key, tag_id),
|
||||||
)
|
)
|
||||||
|
@ -164,6 +164,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
|||||||
AllowDerivatives=model_json["allowDerivatives"],
|
AllowDerivatives=model_json["allowDerivatives"],
|
||||||
AllowDifferentLicense=model_json["allowDifferentLicense"],
|
AllowDifferentLicense=model_json["allowDifferentLicense"],
|
||||||
),
|
),
|
||||||
|
trigger_phrases=version_json["trainedWords"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
|
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
|
||||||
|
@ -73,7 +73,8 @@ class ModelMetadataBase(BaseModel):
|
|||||||
|
|
||||||
name: str = Field(description="model's name")
|
name: str = Field(description="model's name")
|
||||||
author: str = Field(description="model's author")
|
author: str = Field(description="model's author")
|
||||||
tags: Set[str] = Field(description="tags provided by model source")
|
tags: Optional[Set[str]] = Field(description="tags provided by model source", default=None)
|
||||||
|
trigger_phrases: Optional[List[str]] = Field(description="trigger phrases for this model", default=None)
|
||||||
|
|
||||||
|
|
||||||
class BaseMetadata(ModelMetadataBase):
|
class BaseMetadata(ModelMetadataBase):
|
||||||
|
@ -171,6 +171,8 @@ class ModelPatcher:
|
|||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
ti_list: List[Tuple[str, TextualInversionModelRaw]],
|
ti_list: List[Tuple[str, TextualInversionModelRaw]],
|
||||||
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
|
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
|
||||||
|
print("TI LIST")
|
||||||
|
print(ti_list)
|
||||||
init_tokens_count = None
|
init_tokens_count = None
|
||||||
new_tokens_added = None
|
new_tokens_added = None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user