mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[Feature] Allow the model record migrate script to update existing model records (#5264)
## What type of PR is this? (check all applicable) - [ ] Refactor - [X] Feature - [ ] Bug Fix - [ ] Optimization - [ ] Documentation Update - [ ] Community Node Submission ## Have you discussed this change with the InvokeAI team? - [X] Yes - [ ] No, because: ## Have you updated all relevant documentation? - [ ] Yes - [X] No ## Description 1. The new model manager sqlite3-based configuration record storage system is automatically populated with probed values from existing models found in the models path when `invokeai-web` starts up for the first time. However, the user's customization of these models in `invokeai.yaml`, including such things as the prediction type and model description, are not automatically copied over. This PR enhances the `invokeai-migrate-models-to-db` script so that any customized configuration data from `invokeai.yaml` replaces the original probed values. This script only needs to be run once, but it does not hurt to run it additional times. In the near future, I'm going to register this module with psychedelicious's sqlite migration system so that the update happens automatically during database migration. 2. The SQL-based model config record system stores a JSON version of the config, as well as several fields that are broken out into individual columns for search/indexing purposes. This PR keeps the JSON and the broken-out fields in sync using the `json_extract()` sqlite3 function to populate the broken out `base`, `type`, `name`, `path` and `format` fields in the `model_config` table. 3. Finally, this PR fixes the annoying `invokeai-web` shutdown message: `TypeError: ModelInstallService.stop() takes 1 positional argument but 2 were given` ## Related Tickets & Documents - Related Issue # - Closes # ## QA Instructions, Screenshots, Recordings If you've run `invokeai-web` at any time since PR #5039, your `invokeai.db` will have a `model_config` table containing probe information from all models in the invokeai models directory as well as those in `autoimport` (if applicable). However, any models present in `models.yaml` whose paths are outside these directories will not be present. To add them, and to update the description and other values from `models.yaml`, run the command `invokeai-migrate-models-to-db`. You should see the missing models added to the database table with the correct information. <!-- Please provide steps on how to test changes, any hardware or software specifications as well as any other pertinent information. --> ## Added/updated tests? - [X] Yes - [ ] No : _please replace this line with details on why tests have not been included_ ## [optional] Are there any post deployment tasks we need to perform?
This commit is contained in:
commit
22ccaa4e9a
@ -85,7 +85,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
|
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
|
||||||
return self._event_bus
|
return self._event_bus
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self, *args, **kwargs) -> None:
|
||||||
"""Stop the install thread; after this the object can be deleted and garbage collected."""
|
"""Stop the install thread; after this the object can be deleted and garbage collected."""
|
||||||
self._install_queue.put(STOP_JOB)
|
self._install_queue.put(STOP_JOB)
|
||||||
|
|
||||||
|
@ -96,10 +96,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
CREATE TABLE IF NOT EXISTS model_config (
|
CREATE TABLE IF NOT EXISTS model_config (
|
||||||
id TEXT NOT NULL PRIMARY KEY,
|
id TEXT NOT NULL PRIMARY KEY,
|
||||||
-- The next 3 fields are enums in python, unrestricted string here
|
-- The next 3 fields are enums in python, unrestricted string here
|
||||||
base TEXT NOT NULL,
|
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
|
||||||
type TEXT NOT NULL,
|
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
|
||||||
name TEXT NOT NULL,
|
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
|
||||||
path TEXT NOT NULL,
|
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
|
||||||
|
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
|
||||||
original_hash TEXT, -- could be null
|
original_hash TEXT, -- could be null
|
||||||
-- Serialized JSON representation of the whole config object,
|
-- Serialized JSON representation of the whole config object,
|
||||||
-- which will contain additional fields from subclasses
|
-- which will contain additional fields from subclasses
|
||||||
@ -175,21 +176,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
"""--sql
|
"""--sql
|
||||||
INSERT INTO model_config (
|
INSERT INTO model_config (
|
||||||
id,
|
id,
|
||||||
base,
|
|
||||||
type,
|
|
||||||
name,
|
|
||||||
path,
|
|
||||||
original_hash,
|
original_hash,
|
||||||
config
|
config
|
||||||
)
|
)
|
||||||
VALUES (?,?,?,?,?,?,?);
|
VALUES (?,?,?);
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
key,
|
key,
|
||||||
record.base,
|
|
||||||
record.type,
|
|
||||||
record.name,
|
|
||||||
record.path,
|
|
||||||
record.original_hash,
|
record.original_hash,
|
||||||
json_serialized,
|
json_serialized,
|
||||||
),
|
),
|
||||||
@ -269,14 +262,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
UPDATE model_config
|
UPDATE model_config
|
||||||
SET base=?,
|
SET
|
||||||
type=?,
|
|
||||||
name=?,
|
|
||||||
path=?,
|
|
||||||
config=?
|
config=?
|
||||||
WHERE id=?;
|
WHERE id=?;
|
||||||
""",
|
""",
|
||||||
(record.base, record.type, record.name, record.path, json_serialized, key),
|
(json_serialized, key),
|
||||||
)
|
)
|
||||||
if self._cursor.rowcount == 0:
|
if self._cursor.rowcount == 0:
|
||||||
raise UnknownModelException("model not found")
|
raise UnknownModelException("model not found")
|
||||||
@ -374,7 +364,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
SELECT config FROM model_config
|
SELECT config FROM model_config
|
||||||
WHERE model_path=?;
|
WHERE path=?;
|
||||||
""",
|
""",
|
||||||
(str(path),),
|
(str(path),),
|
||||||
)
|
)
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
|
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
|
||||||
|
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
|
from logging import Logger
|
||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
@ -10,6 +11,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
|||||||
from invokeai.app.services.model_records import (
|
from invokeai.app.services.model_records import (
|
||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
ModelRecordServiceSQL,
|
ModelRecordServiceSQL,
|
||||||
|
UnknownModelException,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
@ -38,9 +40,9 @@ class MigrateModelYamlToDb:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
config: InvokeAIAppConfig
|
config: InvokeAIAppConfig
|
||||||
logger: InvokeAILogger
|
logger: Logger
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.config = InvokeAIAppConfig.get_config()
|
self.config = InvokeAIAppConfig.get_config()
|
||||||
self.config.parse_args()
|
self.config.parse_args()
|
||||||
self.logger = InvokeAILogger.get_logger()
|
self.logger = InvokeAILogger.get_logger()
|
||||||
@ -53,9 +55,11 @@ class MigrateModelYamlToDb:
|
|||||||
def get_yaml(self) -> DictConfig:
|
def get_yaml(self) -> DictConfig:
|
||||||
"""Fetch the models.yaml DictConfig for this installation."""
|
"""Fetch the models.yaml DictConfig for this installation."""
|
||||||
yaml_path = self.config.model_conf_path
|
yaml_path = self.config.model_conf_path
|
||||||
return OmegaConf.load(yaml_path)
|
omegaconf = OmegaConf.load(yaml_path)
|
||||||
|
assert isinstance(omegaconf, DictConfig)
|
||||||
|
return omegaconf
|
||||||
|
|
||||||
def migrate(self):
|
def migrate(self) -> None:
|
||||||
"""Do the migration from models.yaml to invokeai.db."""
|
"""Do the migration from models.yaml to invokeai.db."""
|
||||||
db = self.get_db()
|
db = self.get_db()
|
||||||
yaml = self.get_yaml()
|
yaml = self.get_yaml()
|
||||||
@ -69,6 +73,7 @@ class MigrateModelYamlToDb:
|
|||||||
|
|
||||||
base_type, model_type, model_name = str(model_key).split("/")
|
base_type, model_type, model_name = str(model_key).split("/")
|
||||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
||||||
|
assert isinstance(model_key, str)
|
||||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
stanza["base"] = BaseModelType(base_type)
|
stanza["base"] = BaseModelType(base_type)
|
||||||
@ -77,12 +82,20 @@ class MigrateModelYamlToDb:
|
|||||||
stanza["original_hash"] = hash
|
stanza["original_hash"] = hash
|
||||||
stanza["current_hash"] = hash
|
stanza["current_hash"] = hash
|
||||||
|
|
||||||
new_config = ModelsValidator.validate_python(stanza)
|
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
|
||||||
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
|
||||||
try:
|
try:
|
||||||
|
if original_record := db.search_by_path(stanza.path):
|
||||||
|
key = original_record[0].key
|
||||||
|
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
|
||||||
|
db.update_model(key, new_config)
|
||||||
|
else:
|
||||||
|
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
||||||
db.add_model(new_key, new_config)
|
db.add_model(new_key, new_config)
|
||||||
except DuplicateModelException:
|
except DuplicateModelException:
|
||||||
self.logger.warning(f"Model {model_name} is already in the database")
|
self.logger.warning(f"Model {model_name} is already in the database")
|
||||||
|
except UnknownModelException:
|
||||||
|
self.logger.warning(f"Model at {stanza.path} could not be found in database")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
Loading…
Reference in New Issue
Block a user