sync pydantic and sql field names; merge routes

This commit is contained in:
Lincoln Stein 2023-11-06 18:08:57 -05:00
parent 55f8865524
commit ce22c0fbaa
6 changed files with 82 additions and 45 deletions

View File

@ -16,7 +16,7 @@ from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType,
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
model_records_router = APIRouter(prefix="/v1/model_records", tags=["model_records"]) model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"])
ModelConfigValidator = TypeAdapter(AnyModelConfig) ModelConfigValidator = TypeAdapter(AnyModelConfig)
@ -34,7 +34,7 @@ ModelsListValidator = TypeAdapter(ModelsList)
@model_records_router.get( @model_records_router.get(
"/", "/",
operation_id="list_model_configs", operation_id="list_model_recordss",
responses={200: {"model": ModelsList}}, responses={200: {"model": ModelsList}},
) )
async def list_model_records( async def list_model_records(

View File

@ -27,7 +27,7 @@ Typical usage:
# fetching config # fetching config
new_config = store.get_model('key1') new_config = store.get_model('key1')
print(new_config.name, new_config.base_model) print(new_config.name, new_config.base)
assert new_config.key == 'key1' assert new_config.key == 'key1'
# deleting # deleting
@ -100,11 +100,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
"""--sql """--sql
CREATE TABLE IF NOT EXISTS model_config ( CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY, id TEXT NOT NULL PRIMARY KEY,
-- These 4 fields are enums in python, unrestricted string here -- The next 3 fields are enums in python, unrestricted string here
base_model TEXT NOT NULL, base TEXT NOT NULL,
model_type TEXT NOT NULL, type TEXT NOT NULL,
model_name TEXT NOT NULL, name TEXT NOT NULL,
model_path TEXT NOT NULL, path TEXT NOT NULL,
original_hash TEXT, -- could be null original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object, -- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses -- which will contain additional fields from subclasses
@ -139,6 +139,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
""" """
) )
# Add indexes for searchable fields
for stmt in [
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
]:
self._cursor.execute(stmt)
# Add our version to the metadata table # Add our version to the metadata table
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
@ -169,10 +178,10 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
"""--sql """--sql
INSERT INTO model_config ( INSERT INTO model_config (
id, id,
base_model, base,
model_type, type,
model_name, name,
model_path, path,
original_hash, original_hash,
config config
) )
@ -180,7 +189,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
""", """,
( (
key, key,
record.base_model, record.base,
record.type, record.type,
record.name, record.name,
record.path, record.path,
@ -193,7 +202,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
except sqlite3.IntegrityError as e: except sqlite3.IntegrityError as e:
self._conn.rollback() self._conn.rollback()
if "UNIQUE constraint failed" in str(e): if "UNIQUE constraint failed" in str(e):
raise DuplicateModelException(f"A model with key '{key}' is already installed") from e if "model_config.path" in str(e):
msg = f"A model with path '{record.path}' is already installed"
else:
msg = f"A model with key '{key}' is already installed"
raise DuplicateModelException(msg) from e
else: else:
raise e raise e
except sqlite3.Error as e: except sqlite3.Error as e:
@ -257,14 +270,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
UPDATE model_config UPDATE model_config
SET base_model=?, SET base=?,
model_type=?, type=?,
model_name=?, name=?,
model_path=?, path=?,
config=? config=?
WHERE id=?; WHERE id=?;
""", """,
(record.base_model, record.type, record.name, record.path, json_serialized, key), (record.base, record.type, record.name, record.path, json_serialized, key),
) )
if self._cursor.rowcount == 0: if self._cursor.rowcount == 0:
raise UnknownModelException("model not found") raise UnknownModelException("model not found")
@ -338,13 +351,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
where_clause = [] where_clause = []
bindings = [] bindings = []
if model_name: if model_name:
where_clause.append("model_name=?") where_clause.append("name=?")
bindings.append(model_name) bindings.append(model_name)
if base_model: if base_model:
where_clause.append("base_model=?") where_clause.append("base=?")
bindings.append(base_model) bindings.append(base_model)
if model_type: if model_type:
where_clause.append("model_type=?") where_clause.append("type=?")
bindings.append(model_type) bindings.append(model_type)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else "" where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
with self._lock: with self._lock:

View File

@ -7,8 +7,8 @@ Typical usage:
from invokeai.backend.model_manager import ModelConfigFactory from invokeai.backend.model_manager import ModelConfigFactory
raw = dict(path='models/sd-1/main/foo.ckpt', raw = dict(path='models/sd-1/main/foo.ckpt',
name='foo', name='foo',
base_model='sd-1', base='sd-1',
model_type='main', type='main',
config='configs/stable-diffusion/v1-inference.yaml', config='configs/stable-diffusion/v1-inference.yaml',
variant='normal', variant='normal',
format='checkpoint' format='checkpoint'
@ -103,7 +103,7 @@ class ModelConfigBase(BaseModel):
path: str path: str
name: str name: str
base_model: BaseModelType base: BaseModelType
type: ModelType type: ModelType
format: ModelFormat format: ModelFormat
key: str = Field(description="unique key for model", default="<NOKEY>") key: str = Field(description="unique key for model", default="<NOKEY>")
@ -181,20 +181,29 @@ class MainConfig(ModelConfigBase):
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
variant: ModelVariantType = ModelVariantType.Normal variant: ModelVariantType = ModelVariantType.Normal
ztsnr_training: bool = False
class MainCheckpointConfig(CheckpointConfig, MainConfig): class MainCheckpointConfig(CheckpointConfig, MainConfig):
"""Model config for main checkpoint models.""" """Model config for main checkpoint models."""
# Note that we do not need prediction_type or upcast_attention here
# because they are provided in the checkpoint's own config file.
class MainDiffusersConfig(DiffusersConfig, MainConfig): class MainDiffusersConfig(DiffusersConfig, MainConfig):
"""Model config for main diffusers models.""" """Model config for main diffusers models."""
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class ONNXSD1Config(MainConfig): class ONNXSD1Config(MainConfig):
"""Model config for ONNX format models based on sd-1.""" """Model config for ONNX format models based on sd-1."""
format: Literal[ModelFormat.Onnx, ModelFormat.Olive] format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class ONNXSD2Config(MainConfig): class ONNXSD2Config(MainConfig):
@ -202,8 +211,8 @@ class ONNXSD2Config(MainConfig):
format: Literal[ModelFormat.Onnx, ModelFormat.Olive] format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
# No yaml config file for ONNX, so these are part of config # No yaml config file for ONNX, so these are part of config
prediction_type: SchedulerPredictionType prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction
upcast_attention: bool upcast_attention: bool = True
class IPAdapterConfig(ModelConfigBase): class IPAdapterConfig(ModelConfigBase):
@ -305,7 +314,7 @@ class ModelConfigFactory(object):
try: try:
format = model_data.get("format") format = model_data.get("format")
type = model_data.get("type") type = model_data.get("type")
model_base = model_data.get("base_model") model_base = model_data.get("base")
class_to_return = dest_class or cls._class_map[format][type] class_to_return = dest_class or cls._class_map[format][type]
if isinstance(class_to_return, dict): # additional level allowed if isinstance(class_to_return, dict): # additional level allowed
class_to_return = class_to_return[model_base] class_to_return = class_to_return[model_base]

View File

@ -50,7 +50,7 @@ class Migrate:
hash = FastModelHash.hash(self.config.models_path / stanza.path) hash = FastModelHash.hash(self.config.models_path / stanza.path)
new_key = sha1(model_key.encode("utf-8")).hexdigest() new_key = sha1(model_key.encode("utf-8")).hexdigest()
stanza["base_model"] = BaseModelType(base_type) stanza["base"] = BaseModelType(base_type)
stanza["type"] = ModelType(model_type) stanza["type"] = ModelType(model_type)
stanza["name"] = model_name stanza["name"] = model_name
stanza["original_hash"] = hash stanza["original_hash"] = hash

View File

@ -51,6 +51,7 @@ dependencies = [
"fastapi~=0.103.2", "fastapi~=0.103.2",
"fastapi-events~=0.9.1", "fastapi-events~=0.9.1",
"huggingface-hub~=0.16.4", "huggingface-hub~=0.16.4",
"imohash",
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids "invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
"matplotlib", # needed for plotting of Penner easing functions "matplotlib", # needed for plotting of Penner easing functions
"mediapipe", # needed for "mediapipeface" controlnet model "mediapipe", # needed for "mediapipeface" controlnet model

View File

@ -8,7 +8,12 @@ from hashlib import sha256
import pytest import pytest
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException from invokeai.app.services.model_records import (
DuplicateModelException,
ModelRecordServiceBase,
ModelRecordServiceSQL,
UnknownModelException,
)
from invokeai.app.services.shared.sqlite import SqliteDatabase from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
BaseModelType, BaseModelType,
@ -32,8 +37,8 @@ def example_config() -> TextualInversionConfig:
return TextualInversionConfig( return TextualInversionConfig(
path="/tmp/pokemon.bin", path="/tmp/pokemon.bin",
name="old name", name="old name",
base_model="sd-1", base=BaseModelType("sd-1"),
type="embedding", type=ModelType("embedding"),
format="embedding_file", format="embedding_file",
original_hash="ABC123", original_hash="ABC123",
) )
@ -43,7 +48,7 @@ def test_add(store: ModelRecordServiceBase):
raw = dict( raw = dict(
path="/tmp/foo.ckpt", path="/tmp/foo.ckpt",
name="model1", name="model1",
base_model="sd-1", base=BaseModelType("sd-1"),
type="main", type="main",
config="/tmp/foo.yaml", config="/tmp/foo.yaml",
variant="normal", variant="normal",
@ -53,16 +58,25 @@ def test_add(store: ModelRecordServiceBase):
store.add_model("key1", raw) store.add_model("key1", raw)
config1 = store.get_model("key1") config1 = store.get_model("key1")
assert config1 is not None assert config1 is not None
raw["name"] = "model2" assert config1.base == BaseModelType("sd-1")
raw["base_model"] = "sd-2"
raw["format"] = "diffusers"
raw.pop("config")
store.add_model("key2", raw)
config2 = store.get_model("key2")
assert config1.name == "model1" assert config1.name == "model1"
assert config2.name == "model2" assert config1.original_hash == "111222333444"
assert config1.base_model == "sd-1" assert config1.current_hash is None
assert config2.base_model == "sd-2"
def test_dup(store: ModelRecordServiceBase):
config = example_config()
store.add_model("key1", example_config())
try:
store.add_model("key1", config)
assert False, "Duplicate model key should have been caught"
except DuplicateModelException:
assert True
try:
store.add_model("key2", config)
assert False, "Duplicate model path should have been caught"
except DuplicateModelException:
assert True
def test_update(store: ModelRecordServiceBase): def test_update(store: ModelRecordServiceBase):
@ -115,21 +129,21 @@ def test_filter(store: ModelRecordServiceBase):
config1 = DiffusersConfig( config1 = DiffusersConfig(
path="/tmp/config1", path="/tmp/config1",
name="config1", name="config1",
base_model=BaseModelType("sd-1"), base=BaseModelType("sd-1"),
type=ModelType("main"), type=ModelType("main"),
original_hash="CONFIG1HASH", original_hash="CONFIG1HASH",
) )
config2 = DiffusersConfig( config2 = DiffusersConfig(
path="/tmp/config2", path="/tmp/config2",
name="config2", name="config2",
base_model=BaseModelType("sd-1"), base=BaseModelType("sd-1"),
type=ModelType("main"), type=ModelType("main"),
original_hash="CONFIG2HASH", original_hash="CONFIG2HASH",
) )
config3 = VaeDiffusersConfig( config3 = VaeDiffusersConfig(
path="/tmp/config3", path="/tmp/config3",
name="config3", name="config3",
base_model=BaseModelType("sd-2"), base=BaseModelType("sd-2"),
type=ModelType("vae"), type=ModelType("vae"),
original_hash="CONFIG3HASH", original_hash="CONFIG3HASH",
) )