sync pydantic and sql field names; merge routes

This commit is contained in:
Lincoln Stein 2023-11-06 18:08:57 -05:00
parent 55f8865524
commit ce22c0fbaa
6 changed files with 82 additions and 45 deletions

View File

@ -16,7 +16,7 @@ from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType,
from ..dependencies import ApiDependencies
model_records_router = APIRouter(prefix="/v1/model_records", tags=["model_records"])
model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"])
ModelConfigValidator = TypeAdapter(AnyModelConfig)
@ -34,7 +34,7 @@ ModelsListValidator = TypeAdapter(ModelsList)
@model_records_router.get(
"/",
operation_id="list_model_configs",
operation_id="list_model_recordss",
responses={200: {"model": ModelsList}},
)
async def list_model_records(

View File

@ -27,7 +27,7 @@ Typical usage:
# fetching config
new_config = store.get_model('key1')
print(new_config.name, new_config.base_model)
print(new_config.name, new_config.base)
assert new_config.key == 'key1'
# deleting
@ -100,11 +100,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- These 4 fields are enums in python, unrestricted string here
base_model TEXT NOT NULL,
model_type TEXT NOT NULL,
model_name TEXT NOT NULL,
model_path TEXT NOT NULL,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT NOT NULL,
type TEXT NOT NULL,
name TEXT NOT NULL,
path TEXT NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
@ -139,6 +139,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
"""
)
# Add indexes for searchable fields
for stmt in [
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
]:
self._cursor.execute(stmt)
# Add our version to the metadata table
self._cursor.execute(
"""--sql
@ -169,10 +178,10 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
"""--sql
INSERT INTO model_config (
id,
base_model,
model_type,
model_name,
model_path,
base,
type,
name,
path,
original_hash,
config
)
@ -180,7 +189,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
""",
(
key,
record.base_model,
record.base,
record.type,
record.name,
record.path,
@ -193,7 +202,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
except sqlite3.IntegrityError as e:
self._conn.rollback()
if "UNIQUE constraint failed" in str(e):
raise DuplicateModelException(f"A model with key '{key}' is already installed") from e
if "model_config.path" in str(e):
msg = f"A model with path '{record.path}' is already installed"
else:
msg = f"A model with key '{key}' is already installed"
raise DuplicateModelException(msg) from e
else:
raise e
except sqlite3.Error as e:
@ -257,14 +270,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._cursor.execute(
"""--sql
UPDATE model_config
SET base_model=?,
model_type=?,
model_name=?,
model_path=?,
SET base=?,
type=?,
name=?,
path=?,
config=?
WHERE id=?;
""",
(record.base_model, record.type, record.name, record.path, json_serialized, key),
(record.base, record.type, record.name, record.path, json_serialized, key),
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
@ -338,13 +351,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
where_clause = []
bindings = []
if model_name:
where_clause.append("model_name=?")
where_clause.append("name=?")
bindings.append(model_name)
if base_model:
where_clause.append("base_model=?")
where_clause.append("base=?")
bindings.append(base_model)
if model_type:
where_clause.append("model_type=?")
where_clause.append("type=?")
bindings.append(model_type)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
with self._lock:

View File

@ -7,8 +7,8 @@ Typical usage:
from invokeai.backend.model_manager import ModelConfigFactory
raw = dict(path='models/sd-1/main/foo.ckpt',
name='foo',
base_model='sd-1',
model_type='main',
base='sd-1',
type='main',
config='configs/stable-diffusion/v1-inference.yaml',
variant='normal',
format='checkpoint'
@ -103,7 +103,7 @@ class ModelConfigBase(BaseModel):
path: str
name: str
base_model: BaseModelType
base: BaseModelType
type: ModelType
format: ModelFormat
key: str = Field(description="unique key for model", default="<NOKEY>")
@ -181,20 +181,29 @@ class MainConfig(ModelConfigBase):
vae: Optional[str] = Field(None)
variant: ModelVariantType = ModelVariantType.Normal
ztsnr_training: bool = False
class MainCheckpointConfig(CheckpointConfig, MainConfig):
"""Model config for main checkpoint models."""
# Note that we do not need prediction_type or upcast_attention here
# because they are provided in the checkpoint's own config file.
class MainDiffusersConfig(DiffusersConfig, MainConfig):
"""Model config for main diffusers models."""
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class ONNXSD1Config(MainConfig):
"""Model config for ONNX format models based on sd-1."""
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class ONNXSD2Config(MainConfig):
@ -202,8 +211,8 @@ class ONNXSD2Config(MainConfig):
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
# No yaml config file for ONNX, so these are part of config
prediction_type: SchedulerPredictionType
upcast_attention: bool
prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction
upcast_attention: bool = True
class IPAdapterConfig(ModelConfigBase):
@ -305,7 +314,7 @@ class ModelConfigFactory(object):
try:
format = model_data.get("format")
type = model_data.get("type")
model_base = model_data.get("base_model")
model_base = model_data.get("base")
class_to_return = dest_class or cls._class_map[format][type]
if isinstance(class_to_return, dict): # additional level allowed
class_to_return = class_to_return[model_base]

View File

@ -50,7 +50,7 @@ class Migrate:
hash = FastModelHash.hash(self.config.models_path / stanza.path)
new_key = sha1(model_key.encode("utf-8")).hexdigest()
stanza["base_model"] = BaseModelType(base_type)
stanza["base"] = BaseModelType(base_type)
stanza["type"] = ModelType(model_type)
stanza["name"] = model_name
stanza["original_hash"] = hash

View File

@ -51,6 +51,7 @@ dependencies = [
"fastapi~=0.103.2",
"fastapi-events~=0.9.1",
"huggingface-hub~=0.16.4",
"imohash",
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
"matplotlib", # needed for plotting of Penner easing functions
"mediapipe", # needed for "mediapipeface" controlnet model

View File

@ -8,7 +8,12 @@ from hashlib import sha256
import pytest
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
from invokeai.app.services.model_records import (
DuplicateModelException,
ModelRecordServiceBase,
ModelRecordServiceSQL,
UnknownModelException,
)
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.backend.model_manager.config import (
BaseModelType,
@ -32,8 +37,8 @@ def example_config() -> TextualInversionConfig:
return TextualInversionConfig(
path="/tmp/pokemon.bin",
name="old name",
base_model="sd-1",
type="embedding",
base=BaseModelType("sd-1"),
type=ModelType("embedding"),
format="embedding_file",
original_hash="ABC123",
)
@ -43,7 +48,7 @@ def test_add(store: ModelRecordServiceBase):
raw = dict(
path="/tmp/foo.ckpt",
name="model1",
base_model="sd-1",
base=BaseModelType("sd-1"),
type="main",
config="/tmp/foo.yaml",
variant="normal",
@ -53,16 +58,25 @@ def test_add(store: ModelRecordServiceBase):
store.add_model("key1", raw)
config1 = store.get_model("key1")
assert config1 is not None
raw["name"] = "model2"
raw["base_model"] = "sd-2"
raw["format"] = "diffusers"
raw.pop("config")
store.add_model("key2", raw)
config2 = store.get_model("key2")
assert config1.base == BaseModelType("sd-1")
assert config1.name == "model1"
assert config2.name == "model2"
assert config1.base_model == "sd-1"
assert config2.base_model == "sd-2"
assert config1.original_hash == "111222333444"
assert config1.current_hash is None
def test_dup(store: ModelRecordServiceBase):
config = example_config()
store.add_model("key1", example_config())
try:
store.add_model("key1", config)
assert False, "Duplicate model key should have been caught"
except DuplicateModelException:
assert True
try:
store.add_model("key2", config)
assert False, "Duplicate model path should have been caught"
except DuplicateModelException:
assert True
def test_update(store: ModelRecordServiceBase):
@ -115,21 +129,21 @@ def test_filter(store: ModelRecordServiceBase):
config1 = DiffusersConfig(
path="/tmp/config1",
name="config1",
base_model=BaseModelType("sd-1"),
base=BaseModelType("sd-1"),
type=ModelType("main"),
original_hash="CONFIG1HASH",
)
config2 = DiffusersConfig(
path="/tmp/config2",
name="config2",
base_model=BaseModelType("sd-1"),
base=BaseModelType("sd-1"),
type=ModelType("main"),
original_hash="CONFIG2HASH",
)
config3 = VaeDiffusersConfig(
path="/tmp/config3",
name="config3",
base_model=BaseModelType("sd-2"),
base=BaseModelType("sd-2"),
type=ModelType("vae"),
original_hash="CONFIG3HASH",
)