mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(api): add style_presets router, make sure all CRUD is working, add is_default
This commit is contained in:
parent
b76bf50b93
commit
217fe40d99
@ -110,7 +110,7 @@ class ApiDependencies:
|
|||||||
session_queue = SqliteSessionQueue(db=db)
|
session_queue = SqliteSessionQueue(db=db)
|
||||||
urls = LocalUrlService()
|
urls = LocalUrlService()
|
||||||
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
||||||
style_presets = SqliteStylePresetRecordsStorage(db=db)
|
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
board_image_records=board_image_records,
|
board_image_records=board_image_records,
|
||||||
@ -136,7 +136,7 @@ class ApiDependencies:
|
|||||||
workflow_records=workflow_records,
|
workflow_records=workflow_records,
|
||||||
tensors=tensors,
|
tensors=tensors,
|
||||||
conditioning=conditioning,
|
conditioning=conditioning,
|
||||||
style_presets=style_presets,
|
style_preset_records=style_preset_records,
|
||||||
)
|
)
|
||||||
|
|
||||||
ApiDependencies.invoker = Invoker(services)
|
ApiDependencies.invoker = Invoker(services)
|
||||||
|
100
invokeai/app/api/routers/style_presets.py
Normal file
100
invokeai/app/api/routers/style_presets.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Body, HTTPException, Path, Query
|
||||||
|
|
||||||
|
from invokeai.app.api.dependencies import ApiDependencies
|
||||||
|
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||||
|
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||||
|
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
||||||
|
StylePresetChanges,
|
||||||
|
StylePresetNotFoundError,
|
||||||
|
StylePresetRecordDTO,
|
||||||
|
StylePresetWithoutId,
|
||||||
|
StylePresetRecordOrderBy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
style_presets_router = APIRouter(prefix="/v1/style_presets", tags=["style_presets"])
|
||||||
|
|
||||||
|
|
||||||
|
@style_presets_router.get(
|
||||||
|
"/i/{style_preset_id}",
|
||||||
|
operation_id="get_style_preset",
|
||||||
|
responses={
|
||||||
|
200: {"model": StylePresetRecordDTO},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def get_style_preset(
|
||||||
|
style_preset_id: str = Path(description="The style preset to get"),
|
||||||
|
) -> StylePresetRecordDTO:
|
||||||
|
"""Gets a style preset"""
|
||||||
|
try:
|
||||||
|
return ApiDependencies.invoker.services.style_preset_records.get(style_preset_id)
|
||||||
|
except StylePresetNotFoundError:
|
||||||
|
raise HTTPException(status_code=404, detail="Style preset not found")
|
||||||
|
|
||||||
|
|
||||||
|
@style_presets_router.patch(
|
||||||
|
"/i/{style_preset_id}",
|
||||||
|
operation_id="update_style_preset",
|
||||||
|
responses={
|
||||||
|
200: {"model": StylePresetRecordDTO},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def update_style_preset(
|
||||||
|
style_preset_id: str = Path(description="The id of the style preset to update"),
|
||||||
|
changes: StylePresetChanges = Body(description="The updated style preset", embed=True),
|
||||||
|
) -> StylePresetRecordDTO:
|
||||||
|
"""Updates a style preset"""
|
||||||
|
return ApiDependencies.invoker.services.style_preset_records.update(id=style_preset_id, changes=changes)
|
||||||
|
|
||||||
|
|
||||||
|
@style_presets_router.delete(
|
||||||
|
"/i/{style_preset_id}",
|
||||||
|
operation_id="delete_style_preset",
|
||||||
|
)
|
||||||
|
async def delete_style_preset(
|
||||||
|
style_preset_id: str = Path(description="The style preset to delete"),
|
||||||
|
) -> None:
|
||||||
|
"""Deletes a style preset"""
|
||||||
|
ApiDependencies.invoker.services.style_preset_records.delete(style_preset_id)
|
||||||
|
|
||||||
|
|
||||||
|
@style_presets_router.post(
|
||||||
|
"/",
|
||||||
|
operation_id="create_style_preset",
|
||||||
|
responses={
|
||||||
|
200: {"model": StylePresetRecordDTO},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def create_style_preset(
|
||||||
|
style_preset: StylePresetWithoutId = Body(description="The style preset to create", embed=True),
|
||||||
|
) -> StylePresetRecordDTO:
|
||||||
|
"""Creates a style preset"""
|
||||||
|
return ApiDependencies.invoker.services.style_preset_records.create(style_preset=style_preset)
|
||||||
|
|
||||||
|
|
||||||
|
@style_presets_router.get(
|
||||||
|
"/",
|
||||||
|
operation_id="list_style_presets",
|
||||||
|
responses={
|
||||||
|
200: {"model": PaginatedResults[StylePresetRecordDTO]},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def list_style_presets(
|
||||||
|
page: int = Query(default=0, description="The page to get"),
|
||||||
|
per_page: int = Query(default=10, description="The number of style presets per page"),
|
||||||
|
order_by: StylePresetRecordOrderBy = Query(
|
||||||
|
default=StylePresetRecordOrderBy.Name, description="The attribute to order by"
|
||||||
|
),
|
||||||
|
direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"),
|
||||||
|
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
|
||||||
|
) -> PaginatedResults[StylePresetRecordDTO]:
|
||||||
|
"""Gets a page of style presets"""
|
||||||
|
return ApiDependencies.invoker.services.style_preset_records.get_many(
|
||||||
|
page=page,
|
||||||
|
per_page=per_page,
|
||||||
|
order_by=order_by,
|
||||||
|
direction=direction,
|
||||||
|
query=query,
|
||||||
|
)
|
@ -32,6 +32,7 @@ from invokeai.app.api.routers import (
|
|||||||
session_queue,
|
session_queue,
|
||||||
utilities,
|
utilities,
|
||||||
workflows,
|
workflows,
|
||||||
|
style_presets,
|
||||||
)
|
)
|
||||||
from invokeai.app.api.sockets import SocketIO
|
from invokeai.app.api.sockets import SocketIO
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
@ -106,6 +107,7 @@ app.include_router(board_images.board_images_router, prefix="/api")
|
|||||||
app.include_router(app_info.app_router, prefix="/api")
|
app.include_router(app_info.app_router, prefix="/api")
|
||||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||||
app.include_router(workflows.workflows_router, prefix="/api")
|
app.include_router(workflows.workflows_router, prefix="/api")
|
||||||
|
app.include_router(style_presets.style_presets_router, prefix="/api")
|
||||||
|
|
||||||
app.openapi = get_openapi_func(app)
|
app.openapi = get_openapi_func(app)
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ class InvocationServices:
|
|||||||
workflow_records: "WorkflowRecordsStorageBase",
|
workflow_records: "WorkflowRecordsStorageBase",
|
||||||
tensors: "ObjectSerializerBase[torch.Tensor]",
|
tensors: "ObjectSerializerBase[torch.Tensor]",
|
||||||
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
|
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
|
||||||
style_presets: "StylePresetRecordsStorageBase",
|
style_preset_records: "StylePresetRecordsStorageBase",
|
||||||
):
|
):
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
self.board_image_records = board_image_records
|
self.board_image_records = board_image_records
|
||||||
@ -87,4 +87,4 @@ class InvocationServices:
|
|||||||
self.workflow_records = workflow_records
|
self.workflow_records = workflow_records
|
||||||
self.tensors = tensors
|
self.tensors = tensors
|
||||||
self.conditioning = conditioning
|
self.conditioning = conditioning
|
||||||
self.style_presets = style_presets
|
self.style_preset_records = style_preset_records
|
||||||
|
@ -15,13 +15,14 @@ class Migration14Callback:
|
|||||||
id TEXT NOT NULL PRIMARY KEY,
|
id TEXT NOT NULL PRIMARY KEY,
|
||||||
name TEXT NOT NULL,
|
name TEXT NOT NULL,
|
||||||
preset_data TEXT NOT NULL,
|
preset_data TEXT NOT NULL,
|
||||||
|
is_default BOOLEAN DEFAULT FALSE,
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
-- Updated via trigger
|
-- Updated via trigger
|
||||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
||||||
);
|
);
|
||||||
"""
|
"""
|
||||||
]
|
]
|
||||||
|
|
||||||
# Add trigger for `updated_at`.
|
# Add trigger for `updated_at`.
|
||||||
triggers = [
|
triggers = [
|
||||||
"""--sql
|
"""--sql
|
||||||
|
@ -3,9 +3,11 @@ from typing import Optional
|
|||||||
|
|
||||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||||
from invokeai.app.services.style_preset_records.style_preset_records_common import StylePresetRecordDTO, StylePresetWithoutId
|
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
||||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
StylePresetChanges,
|
||||||
WorkflowRecordOrderBy,
|
StylePresetRecordDTO,
|
||||||
|
StylePresetWithoutId,
|
||||||
|
StylePresetRecordOrderBy,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -23,7 +25,7 @@ class StylePresetRecordsStorageBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def update(self, id: str, changes: StylePresetWithoutId) -> StylePresetRecordDTO:
|
def update(self, id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
||||||
"""Updates a style preset."""
|
"""Updates a style preset."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -37,7 +39,7 @@ class StylePresetRecordsStorageBase(ABC):
|
|||||||
self,
|
self,
|
||||||
page: int,
|
page: int,
|
||||||
per_page: int,
|
per_page: int,
|
||||||
order_by: WorkflowRecordOrderBy,
|
order_by: StylePresetRecordOrderBy,
|
||||||
direction: SQLiteDirection,
|
direction: SQLiteDirection,
|
||||||
query: Optional[str],
|
query: Optional[str],
|
||||||
) -> PaginatedResults[StylePresetRecordDTO]:
|
) -> PaginatedResults[StylePresetRecordDTO]:
|
||||||
|
@ -1,13 +1,23 @@
|
|||||||
from typing import Any, Union
|
from enum import Enum
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, field_validator
|
||||||
|
|
||||||
|
from invokeai.app.util.metaenum import MetaEnum
|
||||||
|
|
||||||
|
|
||||||
class StylePresetNotFoundError(Exception):
|
class StylePresetNotFoundError(Exception):
|
||||||
"""Raised when a style preset is not found"""
|
"""Raised when a style preset is not found"""
|
||||||
|
|
||||||
|
|
||||||
class PresetData(BaseModel):
|
class StylePresetRecordOrderBy(str, Enum, metaclass=MetaEnum):
|
||||||
|
"""The order by options for workflow records"""
|
||||||
|
|
||||||
|
CreatedAt = "created_at"
|
||||||
|
Name = "name"
|
||||||
|
|
||||||
|
|
||||||
|
class PresetData(BaseModel, extra="forbid"):
|
||||||
positive_prompt: str = Field(description="Positive prompt")
|
positive_prompt: str = Field(description="Positive prompt")
|
||||||
negative_prompt: str = Field(description="Negative prompt")
|
negative_prompt: str = Field(description="Negative prompt")
|
||||||
|
|
||||||
@ -15,6 +25,11 @@ class PresetData(BaseModel):
|
|||||||
PresetDataValidator = TypeAdapter(PresetData)
|
PresetDataValidator = TypeAdapter(PresetData)
|
||||||
|
|
||||||
|
|
||||||
|
class StylePresetChanges(BaseModel, extra="forbid"):
|
||||||
|
name: Optional[str] = Field(default=None, description="The style preset's new name.")
|
||||||
|
preset_data: Optional[PresetData] = Field(default=None, description="The updated data for style preset.")
|
||||||
|
|
||||||
|
|
||||||
class StylePresetWithoutId(BaseModel):
|
class StylePresetWithoutId(BaseModel):
|
||||||
name: str = Field(description="The name of the style preset.")
|
name: str = Field(description="The name of the style preset.")
|
||||||
preset_data: PresetData = Field(description="The preset data")
|
preset_data: PresetData = Field(description="The preset data")
|
||||||
@ -22,6 +37,7 @@ class StylePresetWithoutId(BaseModel):
|
|||||||
|
|
||||||
class StylePresetRecordDTO(StylePresetWithoutId):
|
class StylePresetRecordDTO(StylePresetWithoutId):
|
||||||
id: str = Field(description="The style preset ID.")
|
id: str = Field(description="The style preset ID.")
|
||||||
|
is_default: bool = Field(description="Whether or not the style preset is default")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict[str, Any]) -> "StylePresetRecordDTO":
|
def from_dict(cls, data: dict[str, Any]) -> "StylePresetRecordDTO":
|
||||||
|
@ -7,12 +7,11 @@ from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
|||||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
|
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
|
||||||
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
||||||
|
StylePresetChanges,
|
||||||
StylePresetNotFoundError,
|
StylePresetNotFoundError,
|
||||||
StylePresetRecordDTO,
|
StylePresetRecordDTO,
|
||||||
StylePresetWithoutId,
|
StylePresetWithoutId,
|
||||||
)
|
StylePresetRecordOrderBy,
|
||||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
|
||||||
WorkflowRecordOrderBy,
|
|
||||||
)
|
)
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
@ -58,14 +57,14 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
|||||||
INSERT OR IGNORE INTO style_presets (
|
INSERT OR IGNORE INTO style_presets (
|
||||||
id,
|
id,
|
||||||
name,
|
name,
|
||||||
preset_data,
|
preset_data
|
||||||
)
|
)
|
||||||
VALUES (?, ?, ?);
|
VALUES (?, ?, ?);
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
id,
|
id,
|
||||||
style_preset.name,
|
style_preset.name,
|
||||||
style_preset.preset_data,
|
style_preset.preset_data.model_dump_json(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
@ -76,20 +75,31 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
|||||||
self._lock.release()
|
self._lock.release()
|
||||||
return self.get(id)
|
return self.get(id)
|
||||||
|
|
||||||
def update(self, id: str, changes: StylePresetWithoutId) -> StylePresetRecordDTO:
|
def update(self, id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
# Change the name of a style preset
|
||||||
"""--sql
|
if changes.name is not None:
|
||||||
UPDATE style_presets
|
self._cursor.execute(
|
||||||
SET preset_data = ?
|
"""--sql
|
||||||
WHERE id = ? ;
|
UPDATE style_presets
|
||||||
""",
|
SET name = ?
|
||||||
(
|
WHERE id = ?;
|
||||||
changes.preset_data,
|
""",
|
||||||
id,
|
(changes.name, id),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
|
# Change the preset data for a style preset
|
||||||
|
if changes.preset_data is not None:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
UPDATE style_presets
|
||||||
|
SET preset_data = ?
|
||||||
|
WHERE id = ?;
|
||||||
|
""",
|
||||||
|
(changes.preset_data.model_dump_json(), id),
|
||||||
|
)
|
||||||
|
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
except Exception:
|
except Exception:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
@ -120,14 +130,14 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
|||||||
self,
|
self,
|
||||||
page: int,
|
page: int,
|
||||||
per_page: int,
|
per_page: int,
|
||||||
order_by: WorkflowRecordOrderBy,
|
order_by: StylePresetRecordOrderBy,
|
||||||
direction: SQLiteDirection,
|
direction: SQLiteDirection,
|
||||||
query: Optional[str] = None,
|
query: Optional[str] = None,
|
||||||
) -> PaginatedResults[StylePresetRecordDTO]:
|
) -> PaginatedResults[StylePresetRecordDTO]:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
# sanitize!
|
# sanitize!
|
||||||
assert order_by in WorkflowRecordOrderBy
|
assert order_by in StylePresetRecordOrderBy
|
||||||
assert direction in SQLiteDirection
|
assert direction in SQLiteDirection
|
||||||
count_query = "SELECT COUNT(*) FROM style_presets"
|
count_query = "SELECT COUNT(*) FROM style_presets"
|
||||||
main_query = """
|
main_query = """
|
||||||
@ -140,8 +150,8 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
|||||||
stripped_query = query.strip() if query else None
|
stripped_query = query.strip() if query else None
|
||||||
if stripped_query:
|
if stripped_query:
|
||||||
wildcard_query = "%" + stripped_query + "%"
|
wildcard_query = "%" + stripped_query + "%"
|
||||||
main_query += " AND name LIKE ? OR description LIKE ? "
|
main_query += " AND name LIKE ? "
|
||||||
count_query += " AND name LIKE ? OR description LIKE ?;"
|
count_query += " AND name LIKE ?;"
|
||||||
main_params.extend([wildcard_query, wildcard_query])
|
main_params.extend([wildcard_query, wildcard_query])
|
||||||
count_params.extend([wildcard_query, wildcard_query])
|
count_params.extend([wildcard_query, wildcard_query])
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user