mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
sync pydantic and sql field names; merge routes
This commit is contained in:
parent
55f8865524
commit
ce22c0fbaa
@ -16,7 +16,7 @@ from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType,
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
model_records_router = APIRouter(prefix="/v1/model_records", tags=["model_records"])
|
||||
model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"])
|
||||
|
||||
ModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
@ -34,7 +34,7 @@ ModelsListValidator = TypeAdapter(ModelsList)
|
||||
|
||||
@model_records_router.get(
|
||||
"/",
|
||||
operation_id="list_model_configs",
|
||||
operation_id="list_model_recordss",
|
||||
responses={200: {"model": ModelsList}},
|
||||
)
|
||||
async def list_model_records(
|
||||
|
@ -27,7 +27,7 @@ Typical usage:
|
||||
|
||||
# fetching config
|
||||
new_config = store.get_model('key1')
|
||||
print(new_config.name, new_config.base_model)
|
||||
print(new_config.name, new_config.base)
|
||||
assert new_config.key == 'key1'
|
||||
|
||||
# deleting
|
||||
@ -100,11 +100,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_config (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
-- These 4 fields are enums in python, unrestricted string here
|
||||
base_model TEXT NOT NULL,
|
||||
model_type TEXT NOT NULL,
|
||||
model_name TEXT NOT NULL,
|
||||
model_path TEXT NOT NULL,
|
||||
-- The next 3 fields are enums in python, unrestricted string here
|
||||
base TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
original_hash TEXT, -- could be null
|
||||
-- Serialized JSON representation of the whole config object,
|
||||
-- which will contain additional fields from subclasses
|
||||
@ -139,6 +139,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""
|
||||
)
|
||||
|
||||
# Add indexes for searchable fields
|
||||
for stmt in [
|
||||
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
|
||||
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
|
||||
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
|
||||
]:
|
||||
self._cursor.execute(stmt)
|
||||
|
||||
# Add our version to the metadata table
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
@ -169,10 +178,10 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""--sql
|
||||
INSERT INTO model_config (
|
||||
id,
|
||||
base_model,
|
||||
model_type,
|
||||
model_name,
|
||||
model_path,
|
||||
base,
|
||||
type,
|
||||
name,
|
||||
path,
|
||||
original_hash,
|
||||
config
|
||||
)
|
||||
@ -180,7 +189,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
""",
|
||||
(
|
||||
key,
|
||||
record.base_model,
|
||||
record.base,
|
||||
record.type,
|
||||
record.name,
|
||||
record.path,
|
||||
@ -193,7 +202,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
raise DuplicateModelException(f"A model with key '{key}' is already installed") from e
|
||||
if "model_config.path" in str(e):
|
||||
msg = f"A model with path '{record.path}' is already installed"
|
||||
else:
|
||||
msg = f"A model with key '{key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
else:
|
||||
raise e
|
||||
except sqlite3.Error as e:
|
||||
@ -257,14 +270,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_config
|
||||
SET base_model=?,
|
||||
model_type=?,
|
||||
model_name=?,
|
||||
model_path=?,
|
||||
SET base=?,
|
||||
type=?,
|
||||
name=?,
|
||||
path=?,
|
||||
config=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(record.base_model, record.type, record.name, record.path, json_serialized, key),
|
||||
(record.base, record.type, record.name, record.path, json_serialized, key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
@ -338,13 +351,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
where_clause = []
|
||||
bindings = []
|
||||
if model_name:
|
||||
where_clause.append("model_name=?")
|
||||
where_clause.append("name=?")
|
||||
bindings.append(model_name)
|
||||
if base_model:
|
||||
where_clause.append("base_model=?")
|
||||
where_clause.append("base=?")
|
||||
bindings.append(base_model)
|
||||
if model_type:
|
||||
where_clause.append("model_type=?")
|
||||
where_clause.append("type=?")
|
||||
bindings.append(model_type)
|
||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
||||
with self._lock:
|
||||
|
@ -7,8 +7,8 @@ Typical usage:
|
||||
from invokeai.backend.model_manager import ModelConfigFactory
|
||||
raw = dict(path='models/sd-1/main/foo.ckpt',
|
||||
name='foo',
|
||||
base_model='sd-1',
|
||||
model_type='main',
|
||||
base='sd-1',
|
||||
type='main',
|
||||
config='configs/stable-diffusion/v1-inference.yaml',
|
||||
variant='normal',
|
||||
format='checkpoint'
|
||||
@ -103,7 +103,7 @@ class ModelConfigBase(BaseModel):
|
||||
|
||||
path: str
|
||||
name: str
|
||||
base_model: BaseModelType
|
||||
base: BaseModelType
|
||||
type: ModelType
|
||||
format: ModelFormat
|
||||
key: str = Field(description="unique key for model", default="<NOKEY>")
|
||||
@ -181,20 +181,29 @@ class MainConfig(ModelConfigBase):
|
||||
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
ztsnr_training: bool = False
|
||||
|
||||
|
||||
class MainCheckpointConfig(CheckpointConfig, MainConfig):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
# Note that we do not need prediction_type or upcast_attention here
|
||||
# because they are provided in the checkpoint's own config file.
|
||||
|
||||
|
||||
class MainDiffusersConfig(DiffusersConfig, MainConfig):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
|
||||
class ONNXSD1Config(MainConfig):
|
||||
"""Model config for ONNX format models based on sd-1."""
|
||||
|
||||
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
|
||||
class ONNXSD2Config(MainConfig):
|
||||
@ -202,8 +211,8 @@ class ONNXSD2Config(MainConfig):
|
||||
|
||||
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
||||
# No yaml config file for ONNX, so these are part of config
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction
|
||||
upcast_attention: bool = True
|
||||
|
||||
|
||||
class IPAdapterConfig(ModelConfigBase):
|
||||
@ -305,7 +314,7 @@ class ModelConfigFactory(object):
|
||||
try:
|
||||
format = model_data.get("format")
|
||||
type = model_data.get("type")
|
||||
model_base = model_data.get("base_model")
|
||||
model_base = model_data.get("base")
|
||||
class_to_return = dest_class or cls._class_map[format][type]
|
||||
if isinstance(class_to_return, dict): # additional level allowed
|
||||
class_to_return = class_to_return[model_base]
|
||||
|
@ -50,7 +50,7 @@ class Migrate:
|
||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
||||
|
||||
stanza["base_model"] = BaseModelType(base_type)
|
||||
stanza["base"] = BaseModelType(base_type)
|
||||
stanza["type"] = ModelType(model_type)
|
||||
stanza["name"] = model_name
|
||||
stanza["original_hash"] = hash
|
||||
|
@ -51,6 +51,7 @@ dependencies = [
|
||||
"fastapi~=0.103.2",
|
||||
"fastapi-events~=0.9.1",
|
||||
"huggingface-hub~=0.16.4",
|
||||
"imohash",
|
||||
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||
"matplotlib", # needed for plotting of Penner easing functions
|
||||
"mediapipe", # needed for "mediapipeface" controlnet model
|
||||
|
@ -8,7 +8,12 @@ from hashlib import sha256
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
ModelRecordServiceBase,
|
||||
ModelRecordServiceSQL,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
@ -32,8 +37,8 @@ def example_config() -> TextualInversionConfig:
|
||||
return TextualInversionConfig(
|
||||
path="/tmp/pokemon.bin",
|
||||
name="old name",
|
||||
base_model="sd-1",
|
||||
type="embedding",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("embedding"),
|
||||
format="embedding_file",
|
||||
original_hash="ABC123",
|
||||
)
|
||||
@ -43,7 +48,7 @@ def test_add(store: ModelRecordServiceBase):
|
||||
raw = dict(
|
||||
path="/tmp/foo.ckpt",
|
||||
name="model1",
|
||||
base_model="sd-1",
|
||||
base=BaseModelType("sd-1"),
|
||||
type="main",
|
||||
config="/tmp/foo.yaml",
|
||||
variant="normal",
|
||||
@ -53,16 +58,25 @@ def test_add(store: ModelRecordServiceBase):
|
||||
store.add_model("key1", raw)
|
||||
config1 = store.get_model("key1")
|
||||
assert config1 is not None
|
||||
raw["name"] = "model2"
|
||||
raw["base_model"] = "sd-2"
|
||||
raw["format"] = "diffusers"
|
||||
raw.pop("config")
|
||||
store.add_model("key2", raw)
|
||||
config2 = store.get_model("key2")
|
||||
assert config1.base == BaseModelType("sd-1")
|
||||
assert config1.name == "model1"
|
||||
assert config2.name == "model2"
|
||||
assert config1.base_model == "sd-1"
|
||||
assert config2.base_model == "sd-2"
|
||||
assert config1.original_hash == "111222333444"
|
||||
assert config1.current_hash is None
|
||||
|
||||
|
||||
def test_dup(store: ModelRecordServiceBase):
|
||||
config = example_config()
|
||||
store.add_model("key1", example_config())
|
||||
try:
|
||||
store.add_model("key1", config)
|
||||
assert False, "Duplicate model key should have been caught"
|
||||
except DuplicateModelException:
|
||||
assert True
|
||||
try:
|
||||
store.add_model("key2", config)
|
||||
assert False, "Duplicate model path should have been caught"
|
||||
except DuplicateModelException:
|
||||
assert True
|
||||
|
||||
|
||||
def test_update(store: ModelRecordServiceBase):
|
||||
@ -115,21 +129,21 @@ def test_filter(store: ModelRecordServiceBase):
|
||||
config1 = DiffusersConfig(
|
||||
path="/tmp/config1",
|
||||
name="config1",
|
||||
base_model=BaseModelType("sd-1"),
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
original_hash="CONFIG1HASH",
|
||||
)
|
||||
config2 = DiffusersConfig(
|
||||
path="/tmp/config2",
|
||||
name="config2",
|
||||
base_model=BaseModelType("sd-1"),
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
original_hash="CONFIG2HASH",
|
||||
)
|
||||
config3 = VaeDiffusersConfig(
|
||||
path="/tmp/config3",
|
||||
name="config3",
|
||||
base_model=BaseModelType("sd-2"),
|
||||
base=BaseModelType("sd-2"),
|
||||
type=ModelType("vae"),
|
||||
original_hash="CONFIG3HASH",
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user