mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
allow the model record migrate script to update existing model records
This commit is contained in:
parent
11f4a48144
commit
b0cfa58526
@ -85,7 +85,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
|
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
|
||||||
return self._event_bus
|
return self._event_bus
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self, *args, **kwargs) -> None:
|
||||||
"""Stop the install thread; after this the object can be deleted and garbage collected."""
|
"""Stop the install thread; after this the object can be deleted and garbage collected."""
|
||||||
self._install_queue.put(STOP_JOB)
|
self._install_queue.put(STOP_JOB)
|
||||||
|
|
||||||
|
@ -96,10 +96,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
CREATE TABLE IF NOT EXISTS model_config (
|
CREATE TABLE IF NOT EXISTS model_config (
|
||||||
id TEXT NOT NULL PRIMARY KEY,
|
id TEXT NOT NULL PRIMARY KEY,
|
||||||
-- The next 3 fields are enums in python, unrestricted string here
|
-- The next 3 fields are enums in python, unrestricted string here
|
||||||
base TEXT NOT NULL,
|
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
|
||||||
type TEXT NOT NULL,
|
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
|
||||||
name TEXT NOT NULL,
|
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
|
||||||
path TEXT NOT NULL,
|
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
|
||||||
|
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
|
||||||
original_hash TEXT, -- could be null
|
original_hash TEXT, -- could be null
|
||||||
-- Serialized JSON representation of the whole config object,
|
-- Serialized JSON representation of the whole config object,
|
||||||
-- which will contain additional fields from subclasses
|
-- which will contain additional fields from subclasses
|
||||||
@ -175,21 +176,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
"""--sql
|
"""--sql
|
||||||
INSERT INTO model_config (
|
INSERT INTO model_config (
|
||||||
id,
|
id,
|
||||||
base,
|
|
||||||
type,
|
|
||||||
name,
|
|
||||||
path,
|
|
||||||
original_hash,
|
original_hash,
|
||||||
config
|
config
|
||||||
)
|
)
|
||||||
VALUES (?,?,?,?,?,?,?);
|
VALUES (?,?,?);
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
key,
|
key,
|
||||||
record.base,
|
|
||||||
record.type,
|
|
||||||
record.name,
|
|
||||||
record.path,
|
|
||||||
record.original_hash,
|
record.original_hash,
|
||||||
json_serialized,
|
json_serialized,
|
||||||
),
|
),
|
||||||
@ -269,14 +262,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
UPDATE model_config
|
UPDATE model_config
|
||||||
SET base=?,
|
SET
|
||||||
type=?,
|
|
||||||
name=?,
|
|
||||||
path=?,
|
|
||||||
config=?
|
config=?
|
||||||
WHERE id=?;
|
WHERE id=?;
|
||||||
""",
|
""",
|
||||||
(record.base, record.type, record.name, record.path, json_serialized, key),
|
(json_serialized, key),
|
||||||
)
|
)
|
||||||
if self._cursor.rowcount == 0:
|
if self._cursor.rowcount == 0:
|
||||||
raise UnknownModelException("model not found")
|
raise UnknownModelException("model not found")
|
||||||
@ -374,7 +364,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
SELECT config FROM model_config
|
SELECT config FROM model_config
|
||||||
WHERE model_path=?;
|
WHERE path=?;
|
||||||
""",
|
""",
|
||||||
(str(path),),
|
(str(path),),
|
||||||
)
|
)
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
|
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
|
||||||
|
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
|
from logging import Logger
|
||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
@ -10,6 +11,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
|||||||
from invokeai.app.services.model_records import (
|
from invokeai.app.services.model_records import (
|
||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
ModelRecordServiceSQL,
|
ModelRecordServiceSQL,
|
||||||
|
UnknownModelException,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
@ -38,9 +40,9 @@ class MigrateModelYamlToDb:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
config: InvokeAIAppConfig
|
config: InvokeAIAppConfig
|
||||||
logger: InvokeAILogger
|
logger: Logger
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.config = InvokeAIAppConfig.get_config()
|
self.config = InvokeAIAppConfig.get_config()
|
||||||
self.config.parse_args()
|
self.config.parse_args()
|
||||||
self.logger = InvokeAILogger.get_logger()
|
self.logger = InvokeAILogger.get_logger()
|
||||||
@ -53,9 +55,11 @@ class MigrateModelYamlToDb:
|
|||||||
def get_yaml(self) -> DictConfig:
|
def get_yaml(self) -> DictConfig:
|
||||||
"""Fetch the models.yaml DictConfig for this installation."""
|
"""Fetch the models.yaml DictConfig for this installation."""
|
||||||
yaml_path = self.config.model_conf_path
|
yaml_path = self.config.model_conf_path
|
||||||
return OmegaConf.load(yaml_path)
|
omegaconf = OmegaConf.load(yaml_path)
|
||||||
|
assert isinstance(omegaconf,DictConfig)
|
||||||
|
return omegaconf
|
||||||
|
|
||||||
def migrate(self):
|
def migrate(self) -> None:
|
||||||
"""Do the migration from models.yaml to invokeai.db."""
|
"""Do the migration from models.yaml to invokeai.db."""
|
||||||
db = self.get_db()
|
db = self.get_db()
|
||||||
yaml = self.get_yaml()
|
yaml = self.get_yaml()
|
||||||
@ -69,6 +73,7 @@ class MigrateModelYamlToDb:
|
|||||||
|
|
||||||
base_type, model_type, model_name = str(model_key).split("/")
|
base_type, model_type, model_name = str(model_key).split("/")
|
||||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
||||||
|
assert isinstance(model_key, str)
|
||||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
stanza["base"] = BaseModelType(base_type)
|
stanza["base"] = BaseModelType(base_type)
|
||||||
@ -77,12 +82,20 @@ class MigrateModelYamlToDb:
|
|||||||
stanza["original_hash"] = hash
|
stanza["original_hash"] = hash
|
||||||
stanza["current_hash"] = hash
|
stanza["current_hash"] = hash
|
||||||
|
|
||||||
new_config = ModelsValidator.validate_python(stanza)
|
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
|
||||||
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
|
||||||
try:
|
try:
|
||||||
|
if original_record := db.search_by_path(stanza.path):
|
||||||
|
key = original_record[0].key
|
||||||
|
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
|
||||||
|
db.update_model(key, new_config)
|
||||||
|
else:
|
||||||
|
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
||||||
db.add_model(new_key, new_config)
|
db.add_model(new_key, new_config)
|
||||||
except DuplicateModelException:
|
except DuplicateModelException:
|
||||||
self.logger.warning(f"Model {model_name} is already in the database")
|
self.logger.warning(f"Model {model_name} is already in the database")
|
||||||
|
except UnknownModelException:
|
||||||
|
self.logger.warning(f"Model at {stanza.path} could not be found in database")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user