mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
sync pydantic and sql field names; merge routes
This commit is contained in:
parent
55f8865524
commit
ce22c0fbaa
@ -16,7 +16,7 @@ from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType,
|
|||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
model_records_router = APIRouter(prefix="/v1/model_records", tags=["model_records"])
|
model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"])
|
||||||
|
|
||||||
ModelConfigValidator = TypeAdapter(AnyModelConfig)
|
ModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||||
|
|
||||||
@ -34,7 +34,7 @@ ModelsListValidator = TypeAdapter(ModelsList)
|
|||||||
|
|
||||||
@model_records_router.get(
|
@model_records_router.get(
|
||||||
"/",
|
"/",
|
||||||
operation_id="list_model_configs",
|
operation_id="list_model_recordss",
|
||||||
responses={200: {"model": ModelsList}},
|
responses={200: {"model": ModelsList}},
|
||||||
)
|
)
|
||||||
async def list_model_records(
|
async def list_model_records(
|
||||||
|
@ -27,7 +27,7 @@ Typical usage:
|
|||||||
|
|
||||||
# fetching config
|
# fetching config
|
||||||
new_config = store.get_model('key1')
|
new_config = store.get_model('key1')
|
||||||
print(new_config.name, new_config.base_model)
|
print(new_config.name, new_config.base)
|
||||||
assert new_config.key == 'key1'
|
assert new_config.key == 'key1'
|
||||||
|
|
||||||
# deleting
|
# deleting
|
||||||
@ -100,11 +100,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
"""--sql
|
"""--sql
|
||||||
CREATE TABLE IF NOT EXISTS model_config (
|
CREATE TABLE IF NOT EXISTS model_config (
|
||||||
id TEXT NOT NULL PRIMARY KEY,
|
id TEXT NOT NULL PRIMARY KEY,
|
||||||
-- These 4 fields are enums in python, unrestricted string here
|
-- The next 3 fields are enums in python, unrestricted string here
|
||||||
base_model TEXT NOT NULL,
|
base TEXT NOT NULL,
|
||||||
model_type TEXT NOT NULL,
|
type TEXT NOT NULL,
|
||||||
model_name TEXT NOT NULL,
|
name TEXT NOT NULL,
|
||||||
model_path TEXT NOT NULL,
|
path TEXT NOT NULL,
|
||||||
original_hash TEXT, -- could be null
|
original_hash TEXT, -- could be null
|
||||||
-- Serialized JSON representation of the whole config object,
|
-- Serialized JSON representation of the whole config object,
|
||||||
-- which will contain additional fields from subclasses
|
-- which will contain additional fields from subclasses
|
||||||
@ -139,6 +139,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add indexes for searchable fields
|
||||||
|
for stmt in [
|
||||||
|
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
|
||||||
|
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
|
||||||
|
]:
|
||||||
|
self._cursor.execute(stmt)
|
||||||
|
|
||||||
# Add our version to the metadata table
|
# Add our version to the metadata table
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
@ -169,10 +178,10 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
"""--sql
|
"""--sql
|
||||||
INSERT INTO model_config (
|
INSERT INTO model_config (
|
||||||
id,
|
id,
|
||||||
base_model,
|
base,
|
||||||
model_type,
|
type,
|
||||||
model_name,
|
name,
|
||||||
model_path,
|
path,
|
||||||
original_hash,
|
original_hash,
|
||||||
config
|
config
|
||||||
)
|
)
|
||||||
@ -180,7 +189,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
key,
|
key,
|
||||||
record.base_model,
|
record.base,
|
||||||
record.type,
|
record.type,
|
||||||
record.name,
|
record.name,
|
||||||
record.path,
|
record.path,
|
||||||
@ -193,7 +202,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
except sqlite3.IntegrityError as e:
|
except sqlite3.IntegrityError as e:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
if "UNIQUE constraint failed" in str(e):
|
if "UNIQUE constraint failed" in str(e):
|
||||||
raise DuplicateModelException(f"A model with key '{key}' is already installed") from e
|
if "model_config.path" in str(e):
|
||||||
|
msg = f"A model with path '{record.path}' is already installed"
|
||||||
|
else:
|
||||||
|
msg = f"A model with key '{key}' is already installed"
|
||||||
|
raise DuplicateModelException(msg) from e
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
@ -257,14 +270,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
UPDATE model_config
|
UPDATE model_config
|
||||||
SET base_model=?,
|
SET base=?,
|
||||||
model_type=?,
|
type=?,
|
||||||
model_name=?,
|
name=?,
|
||||||
model_path=?,
|
path=?,
|
||||||
config=?
|
config=?
|
||||||
WHERE id=?;
|
WHERE id=?;
|
||||||
""",
|
""",
|
||||||
(record.base_model, record.type, record.name, record.path, json_serialized, key),
|
(record.base, record.type, record.name, record.path, json_serialized, key),
|
||||||
)
|
)
|
||||||
if self._cursor.rowcount == 0:
|
if self._cursor.rowcount == 0:
|
||||||
raise UnknownModelException("model not found")
|
raise UnknownModelException("model not found")
|
||||||
@ -338,13 +351,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
where_clause = []
|
where_clause = []
|
||||||
bindings = []
|
bindings = []
|
||||||
if model_name:
|
if model_name:
|
||||||
where_clause.append("model_name=?")
|
where_clause.append("name=?")
|
||||||
bindings.append(model_name)
|
bindings.append(model_name)
|
||||||
if base_model:
|
if base_model:
|
||||||
where_clause.append("base_model=?")
|
where_clause.append("base=?")
|
||||||
bindings.append(base_model)
|
bindings.append(base_model)
|
||||||
if model_type:
|
if model_type:
|
||||||
where_clause.append("model_type=?")
|
where_clause.append("type=?")
|
||||||
bindings.append(model_type)
|
bindings.append(model_type)
|
||||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
@ -7,8 +7,8 @@ Typical usage:
|
|||||||
from invokeai.backend.model_manager import ModelConfigFactory
|
from invokeai.backend.model_manager import ModelConfigFactory
|
||||||
raw = dict(path='models/sd-1/main/foo.ckpt',
|
raw = dict(path='models/sd-1/main/foo.ckpt',
|
||||||
name='foo',
|
name='foo',
|
||||||
base_model='sd-1',
|
base='sd-1',
|
||||||
model_type='main',
|
type='main',
|
||||||
config='configs/stable-diffusion/v1-inference.yaml',
|
config='configs/stable-diffusion/v1-inference.yaml',
|
||||||
variant='normal',
|
variant='normal',
|
||||||
format='checkpoint'
|
format='checkpoint'
|
||||||
@ -103,7 +103,7 @@ class ModelConfigBase(BaseModel):
|
|||||||
|
|
||||||
path: str
|
path: str
|
||||||
name: str
|
name: str
|
||||||
base_model: BaseModelType
|
base: BaseModelType
|
||||||
type: ModelType
|
type: ModelType
|
||||||
format: ModelFormat
|
format: ModelFormat
|
||||||
key: str = Field(description="unique key for model", default="<NOKEY>")
|
key: str = Field(description="unique key for model", default="<NOKEY>")
|
||||||
@ -181,20 +181,29 @@ class MainConfig(ModelConfigBase):
|
|||||||
|
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType = ModelVariantType.Normal
|
variant: ModelVariantType = ModelVariantType.Normal
|
||||||
|
ztsnr_training: bool = False
|
||||||
|
|
||||||
|
|
||||||
class MainCheckpointConfig(CheckpointConfig, MainConfig):
|
class MainCheckpointConfig(CheckpointConfig, MainConfig):
|
||||||
"""Model config for main checkpoint models."""
|
"""Model config for main checkpoint models."""
|
||||||
|
|
||||||
|
# Note that we do not need prediction_type or upcast_attention here
|
||||||
|
# because they are provided in the checkpoint's own config file.
|
||||||
|
|
||||||
|
|
||||||
class MainDiffusersConfig(DiffusersConfig, MainConfig):
|
class MainDiffusersConfig(DiffusersConfig, MainConfig):
|
||||||
"""Model config for main diffusers models."""
|
"""Model config for main diffusers models."""
|
||||||
|
|
||||||
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||||
|
upcast_attention: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ONNXSD1Config(MainConfig):
|
class ONNXSD1Config(MainConfig):
|
||||||
"""Model config for ONNX format models based on sd-1."""
|
"""Model config for ONNX format models based on sd-1."""
|
||||||
|
|
||||||
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
||||||
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||||
|
upcast_attention: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ONNXSD2Config(MainConfig):
|
class ONNXSD2Config(MainConfig):
|
||||||
@ -202,8 +211,8 @@ class ONNXSD2Config(MainConfig):
|
|||||||
|
|
||||||
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
||||||
# No yaml config file for ONNX, so these are part of config
|
# No yaml config file for ONNX, so these are part of config
|
||||||
prediction_type: SchedulerPredictionType
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction
|
||||||
upcast_attention: bool
|
upcast_attention: bool = True
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterConfig(ModelConfigBase):
|
class IPAdapterConfig(ModelConfigBase):
|
||||||
@ -305,7 +314,7 @@ class ModelConfigFactory(object):
|
|||||||
try:
|
try:
|
||||||
format = model_data.get("format")
|
format = model_data.get("format")
|
||||||
type = model_data.get("type")
|
type = model_data.get("type")
|
||||||
model_base = model_data.get("base_model")
|
model_base = model_data.get("base")
|
||||||
class_to_return = dest_class or cls._class_map[format][type]
|
class_to_return = dest_class or cls._class_map[format][type]
|
||||||
if isinstance(class_to_return, dict): # additional level allowed
|
if isinstance(class_to_return, dict): # additional level allowed
|
||||||
class_to_return = class_to_return[model_base]
|
class_to_return = class_to_return[model_base]
|
||||||
|
@ -50,7 +50,7 @@ class Migrate:
|
|||||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
||||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
stanza["base_model"] = BaseModelType(base_type)
|
stanza["base"] = BaseModelType(base_type)
|
||||||
stanza["type"] = ModelType(model_type)
|
stanza["type"] = ModelType(model_type)
|
||||||
stanza["name"] = model_name
|
stanza["name"] = model_name
|
||||||
stanza["original_hash"] = hash
|
stanza["original_hash"] = hash
|
||||||
|
@ -51,6 +51,7 @@ dependencies = [
|
|||||||
"fastapi~=0.103.2",
|
"fastapi~=0.103.2",
|
||||||
"fastapi-events~=0.9.1",
|
"fastapi-events~=0.9.1",
|
||||||
"huggingface-hub~=0.16.4",
|
"huggingface-hub~=0.16.4",
|
||||||
|
"imohash",
|
||||||
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||||
"matplotlib", # needed for plotting of Penner easing functions
|
"matplotlib", # needed for plotting of Penner easing functions
|
||||||
"mediapipe", # needed for "mediapipeface" controlnet model
|
"mediapipe", # needed for "mediapipeface" controlnet model
|
||||||
|
@ -8,7 +8,12 @@ from hashlib import sha256
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
|
from invokeai.app.services.model_records import (
|
||||||
|
DuplicateModelException,
|
||||||
|
ModelRecordServiceBase,
|
||||||
|
ModelRecordServiceSQL,
|
||||||
|
UnknownModelException,
|
||||||
|
)
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
@ -32,8 +37,8 @@ def example_config() -> TextualInversionConfig:
|
|||||||
return TextualInversionConfig(
|
return TextualInversionConfig(
|
||||||
path="/tmp/pokemon.bin",
|
path="/tmp/pokemon.bin",
|
||||||
name="old name",
|
name="old name",
|
||||||
base_model="sd-1",
|
base=BaseModelType("sd-1"),
|
||||||
type="embedding",
|
type=ModelType("embedding"),
|
||||||
format="embedding_file",
|
format="embedding_file",
|
||||||
original_hash="ABC123",
|
original_hash="ABC123",
|
||||||
)
|
)
|
||||||
@ -43,7 +48,7 @@ def test_add(store: ModelRecordServiceBase):
|
|||||||
raw = dict(
|
raw = dict(
|
||||||
path="/tmp/foo.ckpt",
|
path="/tmp/foo.ckpt",
|
||||||
name="model1",
|
name="model1",
|
||||||
base_model="sd-1",
|
base=BaseModelType("sd-1"),
|
||||||
type="main",
|
type="main",
|
||||||
config="/tmp/foo.yaml",
|
config="/tmp/foo.yaml",
|
||||||
variant="normal",
|
variant="normal",
|
||||||
@ -53,16 +58,25 @@ def test_add(store: ModelRecordServiceBase):
|
|||||||
store.add_model("key1", raw)
|
store.add_model("key1", raw)
|
||||||
config1 = store.get_model("key1")
|
config1 = store.get_model("key1")
|
||||||
assert config1 is not None
|
assert config1 is not None
|
||||||
raw["name"] = "model2"
|
assert config1.base == BaseModelType("sd-1")
|
||||||
raw["base_model"] = "sd-2"
|
|
||||||
raw["format"] = "diffusers"
|
|
||||||
raw.pop("config")
|
|
||||||
store.add_model("key2", raw)
|
|
||||||
config2 = store.get_model("key2")
|
|
||||||
assert config1.name == "model1"
|
assert config1.name == "model1"
|
||||||
assert config2.name == "model2"
|
assert config1.original_hash == "111222333444"
|
||||||
assert config1.base_model == "sd-1"
|
assert config1.current_hash is None
|
||||||
assert config2.base_model == "sd-2"
|
|
||||||
|
|
||||||
|
def test_dup(store: ModelRecordServiceBase):
|
||||||
|
config = example_config()
|
||||||
|
store.add_model("key1", example_config())
|
||||||
|
try:
|
||||||
|
store.add_model("key1", config)
|
||||||
|
assert False, "Duplicate model key should have been caught"
|
||||||
|
except DuplicateModelException:
|
||||||
|
assert True
|
||||||
|
try:
|
||||||
|
store.add_model("key2", config)
|
||||||
|
assert False, "Duplicate model path should have been caught"
|
||||||
|
except DuplicateModelException:
|
||||||
|
assert True
|
||||||
|
|
||||||
|
|
||||||
def test_update(store: ModelRecordServiceBase):
|
def test_update(store: ModelRecordServiceBase):
|
||||||
@ -115,21 +129,21 @@ def test_filter(store: ModelRecordServiceBase):
|
|||||||
config1 = DiffusersConfig(
|
config1 = DiffusersConfig(
|
||||||
path="/tmp/config1",
|
path="/tmp/config1",
|
||||||
name="config1",
|
name="config1",
|
||||||
base_model=BaseModelType("sd-1"),
|
base=BaseModelType("sd-1"),
|
||||||
type=ModelType("main"),
|
type=ModelType("main"),
|
||||||
original_hash="CONFIG1HASH",
|
original_hash="CONFIG1HASH",
|
||||||
)
|
)
|
||||||
config2 = DiffusersConfig(
|
config2 = DiffusersConfig(
|
||||||
path="/tmp/config2",
|
path="/tmp/config2",
|
||||||
name="config2",
|
name="config2",
|
||||||
base_model=BaseModelType("sd-1"),
|
base=BaseModelType("sd-1"),
|
||||||
type=ModelType("main"),
|
type=ModelType("main"),
|
||||||
original_hash="CONFIG2HASH",
|
original_hash="CONFIG2HASH",
|
||||||
)
|
)
|
||||||
config3 = VaeDiffusersConfig(
|
config3 = VaeDiffusersConfig(
|
||||||
path="/tmp/config3",
|
path="/tmp/config3",
|
||||||
name="config3",
|
name="config3",
|
||||||
base_model=BaseModelType("sd-2"),
|
base=BaseModelType("sd-2"),
|
||||||
type=ModelType("vae"),
|
type=ModelType("vae"),
|
||||||
original_hash="CONFIG3HASH",
|
original_hash="CONFIG3HASH",
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user