mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
4 Commits
feat/batch
...
fix/diffus
Author | SHA1 | Date | |
---|---|---|---|
e04c25eba7 | |||
be61ffdbf6 | |||
823b879329 | |||
17c901aaf7 |
10
flake.nix
10
flake.nix
@ -34,10 +34,6 @@
|
|||||||
cudaPackages.cudnn
|
cudaPackages.cudnn
|
||||||
cudaPackages.cuda_nvrtc
|
cudaPackages.cuda_nvrtc
|
||||||
cudatoolkit
|
cudatoolkit
|
||||||
pkgconfig
|
|
||||||
libconfig
|
|
||||||
cmake
|
|
||||||
blas
|
|
||||||
freeglut
|
freeglut
|
||||||
glib
|
glib
|
||||||
gperf
|
gperf
|
||||||
@ -46,12 +42,6 @@
|
|||||||
libGLU
|
libGLU
|
||||||
linuxPackages.nvidia_x11
|
linuxPackages.nvidia_x11
|
||||||
python
|
python
|
||||||
(opencv4.override {
|
|
||||||
enableGtk3 = true;
|
|
||||||
enableFfmpeg = true;
|
|
||||||
enableCuda = true;
|
|
||||||
enableUnfree = true;
|
|
||||||
})
|
|
||||||
stdenv.cc
|
stdenv.cc
|
||||||
stdenv.cc.cc.lib
|
stdenv.cc.cc.lib
|
||||||
xorg.libX11
|
xorg.libX11
|
||||||
|
@ -30,8 +30,6 @@ from ..services.invoker import Invoker
|
|||||||
from ..services.processor import DefaultInvocationProcessor
|
from ..services.processor import DefaultInvocationProcessor
|
||||||
from ..services.sqlite import SqliteItemStorage
|
from ..services.sqlite import SqliteItemStorage
|
||||||
from ..services.model_manager_service import ModelManagerService
|
from ..services.model_manager_service import ModelManagerService
|
||||||
from ..services.batch_manager import BatchManager
|
|
||||||
from ..services.batch_manager_storage import SqliteBatchProcessStorage
|
|
||||||
from .events import FastAPIEventService
|
from .events import FastAPIEventService
|
||||||
|
|
||||||
|
|
||||||
@ -118,15 +116,11 @@ class ApiDependencies:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_manager_storage = SqliteBatchProcessStorage(db_location)
|
|
||||||
batch_manager = BatchManager(batch_manager_storage)
|
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=ModelManagerService(config, logger),
|
model_manager=ModelManagerService(config, logger),
|
||||||
events=events,
|
events=events,
|
||||||
latents=latents,
|
latents=latents,
|
||||||
images=images,
|
images=images,
|
||||||
batch_manager=batch_manager,
|
|
||||||
boards=boards,
|
boards=boards,
|
||||||
board_images=board_images,
|
board_images=board_images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
|
@ -15,7 +15,6 @@ from ...services.graph import (
|
|||||||
GraphExecutionState,
|
GraphExecutionState,
|
||||||
NodeAlreadyExecutedError,
|
NodeAlreadyExecutedError,
|
||||||
)
|
)
|
||||||
from ...services.batch_manager import Batch, BatchProcess
|
|
||||||
from ...services.item_storage import PaginatedResults
|
from ...services.item_storage import PaginatedResults
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
@ -38,37 +37,6 @@ async def create_session(
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
@session_router.post(
|
|
||||||
"/batch",
|
|
||||||
operation_id="create_batch",
|
|
||||||
responses={
|
|
||||||
200: {"model": BatchProcess},
|
|
||||||
400: {"description": "Invalid json"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def create_batch(
|
|
||||||
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
|
|
||||||
batches: list[Batch] = Body(description="Batch config to apply to the given graph"),
|
|
||||||
) -> BatchProcess:
|
|
||||||
"""Creates and starts a new new batch process"""
|
|
||||||
batch_id = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph)
|
|
||||||
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_id)
|
|
||||||
return {"batch_id":batch_id}
|
|
||||||
|
|
||||||
|
|
||||||
@session_router.delete(
|
|
||||||
"{batch_process_id}/batch",
|
|
||||||
operation_id="cancel_batch",
|
|
||||||
responses={202: {"description": "The batch is canceled"}},
|
|
||||||
)
|
|
||||||
async def cancel_batch(
|
|
||||||
batch_process_id: str = Path(description="The id of the batch process to cancel"),
|
|
||||||
) -> Response:
|
|
||||||
"""Creates and starts a new new batch process"""
|
|
||||||
ApiDependencies.invoker.services.batch_manager.cancel_batch_process(batch_process_id)
|
|
||||||
return Response(status_code=202)
|
|
||||||
|
|
||||||
|
|
||||||
@session_router.get(
|
@session_router.get(
|
||||||
"/",
|
"/",
|
||||||
operation_id="list_sessions",
|
operation_id="list_sessions",
|
||||||
|
@ -37,8 +37,6 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
|||||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||||
from invokeai.app.services.resource_name import SimpleNameService
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
from invokeai.app.services.batch_manager import BatchManager
|
|
||||||
from invokeai.app.services.batch_manager_storage import SqliteBatchProcessStorage
|
|
||||||
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
@ -302,16 +300,12 @@ def invoke_cli():
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_manager_storage = SqliteBatchProcessStorage(db_location)
|
|
||||||
batch_manager = BatchManager(batch_manager_storage)
|
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=model_manager,
|
model_manager=model_manager,
|
||||||
events=events,
|
events=events,
|
||||||
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
|
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
|
||||||
images=images,
|
images=images,
|
||||||
boards=boards,
|
boards=boards,
|
||||||
batch_manager=batch_manager,
|
|
||||||
board_images=board_images,
|
board_images=board_images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||||
|
@ -1,139 +0,0 @@
|
|||||||
import networkx as nx
|
|
||||||
import copy
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from itertools import product
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from fastapi_events.handlers.local import local_handler
|
|
||||||
from fastapi_events.typing import Event
|
|
||||||
|
|
||||||
from invokeai.app.services.events import EventServiceBase
|
|
||||||
from invokeai.app.services.graph import Graph, GraphExecutionState
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
|
||||||
from invokeai.app.services.batch_manager_storage import (
|
|
||||||
BatchProcessStorageBase,
|
|
||||||
BatchSessionNotFoundException,
|
|
||||||
Batch,
|
|
||||||
BatchProcess,
|
|
||||||
BatchSession,
|
|
||||||
BatchSessionChanges,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BatchManagerBase(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def start(self, invoker: Invoker):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create_batch_process(self, batches: list[Batch], graph: Graph) -> str:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def run_batch_process(self, batch_id: str):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def cancel_batch_process(self, batch_process_id: str):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BatchManager(BatchManagerBase):
|
|
||||||
"""Responsible for managing currently running and scheduled batch jobs"""
|
|
||||||
|
|
||||||
__invoker: Invoker
|
|
||||||
__batches: list[BatchProcess]
|
|
||||||
__batch_process_storage: BatchProcessStorageBase
|
|
||||||
|
|
||||||
def __init__(self, batch_process_storage: BatchProcessStorageBase) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.__batch_process_storage = batch_process_storage
|
|
||||||
|
|
||||||
def start(self, invoker: Invoker) -> None:
|
|
||||||
# if we do want multithreading at some point, we could make this configurable
|
|
||||||
self.__invoker = invoker
|
|
||||||
self.__batches = list()
|
|
||||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self.on_event)
|
|
||||||
|
|
||||||
async def on_event(self, event: Event):
|
|
||||||
event_name = event[1]["event"]
|
|
||||||
|
|
||||||
match event_name:
|
|
||||||
case "graph_execution_state_complete":
|
|
||||||
await self.process(event, False)
|
|
||||||
case "invocation_error":
|
|
||||||
await self.process(event, True)
|
|
||||||
|
|
||||||
return event
|
|
||||||
|
|
||||||
async def process(self, event: Event, err: bool):
|
|
||||||
data = event[1]["data"]
|
|
||||||
batch_session = self.__batch_process_storage.get_session(data["graph_execution_state_id"])
|
|
||||||
if not batch_session:
|
|
||||||
return
|
|
||||||
updateSession = BatchSessionChanges(
|
|
||||||
state='error' if err else 'completed'
|
|
||||||
)
|
|
||||||
batch_session = self.__batch_process_storage.update_session_state(
|
|
||||||
batch_session.batch_id,
|
|
||||||
batch_session.session_id,
|
|
||||||
updateSession,
|
|
||||||
)
|
|
||||||
self.run_batch_process(batch_session.batch_id)
|
|
||||||
|
|
||||||
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState:
|
|
||||||
graph = copy.deepcopy(batch_process.graph)
|
|
||||||
batches = batch_process.batches
|
|
||||||
g = graph.nx_graph_flat()
|
|
||||||
sorted_nodes = nx.topological_sort(g)
|
|
||||||
for npath in sorted_nodes:
|
|
||||||
node = graph.get_node(npath)
|
|
||||||
(index, batch) = next(((i, b) for i, b in enumerate(batches) if b.node_id in node.id), (None, None))
|
|
||||||
if batch:
|
|
||||||
batch_index = batch_indices[index]
|
|
||||||
datum = batch.data[batch_index]
|
|
||||||
for key in datum:
|
|
||||||
node.__dict__[key] = datum[key]
|
|
||||||
graph.update_node(npath, node)
|
|
||||||
|
|
||||||
return GraphExecutionState(graph=graph)
|
|
||||||
|
|
||||||
def run_batch_process(self, batch_id: str):
|
|
||||||
try:
|
|
||||||
created_session = self.__batch_process_storage.get_created_session(batch_id)
|
|
||||||
except BatchSessionNotFoundException:
|
|
||||||
return
|
|
||||||
ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id)
|
|
||||||
self.__invoker.invoke(ges, invoke_all=True)
|
|
||||||
|
|
||||||
def _valid_batch_config(self, batch_process: BatchProcess) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def create_batch_process(self, batches: list[Batch], graph: Graph) -> str:
|
|
||||||
batch_process = BatchProcess(
|
|
||||||
batches=batches,
|
|
||||||
graph=graph,
|
|
||||||
)
|
|
||||||
if not self._valid_batch_config(batch_process):
|
|
||||||
return None
|
|
||||||
batch_process = self.__batch_process_storage.save(batch_process)
|
|
||||||
self._create_sessions(batch_process)
|
|
||||||
return batch_process.batch_id
|
|
||||||
|
|
||||||
def _create_sessions(self, batch_process: BatchProcess):
|
|
||||||
batch_indices = list()
|
|
||||||
for batch in batch_process.batches:
|
|
||||||
batch_indices.append(list(range(len(batch.data))))
|
|
||||||
all_batch_indices = product(*batch_indices)
|
|
||||||
for bi in all_batch_indices:
|
|
||||||
ges = self._create_batch_session(batch_process, bi)
|
|
||||||
self.__invoker.services.graph_execution_manager.set(ges)
|
|
||||||
batch_session = BatchSession(
|
|
||||||
batch_id=batch_process.batch_id,
|
|
||||||
session_id=ges.id,
|
|
||||||
state="created"
|
|
||||||
)
|
|
||||||
self.__batch_process_storage.create_session(batch_session)
|
|
||||||
|
|
||||||
def cancel_batch_process(self, batch_process_id: str):
|
|
||||||
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]
|
|
@ -1,505 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import cast
|
|
||||||
import uuid
|
|
||||||
import sqlite3
|
|
||||||
import threading
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
import json
|
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
|
||||||
BaseInvocation,
|
|
||||||
)
|
|
||||||
from invokeai.app.services.graph import Graph
|
|
||||||
from invokeai.app.models.image import ImageField
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, Extra, parse_raw_as
|
|
||||||
|
|
||||||
invocations = BaseInvocation.get_invocations()
|
|
||||||
InvocationsUnion = Union[invocations] # type: ignore
|
|
||||||
|
|
||||||
BatchDataType = Union[str, int, float, ImageField]
|
|
||||||
|
|
||||||
class Batch(BaseModel):
|
|
||||||
data: list[dict[str, BatchDataType]] = Field(description="Mapping of node field to data value")
|
|
||||||
node_id: str = Field(description="ID of the node to batch")
|
|
||||||
|
|
||||||
|
|
||||||
class BatchSession(BaseModel):
|
|
||||||
batch_id: str = Field(description="Identifier for which batch this Index belongs to")
|
|
||||||
session_id: str = Field(description="Session ID Created for this Batch Index")
|
|
||||||
state: Literal["created", "completed", "inprogress", "error"] = Field(
|
|
||||||
description="Is this session created, completed, in progress, or errored?"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def uuid_string():
|
|
||||||
res = uuid.uuid4()
|
|
||||||
return str(res)
|
|
||||||
|
|
||||||
class BatchProcess(BaseModel):
|
|
||||||
batch_id: Optional[str] = Field(default_factory=uuid_string, description="Identifier for this batch")
|
|
||||||
batches: List[Batch] = Field(
|
|
||||||
description="List of batch configs to apply to this session",
|
|
||||||
default_factory=list,
|
|
||||||
)
|
|
||||||
graph: Graph = Field(description="The graph being executed")
|
|
||||||
|
|
||||||
|
|
||||||
class BatchSessionChanges(BaseModel, extra=Extra.forbid):
|
|
||||||
state: Literal["created", "completed", "inprogress", "error"] = Field(
|
|
||||||
description="Is this session created, completed, in progress, or errored?"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BatchProcessNotFoundException(Exception):
|
|
||||||
"""Raised when an Batch Process record is not found."""
|
|
||||||
|
|
||||||
def __init__(self, message="BatchProcess record not found"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class BatchProcessSaveException(Exception):
|
|
||||||
"""Raised when an Batch Process record cannot be saved."""
|
|
||||||
|
|
||||||
def __init__(self, message="BatchProcess record not saved"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class BatchProcessDeleteException(Exception):
|
|
||||||
"""Raised when an Batch Process record cannot be deleted."""
|
|
||||||
|
|
||||||
def __init__(self, message="BatchProcess record not deleted"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class BatchSessionNotFoundException(Exception):
|
|
||||||
"""Raised when an Batch Session record is not found."""
|
|
||||||
|
|
||||||
def __init__(self, message="BatchSession record not found"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class BatchSessionSaveException(Exception):
|
|
||||||
"""Raised when an Batch Session record cannot be saved."""
|
|
||||||
|
|
||||||
def __init__(self, message="BatchSession record not saved"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class BatchSessionDeleteException(Exception):
|
|
||||||
"""Raised when an Batch Session record cannot be deleted."""
|
|
||||||
|
|
||||||
def __init__(self, message="BatchSession record not deleted"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class BatchProcessStorageBase(ABC):
|
|
||||||
"""Low-level service responsible for interfacing with the Batch Process record store."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, batch_id: str) -> None:
|
|
||||||
"""Deletes a Batch Process record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save(
|
|
||||||
self,
|
|
||||||
batch_process: BatchProcess,
|
|
||||||
) -> BatchProcess:
|
|
||||||
"""Saves a Batch Process record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(
|
|
||||||
self,
|
|
||||||
batch_id: str,
|
|
||||||
) -> BatchProcess:
|
|
||||||
"""Gets a Batch Process record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create_session(
|
|
||||||
self,
|
|
||||||
session: BatchSession,
|
|
||||||
) -> BatchSession:
|
|
||||||
"""Creates a Batch Session attached to a Batch Process."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_session(
|
|
||||||
self,
|
|
||||||
session_id: str
|
|
||||||
) -> BatchSession:
|
|
||||||
"""Gets session by session_id"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_created_session(
|
|
||||||
self,
|
|
||||||
batch_id: str
|
|
||||||
) -> BatchSession:
|
|
||||||
"""Gets all created Batch Sessions for a given Batch Process id."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_created_sessions(
|
|
||||||
self,
|
|
||||||
batch_id: str
|
|
||||||
) -> List[BatchSession]:
|
|
||||||
"""Gets all created Batch Sessions for a given Batch Process id."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update_session_state(
|
|
||||||
self,
|
|
||||||
batch_id: str,
|
|
||||||
session_id: str,
|
|
||||||
changes: BatchSessionChanges,
|
|
||||||
) -> BatchSession:
|
|
||||||
"""Updates the state of a Batch Session record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|
||||||
_filename: str
|
|
||||||
_conn: sqlite3.Connection
|
|
||||||
_cursor: sqlite3.Cursor
|
|
||||||
_lock: threading.Lock
|
|
||||||
|
|
||||||
def __init__(self, filename: str) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self._filename = filename
|
|
||||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
|
||||||
self._conn.row_factory = sqlite3.Row
|
|
||||||
self._cursor = self._conn.cursor()
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
# Enable foreign keys
|
|
||||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
|
||||||
self._create_tables()
|
|
||||||
self._conn.commit()
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
def _create_tables(self) -> None:
|
|
||||||
"""Creates the `batch_process` table and `batch_session` junction table."""
|
|
||||||
|
|
||||||
# Create the `batch_process` table.
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TABLE IF NOT EXISTS batch_process (
|
|
||||||
batch_id TEXT NOT NULL PRIMARY KEY,
|
|
||||||
batches TEXT NOT NULL,
|
|
||||||
graph TEXT NOT NULL,
|
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
-- Updated via trigger
|
|
||||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
-- Soft delete, currently unused
|
|
||||||
deleted_at DATETIME
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_batch_process_created_at ON batch_process (created_at);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add trigger for `updated_at`.
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TRIGGER IF NOT EXISTS tg_batch_process_updated_at
|
|
||||||
AFTER UPDATE
|
|
||||||
ON batch_process FOR EACH ROW
|
|
||||||
BEGIN
|
|
||||||
UPDATE batch_process SET updated_at = current_timestamp
|
|
||||||
WHERE batch_id = old.batch_id;
|
|
||||||
END;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create the `batch_session` junction table.
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TABLE IF NOT EXISTS batch_session (
|
|
||||||
batch_id TEXT NOT NULL,
|
|
||||||
session_id TEXT NOT NULL,
|
|
||||||
state TEXT NOT NULL,
|
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
-- updated via trigger
|
|
||||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
-- Soft delete, currently unused
|
|
||||||
deleted_at DATETIME,
|
|
||||||
-- enforce one-to-many relationship between batch_process and batch_session using PK
|
|
||||||
-- (we can extend this to many-to-many later)
|
|
||||||
PRIMARY KEY (batch_id,session_id),
|
|
||||||
FOREIGN KEY (batch_id) REFERENCES batch_process (batch_id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add index for batch id
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id ON batch_session (batch_id);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add index for batch id, sorted by created_at
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id_created_at ON batch_session (batch_id,created_at);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add trigger for `updated_at`.
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TRIGGER IF NOT EXISTS tg_batch_session_updated_at
|
|
||||||
AFTER UPDATE
|
|
||||||
ON batch_session FOR EACH ROW
|
|
||||||
BEGIN
|
|
||||||
UPDATE batch_session SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
|
||||||
WHERE batch_id = old.batch_id AND session_id = old.session_id;
|
|
||||||
END;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete(self, batch_id: str) -> None:
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
DELETE FROM batch_process
|
|
||||||
WHERE batch_id = ?;
|
|
||||||
""",
|
|
||||||
(batch_id,),
|
|
||||||
)
|
|
||||||
self._conn.commit()
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
self._conn.rollback()
|
|
||||||
raise BatchProcessDeleteException from e
|
|
||||||
except Exception as e:
|
|
||||||
self._conn.rollback()
|
|
||||||
raise BatchProcessDeleteException from e
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
def save(
|
|
||||||
self,
|
|
||||||
batch_process: BatchProcess,
|
|
||||||
) -> BatchProcess:
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
batches = [batch.json() for batch in batch_process.batches]
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
INSERT OR IGNORE INTO batch_process (batch_id, batches, graph)
|
|
||||||
VALUES (?, ?, ?);
|
|
||||||
""",
|
|
||||||
(batch_process.batch_id, json.dumps(batches), batch_process.graph.json()),
|
|
||||||
)
|
|
||||||
self._conn.commit()
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
self._conn.rollback()
|
|
||||||
raise BatchProcessSaveException from e
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
return self.get(batch_process.batch_id)
|
|
||||||
|
|
||||||
def _deserialize_batch_process(self, session_dict: dict) -> BatchProcess:
|
|
||||||
"""Deserializes a batch session."""
|
|
||||||
|
|
||||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
|
||||||
|
|
||||||
batch_id = session_dict.get("batch_id", "unknown")
|
|
||||||
batches_raw = session_dict.get("batches", "unknown")
|
|
||||||
graph_raw = session_dict.get("graph", "unknown")
|
|
||||||
batches = json.loads(batches_raw)
|
|
||||||
batches = [parse_raw_as(Batch, batch) for batch in batches]
|
|
||||||
return BatchProcess(
|
|
||||||
batch_id=batch_id,
|
|
||||||
batches=batches,
|
|
||||||
graph=parse_raw_as(Graph, graph_raw),
|
|
||||||
)
|
|
||||||
|
|
||||||
def get(
|
|
||||||
self,
|
|
||||||
batch_id: str,
|
|
||||||
) -> BatchProcess:
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT *
|
|
||||||
FROM batch_process
|
|
||||||
WHERE batch_id = ?;
|
|
||||||
""",
|
|
||||||
(batch_id,)
|
|
||||||
)
|
|
||||||
|
|
||||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
self._conn.rollback()
|
|
||||||
raise BatchProcessNotFoundException from e
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
if result is None:
|
|
||||||
raise BatchProcessNotFoundException
|
|
||||||
return self._deserialize_batch_process(dict(result))
|
|
||||||
|
|
||||||
def create_session(
|
|
||||||
self,
|
|
||||||
session: BatchSession,
|
|
||||||
) -> BatchSession:
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state)
|
|
||||||
VALUES (?, ?, ?);
|
|
||||||
""",
|
|
||||||
(session.batch_id, session.session_id, session.state),
|
|
||||||
)
|
|
||||||
self._conn.commit()
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
self._conn.rollback()
|
|
||||||
raise BatchSessionSaveException from e
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
return self.get_session(session.session_id)
|
|
||||||
|
|
||||||
|
|
||||||
def get_session(
|
|
||||||
self,
|
|
||||||
session_id: str
|
|
||||||
) -> BatchSession:
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT *
|
|
||||||
FROM batch_session
|
|
||||||
WHERE session_id= ?;
|
|
||||||
""",
|
|
||||||
(session_id,),
|
|
||||||
)
|
|
||||||
|
|
||||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
self._conn.rollback()
|
|
||||||
raise BatchSessionNotFoundException from e
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
if result is None:
|
|
||||||
raise BatchSessionNotFoundException
|
|
||||||
return self._deserialize_batch_session(dict(result))
|
|
||||||
|
|
||||||
def _deserialize_batch_session(self, session_dict: dict) -> BatchSession:
|
|
||||||
"""Deserializes a batch session."""
|
|
||||||
|
|
||||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
|
||||||
|
|
||||||
batch_id = session_dict.get("batch_id", "unknown")
|
|
||||||
session_id = session_dict.get("session_id", "unknown")
|
|
||||||
state = session_dict.get("state", "unknown")
|
|
||||||
|
|
||||||
return BatchSession(
|
|
||||||
batch_id=batch_id,
|
|
||||||
session_id=session_id,
|
|
||||||
state=state,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_created_session(
|
|
||||||
self,
|
|
||||||
batch_id: str
|
|
||||||
) -> BatchSession:
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT *
|
|
||||||
FROM batch_session
|
|
||||||
WHERE batch_id = ? AND state = 'created';
|
|
||||||
""",
|
|
||||||
(batch_id,),
|
|
||||||
)
|
|
||||||
|
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchone())
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
self._conn.rollback()
|
|
||||||
raise BatchSessionNotFoundException from e
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
if result is None:
|
|
||||||
raise BatchSessionNotFoundException
|
|
||||||
session = self._deserialize_batch_session(dict(result))
|
|
||||||
return session
|
|
||||||
|
|
||||||
|
|
||||||
def get_created_sessions(
|
|
||||||
self,
|
|
||||||
batch_id: str
|
|
||||||
) -> List[BatchSession]:
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT *
|
|
||||||
FROM batch_session
|
|
||||||
WHERE batch_id = ? AND state = created;
|
|
||||||
""",
|
|
||||||
(batch_id,),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
self._conn.rollback()
|
|
||||||
raise BatchSessionNotFoundException from e
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
if result is None:
|
|
||||||
raise BatchSessionNotFoundException
|
|
||||||
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
|
||||||
return sessions
|
|
||||||
|
|
||||||
|
|
||||||
def update_session_state(
|
|
||||||
self,
|
|
||||||
batch_id: str,
|
|
||||||
session_id: str,
|
|
||||||
changes: BatchSessionChanges,
|
|
||||||
) -> BatchSession:
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
|
|
||||||
# Change the state of a batch session
|
|
||||||
if changes.state is not None:
|
|
||||||
self._cursor.execute(
|
|
||||||
f"""--sql
|
|
||||||
UPDATE batch_session
|
|
||||||
SET state = ?
|
|
||||||
WHERE batch_id = ? AND session_id = ?;
|
|
||||||
""",
|
|
||||||
(changes.state, batch_id, session_id),
|
|
||||||
)
|
|
||||||
|
|
||||||
self._conn.commit()
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
self._conn.rollback()
|
|
||||||
raise BatchSessionSaveException from e
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
return self.get_session(session_id)
|
|
@ -4,7 +4,6 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from invokeai.app.services.batch_manager import BatchManagerBase
|
|
||||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||||
from invokeai.app.services.boards import BoardServiceABC
|
from invokeai.app.services.boards import BoardServiceABC
|
||||||
from invokeai.app.services.images import ImageServiceABC
|
from invokeai.app.services.images import ImageServiceABC
|
||||||
@ -22,7 +21,6 @@ class InvocationServices:
|
|||||||
"""Services that can be used by invocations"""
|
"""Services that can be used by invocations"""
|
||||||
|
|
||||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||||
batch_manager: "BatchManagerBase"
|
|
||||||
board_images: "BoardImagesServiceABC"
|
board_images: "BoardImagesServiceABC"
|
||||||
boards: "BoardServiceABC"
|
boards: "BoardServiceABC"
|
||||||
configuration: "InvokeAIAppConfig"
|
configuration: "InvokeAIAppConfig"
|
||||||
@ -38,7 +36,6 @@ class InvocationServices:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
batch_manager: "BatchManagerBase",
|
|
||||||
board_images: "BoardImagesServiceABC",
|
board_images: "BoardImagesServiceABC",
|
||||||
boards: "BoardServiceABC",
|
boards: "BoardServiceABC",
|
||||||
configuration: "InvokeAIAppConfig",
|
configuration: "InvokeAIAppConfig",
|
||||||
@ -52,7 +49,6 @@ class InvocationServices:
|
|||||||
processor: "InvocationProcessorABC",
|
processor: "InvocationProcessorABC",
|
||||||
queue: "InvocationQueueABC",
|
queue: "InvocationQueueABC",
|
||||||
):
|
):
|
||||||
self.batch_manager = batch_manager
|
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
self.boards = boards
|
self.boards = boards
|
||||||
self.boards = boards
|
self.boards = boards
|
||||||
|
@ -29,7 +29,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
self._conn = sqlite3.connect(
|
self._conn = sqlite3.connect(
|
||||||
self._filename, check_same_thread=False
|
self._filename, check_same_thread=False
|
||||||
) # TODO: figure out a better threading solution
|
) # TODO: figure out a better threading solution
|
||||||
self._conn.execute('pragma journal_mode=wal')
|
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
|
||||||
self._create_table()
|
self._create_table()
|
||||||
|
@ -651,7 +651,10 @@ class TextualInversionModel:
|
|||||||
file_path = Path(file_path)
|
file_path = Path(file_path)
|
||||||
|
|
||||||
result = cls() # TODO:
|
result = cls() # TODO:
|
||||||
result.name = file_path.stem # TODO:
|
if file_path.name == "learned_embeds.bin":
|
||||||
|
result.name = file_path.parent.name
|
||||||
|
else:
|
||||||
|
result.name = file_path.stem
|
||||||
|
|
||||||
if file_path.suffix == ".safetensors":
|
if file_path.suffix == ".safetensors":
|
||||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||||
|
@ -188,7 +188,7 @@ class ModelCache(object):
|
|||||||
cache_entry = self._cached_models.get(key, None)
|
cache_entry = self._cached_models.get(key, None)
|
||||||
if cache_entry is None:
|
if cache_entry is None:
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"Loading model {model_path}, type {base_model.value}:{model_type.value}:{submodel.value if submodel else ''}"
|
f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# this will remove older cached models until
|
# this will remove older cached models until
|
||||||
|
@ -472,7 +472,7 @@ class ModelManager(object):
|
|||||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||||
override_path = getattr(model_config, submodel_type)
|
override_path = getattr(model_config, submodel_type)
|
||||||
if override_path:
|
if override_path:
|
||||||
model_path = self.app_config.root_path / override_path
|
model_path = self.resolve_path(override_path)
|
||||||
model_type = submodel_type
|
model_type = submodel_type
|
||||||
submodel_type = None
|
submodel_type = None
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
@ -670,7 +670,7 @@ class ModelManager(object):
|
|||||||
# TODO: if path changed and old_model.path inside models folder should we delete this too?
|
# TODO: if path changed and old_model.path inside models folder should we delete this too?
|
||||||
|
|
||||||
# remove conversion cache as config changed
|
# remove conversion cache as config changed
|
||||||
old_model_path = self.app_config.root_path / old_model.path
|
old_model_path = self.resolve_model_path(old_model.path)
|
||||||
old_model_cache = self._get_model_cache_path(old_model_path)
|
old_model_cache = self._get_model_cache_path(old_model_path)
|
||||||
if old_model_cache.exists():
|
if old_model_cache.exists():
|
||||||
if old_model_cache.is_dir():
|
if old_model_cache.is_dir():
|
||||||
@ -780,7 +780,7 @@ class ModelManager(object):
|
|||||||
model_type,
|
model_type,
|
||||||
**submodel,
|
**submodel,
|
||||||
)
|
)
|
||||||
checkpoint_path = self.app_config.root_path / info["path"]
|
checkpoint_path = self.resolve_model_path(info["path"])
|
||||||
old_diffusers_path = self.resolve_model_path(model.location)
|
old_diffusers_path = self.resolve_model_path(model.location)
|
||||||
new_diffusers_path = (
|
new_diffusers_path = (
|
||||||
dest_directory or self.app_config.models_path / base_model.value / model_type.value
|
dest_directory or self.app_config.models_path / base_model.value / model_type.value
|
||||||
@ -992,7 +992,7 @@ class ModelManager(object):
|
|||||||
model_manager=self,
|
model_manager=self,
|
||||||
prediction_type_helper=ask_user_for_prediction_type,
|
prediction_type_helper=ask_user_for_prediction_type,
|
||||||
)
|
)
|
||||||
known_paths = {config.root_path / x["path"] for x in self.list_models()}
|
known_paths = {self.resolve_model_path(x["path"]) for x in self.list_models()}
|
||||||
directories = {
|
directories = {
|
||||||
config.root_path / x
|
config.root_path / x
|
||||||
for x in [
|
for x in [
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { ButtonGroup, Flex, Spinner, Text } from '@chakra-ui/react';
|
import { ButtonGroup, Flex, Text } from '@chakra-ui/react';
|
||||||
import { EntityState } from '@reduxjs/toolkit';
|
import { EntityState } from '@reduxjs/toolkit';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import IAIInput from 'common/components/IAIInput';
|
import IAIInput from 'common/components/IAIInput';
|
||||||
@ -6,23 +6,23 @@ import { forEach } from 'lodash-es';
|
|||||||
import type { ChangeEvent, PropsWithChildren } from 'react';
|
import type { ChangeEvent, PropsWithChildren } from 'react';
|
||||||
import { useCallback, useState } from 'react';
|
import { useCallback, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
|
||||||
import {
|
import {
|
||||||
LoRAModelConfigEntity,
|
|
||||||
MainModelConfigEntity,
|
MainModelConfigEntity,
|
||||||
OnnxModelConfigEntity,
|
OnnxModelConfigEntity,
|
||||||
useGetLoRAModelsQuery,
|
|
||||||
useGetMainModelsQuery,
|
useGetMainModelsQuery,
|
||||||
useGetOnnxModelsQuery,
|
useGetOnnxModelsQuery,
|
||||||
|
useGetLoRAModelsQuery,
|
||||||
|
LoRAModelConfigEntity,
|
||||||
} from 'services/api/endpoints/models';
|
} from 'services/api/endpoints/models';
|
||||||
import ModelListItem from './ModelListItem';
|
import ModelListItem from './ModelListItem';
|
||||||
|
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||||
|
|
||||||
type ModelListProps = {
|
type ModelListProps = {
|
||||||
selectedModelId: string | undefined;
|
selectedModelId: string | undefined;
|
||||||
setSelectedModelId: (name: string | undefined) => void;
|
setSelectedModelId: (name: string | undefined) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
type ModelFormat = 'all' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
|
type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
|
||||||
|
|
||||||
type ModelType = 'main' | 'lora' | 'onnx';
|
type ModelType = 'main' | 'lora' | 'onnx';
|
||||||
|
|
||||||
@ -33,63 +33,47 @@ const ModelList = (props: ModelListProps) => {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const [nameFilter, setNameFilter] = useState<string>('');
|
const [nameFilter, setNameFilter] = useState<string>('');
|
||||||
const [modelFormatFilter, setModelFormatFilter] =
|
const [modelFormatFilter, setModelFormatFilter] =
|
||||||
useState<CombinedModelFormat>('all');
|
useState<CombinedModelFormat>('images');
|
||||||
|
|
||||||
const { filteredDiffusersModels, isLoadingDiffusersModels } =
|
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||||
useGetMainModelsQuery(ALL_BASE_MODELS, {
|
selectFromResult: ({ data }) => ({
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
filteredDiffusersModels: modelsFilter(
|
||||||
filteredDiffusersModels: modelsFilter(
|
data,
|
||||||
data,
|
'main',
|
||||||
'main',
|
'diffusers',
|
||||||
'diffusers',
|
nameFilter
|
||||||
nameFilter
|
),
|
||||||
),
|
}),
|
||||||
isLoadingDiffusersModels: isLoading,
|
});
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
const { filteredCheckpointModels, isLoadingCheckpointModels } =
|
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||||
useGetMainModelsQuery(ALL_BASE_MODELS, {
|
selectFromResult: ({ data }) => ({
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
filteredCheckpointModels: modelsFilter(
|
||||||
filteredCheckpointModels: modelsFilter(
|
data,
|
||||||
data,
|
'main',
|
||||||
'main',
|
'checkpoint',
|
||||||
'checkpoint',
|
nameFilter
|
||||||
nameFilter
|
),
|
||||||
),
|
}),
|
||||||
isLoadingCheckpointModels: isLoading,
|
});
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(
|
const { filteredLoraModels } = useGetLoRAModelsQuery(undefined, {
|
||||||
undefined,
|
selectFromResult: ({ data }) => ({
|
||||||
{
|
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
}),
|
||||||
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
|
});
|
||||||
isLoadingLoraModels: isLoading,
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const { filteredOnnxModels, isLoadingOnnxModels } = useGetOnnxModelsQuery(
|
const { filteredOnnxModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
|
||||||
ALL_BASE_MODELS,
|
selectFromResult: ({ data }) => ({
|
||||||
{
|
filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter),
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
}),
|
||||||
filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter),
|
});
|
||||||
isLoadingOnnxModels: isLoading,
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const { filteredOliveModels, isLoadingOliveModels } = useGetOnnxModelsQuery(
|
const { filteredOliveModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
|
||||||
ALL_BASE_MODELS,
|
selectFromResult: ({ data }) => ({
|
||||||
{
|
filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter),
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
}),
|
||||||
filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter),
|
});
|
||||||
isLoadingOliveModels: isLoading,
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||||
setNameFilter(e.target.value);
|
setNameFilter(e.target.value);
|
||||||
@ -100,8 +84,8 @@ const ModelList = (props: ModelListProps) => {
|
|||||||
<Flex flexDirection="column" gap={4} paddingInlineEnd={4}>
|
<Flex flexDirection="column" gap={4} paddingInlineEnd={4}>
|
||||||
<ButtonGroup isAttached>
|
<ButtonGroup isAttached>
|
||||||
<IAIButton
|
<IAIButton
|
||||||
onClick={() => setModelFormatFilter('all')}
|
onClick={() => setModelFormatFilter('images')}
|
||||||
isChecked={modelFormatFilter === 'all'}
|
isChecked={modelFormatFilter === 'images'}
|
||||||
size="sm"
|
size="sm"
|
||||||
>
|
>
|
||||||
{t('modelManager.allModels')}
|
{t('modelManager.allModels')}
|
||||||
@ -155,76 +139,95 @@ const ModelList = (props: ModelListProps) => {
|
|||||||
maxHeight={window.innerHeight - 280}
|
maxHeight={window.innerHeight - 280}
|
||||||
overflow="scroll"
|
overflow="scroll"
|
||||||
>
|
>
|
||||||
{/* Diffusers List */}
|
{['images', 'diffusers'].includes(modelFormatFilter) &&
|
||||||
{isLoadingDiffusersModels && (
|
|
||||||
<FetchingModelsLoader loadingMessage="Loading Diffusers..." />
|
|
||||||
)}
|
|
||||||
{['all', 'diffusers'].includes(modelFormatFilter) &&
|
|
||||||
!isLoadingDiffusersModels &&
|
|
||||||
filteredDiffusersModels.length > 0 && (
|
filteredDiffusersModels.length > 0 && (
|
||||||
<ModelListWrapper
|
<StyledModelContainer>
|
||||||
title="Diffusers"
|
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||||
modelList={filteredDiffusersModels}
|
<Text variant="subtext" fontSize="sm">
|
||||||
selected={{ selectedModelId, setSelectedModelId }}
|
Diffusers
|
||||||
key="diffusers"
|
</Text>
|
||||||
/>
|
{filteredDiffusersModels.map((model) => (
|
||||||
|
<ModelListItem
|
||||||
|
key={model.id}
|
||||||
|
model={model}
|
||||||
|
isSelected={selectedModelId === model.id}
|
||||||
|
setSelectedModelId={setSelectedModelId}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</Flex>
|
||||||
|
</StyledModelContainer>
|
||||||
)}
|
)}
|
||||||
{/* Checkpoints List */}
|
{['images', 'checkpoint'].includes(modelFormatFilter) &&
|
||||||
{isLoadingCheckpointModels && (
|
|
||||||
<FetchingModelsLoader loadingMessage="Loading Checkpoints..." />
|
|
||||||
)}
|
|
||||||
{['all', 'checkpoint'].includes(modelFormatFilter) &&
|
|
||||||
!isLoadingCheckpointModels &&
|
|
||||||
filteredCheckpointModels.length > 0 && (
|
filteredCheckpointModels.length > 0 && (
|
||||||
<ModelListWrapper
|
<StyledModelContainer>
|
||||||
title="Checkpoints"
|
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||||
modelList={filteredCheckpointModels}
|
<Text variant="subtext" fontSize="sm">
|
||||||
selected={{ selectedModelId, setSelectedModelId }}
|
Checkpoints
|
||||||
key="checkpoints"
|
</Text>
|
||||||
/>
|
{filteredCheckpointModels.map((model) => (
|
||||||
|
<ModelListItem
|
||||||
|
key={model.id}
|
||||||
|
model={model}
|
||||||
|
isSelected={selectedModelId === model.id}
|
||||||
|
setSelectedModelId={setSelectedModelId}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</Flex>
|
||||||
|
</StyledModelContainer>
|
||||||
)}
|
)}
|
||||||
|
{['images', 'olive'].includes(modelFormatFilter) &&
|
||||||
{/* LoRAs List */}
|
|
||||||
{isLoadingLoraModels && (
|
|
||||||
<FetchingModelsLoader loadingMessage="Loading LoRAs..." />
|
|
||||||
)}
|
|
||||||
{['all', 'lora'].includes(modelFormatFilter) &&
|
|
||||||
!isLoadingLoraModels &&
|
|
||||||
filteredLoraModels.length > 0 && (
|
|
||||||
<ModelListWrapper
|
|
||||||
title="LoRAs"
|
|
||||||
modelList={filteredLoraModels}
|
|
||||||
selected={{ selectedModelId, setSelectedModelId }}
|
|
||||||
key="loras"
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{/* Olive List */}
|
|
||||||
{isLoadingOliveModels && (
|
|
||||||
<FetchingModelsLoader loadingMessage="Loading Olives..." />
|
|
||||||
)}
|
|
||||||
{['all', 'olive'].includes(modelFormatFilter) &&
|
|
||||||
!isLoadingOliveModels &&
|
|
||||||
filteredOliveModels.length > 0 && (
|
filteredOliveModels.length > 0 && (
|
||||||
<ModelListWrapper
|
<StyledModelContainer>
|
||||||
title="Olives"
|
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||||
modelList={filteredOliveModels}
|
<Text variant="subtext" fontSize="sm">
|
||||||
selected={{ selectedModelId, setSelectedModelId }}
|
Olives
|
||||||
key="olive"
|
</Text>
|
||||||
/>
|
{filteredOliveModels.map((model) => (
|
||||||
|
<ModelListItem
|
||||||
|
key={model.id}
|
||||||
|
model={model}
|
||||||
|
isSelected={selectedModelId === model.id}
|
||||||
|
setSelectedModelId={setSelectedModelId}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</Flex>
|
||||||
|
</StyledModelContainer>
|
||||||
)}
|
)}
|
||||||
{/* Onnx List */}
|
{['images', 'onnx'].includes(modelFormatFilter) &&
|
||||||
{isLoadingOnnxModels && (
|
|
||||||
<FetchingModelsLoader loadingMessage="Loading ONNX..." />
|
|
||||||
)}
|
|
||||||
{['all', 'onnx'].includes(modelFormatFilter) &&
|
|
||||||
!isLoadingOnnxModels &&
|
|
||||||
filteredOnnxModels.length > 0 && (
|
filteredOnnxModels.length > 0 && (
|
||||||
<ModelListWrapper
|
<StyledModelContainer>
|
||||||
title="ONNX"
|
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||||
modelList={filteredOnnxModels}
|
<Text variant="subtext" fontSize="sm">
|
||||||
selected={{ selectedModelId, setSelectedModelId }}
|
Onnx
|
||||||
key="onnx"
|
</Text>
|
||||||
/>
|
{filteredOnnxModels.map((model) => (
|
||||||
|
<ModelListItem
|
||||||
|
key={model.id}
|
||||||
|
model={model}
|
||||||
|
isSelected={selectedModelId === model.id}
|
||||||
|
setSelectedModelId={setSelectedModelId}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</Flex>
|
||||||
|
</StyledModelContainer>
|
||||||
|
)}
|
||||||
|
{['images', 'lora'].includes(modelFormatFilter) &&
|
||||||
|
filteredLoraModels.length > 0 && (
|
||||||
|
<StyledModelContainer>
|
||||||
|
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||||
|
<Text variant="subtext" fontSize="sm">
|
||||||
|
LoRAs
|
||||||
|
</Text>
|
||||||
|
{filteredLoraModels.map((model) => (
|
||||||
|
<ModelListItem
|
||||||
|
key={model.id}
|
||||||
|
model={model}
|
||||||
|
isSelected={selectedModelId === model.id}
|
||||||
|
setSelectedModelId={setSelectedModelId}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</Flex>
|
||||||
|
</StyledModelContainer>
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
@ -284,52 +287,3 @@ const StyledModelContainer = (props: PropsWithChildren) => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
type ModelListWrapperProps = {
|
|
||||||
title: string;
|
|
||||||
modelList:
|
|
||||||
| MainModelConfigEntity[]
|
|
||||||
| LoRAModelConfigEntity[]
|
|
||||||
| OnnxModelConfigEntity[];
|
|
||||||
selected: ModelListProps;
|
|
||||||
};
|
|
||||||
|
|
||||||
function ModelListWrapper(props: ModelListWrapperProps) {
|
|
||||||
const { title, modelList, selected } = props;
|
|
||||||
return (
|
|
||||||
<StyledModelContainer>
|
|
||||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
|
||||||
<Text variant="subtext" fontSize="sm">
|
|
||||||
{title}
|
|
||||||
</Text>
|
|
||||||
{modelList.map((model) => (
|
|
||||||
<ModelListItem
|
|
||||||
key={model.id}
|
|
||||||
model={model}
|
|
||||||
isSelected={selected.selectedModelId === model.id}
|
|
||||||
setSelectedModelId={selected.setSelectedModelId}
|
|
||||||
/>
|
|
||||||
))}
|
|
||||||
</Flex>
|
|
||||||
</StyledModelContainer>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
function FetchingModelsLoader({ loadingMessage }: { loadingMessage?: string }) {
|
|
||||||
return (
|
|
||||||
<StyledModelContainer>
|
|
||||||
<Flex
|
|
||||||
justifyContent="center"
|
|
||||||
alignItems="center"
|
|
||||||
flexDirection="column"
|
|
||||||
p={4}
|
|
||||||
gap={8}
|
|
||||||
>
|
|
||||||
<Spinner />
|
|
||||||
<Text variant="subtext">
|
|
||||||
{loadingMessage ? loadingMessage : 'Fetching...'}
|
|
||||||
</Text>
|
|
||||||
</Flex>
|
|
||||||
</StyledModelContainer>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
@ -181,7 +181,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
providesTags: (result, error, arg) => {
|
providesTags: (result, error, arg) => {
|
||||||
const tags: ApiFullTagDescription[] = [
|
const tags: ApiFullTagDescription[] = [
|
||||||
{ type: 'OnnxModel', id: LIST_TAG },
|
{ id: 'OnnxModel', type: LIST_TAG },
|
||||||
];
|
];
|
||||||
|
|
||||||
if (result) {
|
if (result) {
|
||||||
@ -266,7 +266,6 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
invalidatesTags: [
|
invalidatesTags: [
|
||||||
{ type: 'MainModel', id: LIST_TAG },
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
{ type: 'OnnxModel', id: LIST_TAG },
|
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
importMainModels: build.mutation<
|
importMainModels: build.mutation<
|
||||||
@ -283,7 +282,6 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
invalidatesTags: [
|
invalidatesTags: [
|
||||||
{ type: 'MainModel', id: LIST_TAG },
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
{ type: 'OnnxModel', id: LIST_TAG },
|
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
|
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
|
||||||
@ -297,7 +295,6 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
invalidatesTags: [
|
invalidatesTags: [
|
||||||
{ type: 'MainModel', id: LIST_TAG },
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
{ type: 'OnnxModel', id: LIST_TAG },
|
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
deleteMainModels: build.mutation<
|
deleteMainModels: build.mutation<
|
||||||
@ -313,7 +310,6 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
invalidatesTags: [
|
invalidatesTags: [
|
||||||
{ type: 'MainModel', id: LIST_TAG },
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
{ type: 'OnnxModel', id: LIST_TAG },
|
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
convertMainModels: build.mutation<
|
convertMainModels: build.mutation<
|
||||||
@ -330,7 +326,6 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
invalidatesTags: [
|
invalidatesTags: [
|
||||||
{ type: 'MainModel', id: LIST_TAG },
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
{ type: 'OnnxModel', id: LIST_TAG },
|
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
|
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
|
||||||
@ -344,7 +339,6 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
invalidatesTags: [
|
invalidatesTags: [
|
||||||
{ type: 'MainModel', id: LIST_TAG },
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
{ type: 'OnnxModel', id: LIST_TAG },
|
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
syncModels: build.mutation<SyncModelsResponse, void>({
|
syncModels: build.mutation<SyncModelsResponse, void>({
|
||||||
@ -357,7 +351,6 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
invalidatesTags: [
|
invalidatesTags: [
|
||||||
{ type: 'MainModel', id: LIST_TAG },
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
{ type: 'OnnxModel', id: LIST_TAG },
|
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
|
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
|
||||||
|
@ -48,7 +48,6 @@ def mock_services() -> InvocationServices:
|
|||||||
images=None, # type: ignore
|
images=None, # type: ignore
|
||||||
latents=None, # type: ignore
|
latents=None, # type: ignore
|
||||||
boards=None, # type: ignore
|
boards=None, # type: ignore
|
||||||
batch_manager=None, # type: ignore
|
|
||||||
board_images=None, # type: ignore
|
board_images=None, # type: ignore
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
|
||||||
|
@ -40,7 +40,6 @@ def mock_services() -> InvocationServices:
|
|||||||
logger=None, # type: ignore
|
logger=None, # type: ignore
|
||||||
images=None, # type: ignore
|
images=None, # type: ignore
|
||||||
latents=None, # type: ignore
|
latents=None, # type: ignore
|
||||||
batch_manager=None, # type: ignore
|
|
||||||
boards=None, # type: ignore
|
boards=None, # type: ignore
|
||||||
board_images=None, # type: ignore
|
board_images=None, # type: ignore
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
|
Reference in New Issue
Block a user