mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
allow the model record migrate script to update existing model records
This commit is contained in:
parent
11f4a48144
commit
b0cfa58526
@ -85,7 +85,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
|
||||
return self._event_bus
|
||||
|
||||
def stop(self) -> None:
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
"""Stop the install thread; after this the object can be deleted and garbage collected."""
|
||||
self._install_queue.put(STOP_JOB)
|
||||
|
||||
|
@ -96,10 +96,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
CREATE TABLE IF NOT EXISTS model_config (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
-- The next 3 fields are enums in python, unrestricted string here
|
||||
base TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
|
||||
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
|
||||
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
|
||||
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
|
||||
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
|
||||
original_hash TEXT, -- could be null
|
||||
-- Serialized JSON representation of the whole config object,
|
||||
-- which will contain additional fields from subclasses
|
||||
@ -175,21 +176,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""--sql
|
||||
INSERT INTO model_config (
|
||||
id,
|
||||
base,
|
||||
type,
|
||||
name,
|
||||
path,
|
||||
original_hash,
|
||||
config
|
||||
)
|
||||
VALUES (?,?,?,?,?,?,?);
|
||||
VALUES (?,?,?);
|
||||
""",
|
||||
(
|
||||
key,
|
||||
record.base,
|
||||
record.type,
|
||||
record.name,
|
||||
record.path,
|
||||
record.original_hash,
|
||||
json_serialized,
|
||||
),
|
||||
@ -269,14 +262,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_config
|
||||
SET base=?,
|
||||
type=?,
|
||||
name=?,
|
||||
path=?,
|
||||
SET
|
||||
config=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(record.base, record.type, record.name, record.path, json_serialized, key),
|
||||
(json_serialized, key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
@ -374,7 +364,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config FROM model_config
|
||||
WHERE model_path=?;
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
)
|
||||
|
@ -2,6 +2,7 @@
|
||||
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
|
||||
|
||||
from hashlib import sha1
|
||||
from logging import Logger
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import TypeAdapter
|
||||
@ -10,6 +11,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
ModelRecordServiceSQL,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@ -38,9 +40,9 @@ class MigrateModelYamlToDb:
|
||||
"""
|
||||
|
||||
config: InvokeAIAppConfig
|
||||
logger: InvokeAILogger
|
||||
logger: Logger
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.config = InvokeAIAppConfig.get_config()
|
||||
self.config.parse_args()
|
||||
self.logger = InvokeAILogger.get_logger()
|
||||
@ -53,9 +55,11 @@ class MigrateModelYamlToDb:
|
||||
def get_yaml(self) -> DictConfig:
|
||||
"""Fetch the models.yaml DictConfig for this installation."""
|
||||
yaml_path = self.config.model_conf_path
|
||||
return OmegaConf.load(yaml_path)
|
||||
omegaconf = OmegaConf.load(yaml_path)
|
||||
assert isinstance(omegaconf,DictConfig)
|
||||
return omegaconf
|
||||
|
||||
def migrate(self):
|
||||
def migrate(self) -> None:
|
||||
"""Do the migration from models.yaml to invokeai.db."""
|
||||
db = self.get_db()
|
||||
yaml = self.get_yaml()
|
||||
@ -69,6 +73,7 @@ class MigrateModelYamlToDb:
|
||||
|
||||
base_type, model_type, model_name = str(model_key).split("/")
|
||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
||||
assert isinstance(model_key, str)
|
||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
||||
|
||||
stanza["base"] = BaseModelType(base_type)
|
||||
@ -77,12 +82,20 @@ class MigrateModelYamlToDb:
|
||||
stanza["original_hash"] = hash
|
||||
stanza["current_hash"] = hash
|
||||
|
||||
new_config = ModelsValidator.validate_python(stanza)
|
||||
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
||||
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
|
||||
|
||||
try:
|
||||
db.add_model(new_key, new_config)
|
||||
if original_record := db.search_by_path(stanza.path):
|
||||
key = original_record[0].key
|
||||
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
|
||||
db.update_model(key, new_config)
|
||||
else:
|
||||
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
||||
db.add_model(new_key, new_config)
|
||||
except DuplicateModelException:
|
||||
self.logger.warning(f"Model {model_name} is already in the database")
|
||||
except UnknownModelException:
|
||||
self.logger.warning(f"Model at {stanza.path} could not be found in database")
|
||||
|
||||
|
||||
def main():
|
||||
|
Loading…
Reference in New Issue
Block a user