mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
4 Commits
feat/batch
...
fix/diffus
Author | SHA1 | Date | |
---|---|---|---|
e04c25eba7 | |||
be61ffdbf6 | |||
823b879329 | |||
17c901aaf7 |
10
flake.nix
10
flake.nix
@ -34,10 +34,6 @@
|
||||
cudaPackages.cudnn
|
||||
cudaPackages.cuda_nvrtc
|
||||
cudatoolkit
|
||||
pkgconfig
|
||||
libconfig
|
||||
cmake
|
||||
blas
|
||||
freeglut
|
||||
glib
|
||||
gperf
|
||||
@ -46,12 +42,6 @@
|
||||
libGLU
|
||||
linuxPackages.nvidia_x11
|
||||
python
|
||||
(opencv4.override {
|
||||
enableGtk3 = true;
|
||||
enableFfmpeg = true;
|
||||
enableCuda = true;
|
||||
enableUnfree = true;
|
||||
})
|
||||
stdenv.cc
|
||||
stdenv.cc.cc.lib
|
||||
xorg.libX11
|
||||
|
@ -30,8 +30,6 @@ from ..services.invoker import Invoker
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.model_manager_service import ModelManagerService
|
||||
from ..services.batch_manager import BatchManager
|
||||
from ..services.batch_manager_storage import SqliteBatchProcessStorage
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
@ -118,15 +116,11 @@ class ApiDependencies:
|
||||
)
|
||||
)
|
||||
|
||||
batch_manager_storage = SqliteBatchProcessStorage(db_location)
|
||||
batch_manager = BatchManager(batch_manager_storage)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=ModelManagerService(config, logger),
|
||||
events=events,
|
||||
latents=latents,
|
||||
images=images,
|
||||
batch_manager=batch_manager,
|
||||
boards=boards,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
|
@ -15,7 +15,6 @@ from ...services.graph import (
|
||||
GraphExecutionState,
|
||||
NodeAlreadyExecutedError,
|
||||
)
|
||||
from ...services.batch_manager import Batch, BatchProcess
|
||||
from ...services.item_storage import PaginatedResults
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
@ -38,37 +37,6 @@ async def create_session(
|
||||
return session
|
||||
|
||||
|
||||
@session_router.post(
|
||||
"/batch",
|
||||
operation_id="create_batch",
|
||||
responses={
|
||||
200: {"model": BatchProcess},
|
||||
400: {"description": "Invalid json"},
|
||||
},
|
||||
)
|
||||
async def create_batch(
|
||||
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
|
||||
batches: list[Batch] = Body(description="Batch config to apply to the given graph"),
|
||||
) -> BatchProcess:
|
||||
"""Creates and starts a new new batch process"""
|
||||
batch_id = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph)
|
||||
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_id)
|
||||
return {"batch_id":batch_id}
|
||||
|
||||
|
||||
@session_router.delete(
|
||||
"{batch_process_id}/batch",
|
||||
operation_id="cancel_batch",
|
||||
responses={202: {"description": "The batch is canceled"}},
|
||||
)
|
||||
async def cancel_batch(
|
||||
batch_process_id: str = Path(description="The id of the batch process to cancel"),
|
||||
) -> Response:
|
||||
"""Creates and starts a new new batch process"""
|
||||
ApiDependencies.invoker.services.batch_manager.cancel_batch_process(batch_process_id)
|
||||
return Response(status_code=202)
|
||||
|
||||
|
||||
@session_router.get(
|
||||
"/",
|
||||
operation_id="list_sessions",
|
||||
|
@ -37,8 +37,6 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from invokeai.app.services.batch_manager import BatchManager
|
||||
from invokeai.app.services.batch_manager_storage import SqliteBatchProcessStorage
|
||||
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
@ -302,16 +300,12 @@ def invoke_cli():
|
||||
)
|
||||
)
|
||||
|
||||
batch_manager_storage = SqliteBatchProcessStorage(db_location)
|
||||
batch_manager = BatchManager(batch_manager_storage)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=model_manager,
|
||||
events=events,
|
||||
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
|
||||
images=images,
|
||||
boards=boards,
|
||||
batch_manager=batch_manager,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||
|
@ -1,139 +0,0 @@
|
||||
import networkx as nx
|
||||
import copy
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from itertools import product
|
||||
from pydantic import BaseModel, Field
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event
|
||||
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.graph import Graph, GraphExecutionState
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.batch_manager_storage import (
|
||||
BatchProcessStorageBase,
|
||||
BatchSessionNotFoundException,
|
||||
Batch,
|
||||
BatchProcess,
|
||||
BatchSession,
|
||||
BatchSessionChanges,
|
||||
)
|
||||
|
||||
|
||||
class BatchManagerBase(ABC):
|
||||
@abstractmethod
|
||||
def start(self, invoker: Invoker):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_batch_process(self, batches: list[Batch], graph: Graph) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run_batch_process(self, batch_id: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_batch_process(self, batch_process_id: str):
|
||||
pass
|
||||
|
||||
|
||||
class BatchManager(BatchManagerBase):
|
||||
"""Responsible for managing currently running and scheduled batch jobs"""
|
||||
|
||||
__invoker: Invoker
|
||||
__batches: list[BatchProcess]
|
||||
__batch_process_storage: BatchProcessStorageBase
|
||||
|
||||
def __init__(self, batch_process_storage: BatchProcessStorageBase) -> None:
|
||||
super().__init__()
|
||||
self.__batch_process_storage = batch_process_storage
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
# if we do want multithreading at some point, we could make this configurable
|
||||
self.__invoker = invoker
|
||||
self.__batches = list()
|
||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self.on_event)
|
||||
|
||||
async def on_event(self, event: Event):
|
||||
event_name = event[1]["event"]
|
||||
|
||||
match event_name:
|
||||
case "graph_execution_state_complete":
|
||||
await self.process(event, False)
|
||||
case "invocation_error":
|
||||
await self.process(event, True)
|
||||
|
||||
return event
|
||||
|
||||
async def process(self, event: Event, err: bool):
|
||||
data = event[1]["data"]
|
||||
batch_session = self.__batch_process_storage.get_session(data["graph_execution_state_id"])
|
||||
if not batch_session:
|
||||
return
|
||||
updateSession = BatchSessionChanges(
|
||||
state='error' if err else 'completed'
|
||||
)
|
||||
batch_session = self.__batch_process_storage.update_session_state(
|
||||
batch_session.batch_id,
|
||||
batch_session.session_id,
|
||||
updateSession,
|
||||
)
|
||||
self.run_batch_process(batch_session.batch_id)
|
||||
|
||||
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState:
|
||||
graph = copy.deepcopy(batch_process.graph)
|
||||
batches = batch_process.batches
|
||||
g = graph.nx_graph_flat()
|
||||
sorted_nodes = nx.topological_sort(g)
|
||||
for npath in sorted_nodes:
|
||||
node = graph.get_node(npath)
|
||||
(index, batch) = next(((i, b) for i, b in enumerate(batches) if b.node_id in node.id), (None, None))
|
||||
if batch:
|
||||
batch_index = batch_indices[index]
|
||||
datum = batch.data[batch_index]
|
||||
for key in datum:
|
||||
node.__dict__[key] = datum[key]
|
||||
graph.update_node(npath, node)
|
||||
|
||||
return GraphExecutionState(graph=graph)
|
||||
|
||||
def run_batch_process(self, batch_id: str):
|
||||
try:
|
||||
created_session = self.__batch_process_storage.get_created_session(batch_id)
|
||||
except BatchSessionNotFoundException:
|
||||
return
|
||||
ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id)
|
||||
self.__invoker.invoke(ges, invoke_all=True)
|
||||
|
||||
def _valid_batch_config(self, batch_process: BatchProcess) -> bool:
|
||||
return True
|
||||
|
||||
def create_batch_process(self, batches: list[Batch], graph: Graph) -> str:
|
||||
batch_process = BatchProcess(
|
||||
batches=batches,
|
||||
graph=graph,
|
||||
)
|
||||
if not self._valid_batch_config(batch_process):
|
||||
return None
|
||||
batch_process = self.__batch_process_storage.save(batch_process)
|
||||
self._create_sessions(batch_process)
|
||||
return batch_process.batch_id
|
||||
|
||||
def _create_sessions(self, batch_process: BatchProcess):
|
||||
batch_indices = list()
|
||||
for batch in batch_process.batches:
|
||||
batch_indices.append(list(range(len(batch.data))))
|
||||
all_batch_indices = product(*batch_indices)
|
||||
for bi in all_batch_indices:
|
||||
ges = self._create_batch_session(batch_process, bi)
|
||||
self.__invoker.services.graph_execution_manager.set(ges)
|
||||
batch_session = BatchSession(
|
||||
batch_id=batch_process.batch_id,
|
||||
session_id=ges.id,
|
||||
state="created"
|
||||
)
|
||||
self.__batch_process_storage.create_session(batch_session)
|
||||
|
||||
def cancel_batch_process(self, batch_process_id: str):
|
||||
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]
|
@ -1,505 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import cast
|
||||
import uuid
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
import json
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
)
|
||||
from invokeai.app.services.graph import Graph
|
||||
from invokeai.app.models.image import ImageField
|
||||
|
||||
from pydantic import BaseModel, Field, Extra, parse_raw_as
|
||||
|
||||
invocations = BaseInvocation.get_invocations()
|
||||
InvocationsUnion = Union[invocations] # type: ignore
|
||||
|
||||
BatchDataType = Union[str, int, float, ImageField]
|
||||
|
||||
class Batch(BaseModel):
|
||||
data: list[dict[str, BatchDataType]] = Field(description="Mapping of node field to data value")
|
||||
node_id: str = Field(description="ID of the node to batch")
|
||||
|
||||
|
||||
class BatchSession(BaseModel):
|
||||
batch_id: str = Field(description="Identifier for which batch this Index belongs to")
|
||||
session_id: str = Field(description="Session ID Created for this Batch Index")
|
||||
state: Literal["created", "completed", "inprogress", "error"] = Field(
|
||||
description="Is this session created, completed, in progress, or errored?"
|
||||
)
|
||||
|
||||
|
||||
def uuid_string():
|
||||
res = uuid.uuid4()
|
||||
return str(res)
|
||||
|
||||
class BatchProcess(BaseModel):
|
||||
batch_id: Optional[str] = Field(default_factory=uuid_string, description="Identifier for this batch")
|
||||
batches: List[Batch] = Field(
|
||||
description="List of batch configs to apply to this session",
|
||||
default_factory=list,
|
||||
)
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
|
||||
|
||||
class BatchSessionChanges(BaseModel, extra=Extra.forbid):
|
||||
state: Literal["created", "completed", "inprogress", "error"] = Field(
|
||||
description="Is this session created, completed, in progress, or errored?"
|
||||
)
|
||||
|
||||
|
||||
class BatchProcessNotFoundException(Exception):
|
||||
"""Raised when an Batch Process record is not found."""
|
||||
|
||||
def __init__(self, message="BatchProcess record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchProcessSaveException(Exception):
|
||||
"""Raised when an Batch Process record cannot be saved."""
|
||||
|
||||
def __init__(self, message="BatchProcess record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchProcessDeleteException(Exception):
|
||||
"""Raised when an Batch Process record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="BatchProcess record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchSessionNotFoundException(Exception):
|
||||
"""Raised when an Batch Session record is not found."""
|
||||
|
||||
def __init__(self, message="BatchSession record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchSessionSaveException(Exception):
|
||||
"""Raised when an Batch Session record cannot be saved."""
|
||||
|
||||
def __init__(self, message="BatchSession record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchSessionDeleteException(Exception):
|
||||
"""Raised when an Batch Session record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="BatchSession record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchProcessStorageBase(ABC):
|
||||
"""Low-level service responsible for interfacing with the Batch Process record store."""
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, batch_id: str) -> None:
|
||||
"""Deletes a Batch Process record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
batch_process: BatchProcess,
|
||||
) -> BatchProcess:
|
||||
"""Saves a Batch Process record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
batch_id: str,
|
||||
) -> BatchProcess:
|
||||
"""Gets a Batch Process record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_session(
|
||||
self,
|
||||
session: BatchSession,
|
||||
) -> BatchSession:
|
||||
"""Creates a Batch Session attached to a Batch Process."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_session(
|
||||
self,
|
||||
session_id: str
|
||||
) -> BatchSession:
|
||||
"""Gets session by session_id"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_created_session(
|
||||
self,
|
||||
batch_id: str
|
||||
) -> BatchSession:
|
||||
"""Gets all created Batch Sessions for a given Batch Process id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_created_sessions(
|
||||
self,
|
||||
batch_id: str
|
||||
) -> List[BatchSession]:
|
||||
"""Gets all created Batch Sessions for a given Batch Process id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_session_state(
|
||||
self,
|
||||
batch_id: str,
|
||||
session_id: str,
|
||||
changes: BatchSessionChanges,
|
||||
) -> BatchSession:
|
||||
"""Updates the state of a Batch Session record."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `batch_process` table and `batch_session` junction table."""
|
||||
|
||||
# Create the `batch_process` table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS batch_process (
|
||||
batch_id TEXT NOT NULL PRIMARY KEY,
|
||||
batches TEXT NOT NULL,
|
||||
graph TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_process_created_at ON batch_process (created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_batch_process_updated_at
|
||||
AFTER UPDATE
|
||||
ON batch_process FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE batch_process SET updated_at = current_timestamp
|
||||
WHERE batch_id = old.batch_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `batch_session` junction table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS batch_session (
|
||||
batch_id TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
state TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
-- enforce one-to-many relationship between batch_process and batch_session using PK
|
||||
-- (we can extend this to many-to-many later)
|
||||
PRIMARY KEY (batch_id,session_id),
|
||||
FOREIGN KEY (batch_id) REFERENCES batch_process (batch_id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for batch id
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id ON batch_session (batch_id);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for batch id, sorted by created_at
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id_created_at ON batch_session (batch_id,created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_batch_session_updated_at
|
||||
AFTER UPDATE
|
||||
ON batch_session FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE batch_session SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE batch_id = old.batch_id AND session_id = old.session_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def delete(self, batch_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM batch_process
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessDeleteException from e
|
||||
except Exception as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def save(
|
||||
self,
|
||||
batch_process: BatchProcess,
|
||||
) -> BatchProcess:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
batches = [batch.json() for batch in batch_process.batches]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO batch_process (batch_id, batches, graph)
|
||||
VALUES (?, ?, ?);
|
||||
""",
|
||||
(batch_process.batch_id, json.dumps(batches), batch_process.graph.json()),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(batch_process.batch_id)
|
||||
|
||||
def _deserialize_batch_process(self, session_dict: dict) -> BatchProcess:
|
||||
"""Deserializes a batch session."""
|
||||
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
batch_id = session_dict.get("batch_id", "unknown")
|
||||
batches_raw = session_dict.get("batches", "unknown")
|
||||
graph_raw = session_dict.get("graph", "unknown")
|
||||
batches = json.loads(batches_raw)
|
||||
batches = [parse_raw_as(Batch, batch) for batch in batches]
|
||||
return BatchProcess(
|
||||
batch_id=batch_id,
|
||||
batches=batches,
|
||||
graph=parse_raw_as(Graph, graph_raw),
|
||||
)
|
||||
|
||||
def get(
|
||||
self,
|
||||
batch_id: str,
|
||||
) -> BatchProcess:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_process
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,)
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchProcessNotFoundException
|
||||
return self._deserialize_batch_process(dict(result))
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
session: BatchSession,
|
||||
) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state)
|
||||
VALUES (?, ?, ?);
|
||||
""",
|
||||
(session.batch_id, session.session_id, session.state),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get_session(session.session_id)
|
||||
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
session_id: str
|
||||
) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_session
|
||||
WHERE session_id= ?;
|
||||
""",
|
||||
(session_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchSessionNotFoundException
|
||||
return self._deserialize_batch_session(dict(result))
|
||||
|
||||
def _deserialize_batch_session(self, session_dict: dict) -> BatchSession:
|
||||
"""Deserializes a batch session."""
|
||||
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
batch_id = session_dict.get("batch_id", "unknown")
|
||||
session_id = session_dict.get("session_id", "unknown")
|
||||
state = session_dict.get("state", "unknown")
|
||||
|
||||
return BatchSession(
|
||||
batch_id=batch_id,
|
||||
session_id=session_id,
|
||||
state=state,
|
||||
)
|
||||
|
||||
|
||||
def get_created_session(
|
||||
self,
|
||||
batch_id: str
|
||||
) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_session
|
||||
WHERE batch_id = ? AND state = 'created';
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchSessionNotFoundException
|
||||
session = self._deserialize_batch_session(dict(result))
|
||||
return session
|
||||
|
||||
|
||||
def get_created_sessions(
|
||||
self,
|
||||
batch_id: str
|
||||
) -> List[BatchSession]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_session
|
||||
WHERE batch_id = ? AND state = created;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchSessionNotFoundException
|
||||
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
||||
return sessions
|
||||
|
||||
|
||||
def update_session_state(
|
||||
self,
|
||||
batch_id: str,
|
||||
session_id: str,
|
||||
changes: BatchSessionChanges,
|
||||
) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Change the state of a batch session
|
||||
if changes.state is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE batch_session
|
||||
SET state = ?
|
||||
WHERE batch_id = ? AND session_id = ?;
|
||||
""",
|
||||
(changes.state, batch_id, session_id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get_session(session_id)
|
@ -4,7 +4,6 @@ from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
from invokeai.app.services.batch_manager import BatchManagerBase
|
||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||
from invokeai.app.services.boards import BoardServiceABC
|
||||
from invokeai.app.services.images import ImageServiceABC
|
||||
@ -22,7 +21,6 @@ class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
|
||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||
batch_manager: "BatchManagerBase"
|
||||
board_images: "BoardImagesServiceABC"
|
||||
boards: "BoardServiceABC"
|
||||
configuration: "InvokeAIAppConfig"
|
||||
@ -38,7 +36,6 @@ class InvocationServices:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_manager: "BatchManagerBase",
|
||||
board_images: "BoardImagesServiceABC",
|
||||
boards: "BoardServiceABC",
|
||||
configuration: "InvokeAIAppConfig",
|
||||
@ -52,7 +49,6 @@ class InvocationServices:
|
||||
processor: "InvocationProcessorABC",
|
||||
queue: "InvocationQueueABC",
|
||||
):
|
||||
self.batch_manager = batch_manager
|
||||
self.board_images = board_images
|
||||
self.boards = boards
|
||||
self.boards = boards
|
||||
|
@ -29,7 +29,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._conn = sqlite3.connect(
|
||||
self._filename, check_same_thread=False
|
||||
) # TODO: figure out a better threading solution
|
||||
self._conn.execute('pragma journal_mode=wal')
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
self._create_table()
|
||||
|
@ -651,7 +651,10 @@ class TextualInversionModel:
|
||||
file_path = Path(file_path)
|
||||
|
||||
result = cls() # TODO:
|
||||
result.name = file_path.stem # TODO:
|
||||
if file_path.name == "learned_embeds.bin":
|
||||
result.name = file_path.parent.name
|
||||
else:
|
||||
result.name = file_path.stem
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
|
@ -188,7 +188,7 @@ class ModelCache(object):
|
||||
cache_entry = self._cached_models.get(key, None)
|
||||
if cache_entry is None:
|
||||
self.logger.info(
|
||||
f"Loading model {model_path}, type {base_model.value}:{model_type.value}:{submodel.value if submodel else ''}"
|
||||
f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
|
||||
)
|
||||
|
||||
# this will remove older cached models until
|
||||
|
@ -472,7 +472,7 @@ class ModelManager(object):
|
||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||
override_path = getattr(model_config, submodel_type)
|
||||
if override_path:
|
||||
model_path = self.app_config.root_path / override_path
|
||||
model_path = self.resolve_path(override_path)
|
||||
model_type = submodel_type
|
||||
submodel_type = None
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
@ -670,7 +670,7 @@ class ModelManager(object):
|
||||
# TODO: if path changed and old_model.path inside models folder should we delete this too?
|
||||
|
||||
# remove conversion cache as config changed
|
||||
old_model_path = self.app_config.root_path / old_model.path
|
||||
old_model_path = self.resolve_model_path(old_model.path)
|
||||
old_model_cache = self._get_model_cache_path(old_model_path)
|
||||
if old_model_cache.exists():
|
||||
if old_model_cache.is_dir():
|
||||
@ -780,7 +780,7 @@ class ModelManager(object):
|
||||
model_type,
|
||||
**submodel,
|
||||
)
|
||||
checkpoint_path = self.app_config.root_path / info["path"]
|
||||
checkpoint_path = self.resolve_model_path(info["path"])
|
||||
old_diffusers_path = self.resolve_model_path(model.location)
|
||||
new_diffusers_path = (
|
||||
dest_directory or self.app_config.models_path / base_model.value / model_type.value
|
||||
@ -992,7 +992,7 @@ class ModelManager(object):
|
||||
model_manager=self,
|
||||
prediction_type_helper=ask_user_for_prediction_type,
|
||||
)
|
||||
known_paths = {config.root_path / x["path"] for x in self.list_models()}
|
||||
known_paths = {self.resolve_model_path(x["path"]) for x in self.list_models()}
|
||||
directories = {
|
||||
config.root_path / x
|
||||
for x in [
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { ButtonGroup, Flex, Spinner, Text } from '@chakra-ui/react';
|
||||
import { ButtonGroup, Flex, Text } from '@chakra-ui/react';
|
||||
import { EntityState } from '@reduxjs/toolkit';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
@ -6,23 +6,23 @@ import { forEach } from 'lodash-es';
|
||||
import type { ChangeEvent, PropsWithChildren } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
import {
|
||||
LoRAModelConfigEntity,
|
||||
MainModelConfigEntity,
|
||||
OnnxModelConfigEntity,
|
||||
useGetLoRAModelsQuery,
|
||||
useGetMainModelsQuery,
|
||||
useGetOnnxModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
LoRAModelConfigEntity,
|
||||
} from 'services/api/endpoints/models';
|
||||
import ModelListItem from './ModelListItem';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
|
||||
type ModelListProps = {
|
||||
selectedModelId: string | undefined;
|
||||
setSelectedModelId: (name: string | undefined) => void;
|
||||
};
|
||||
|
||||
type ModelFormat = 'all' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
|
||||
type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
|
||||
|
||||
type ModelType = 'main' | 'lora' | 'onnx';
|
||||
|
||||
@ -33,63 +33,47 @@ const ModelList = (props: ModelListProps) => {
|
||||
const { t } = useTranslation();
|
||||
const [nameFilter, setNameFilter] = useState<string>('');
|
||||
const [modelFormatFilter, setModelFormatFilter] =
|
||||
useState<CombinedModelFormat>('all');
|
||||
useState<CombinedModelFormat>('images');
|
||||
|
||||
const { filteredDiffusersModels, isLoadingDiffusersModels } =
|
||||
useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredDiffusersModels: modelsFilter(
|
||||
data,
|
||||
'main',
|
||||
'diffusers',
|
||||
nameFilter
|
||||
),
|
||||
isLoadingDiffusersModels: isLoading,
|
||||
}),
|
||||
});
|
||||
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredDiffusersModels: modelsFilter(
|
||||
data,
|
||||
'main',
|
||||
'diffusers',
|
||||
nameFilter
|
||||
),
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredCheckpointModels, isLoadingCheckpointModels } =
|
||||
useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredCheckpointModels: modelsFilter(
|
||||
data,
|
||||
'main',
|
||||
'checkpoint',
|
||||
nameFilter
|
||||
),
|
||||
isLoadingCheckpointModels: isLoading,
|
||||
}),
|
||||
});
|
||||
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredCheckpointModels: modelsFilter(
|
||||
data,
|
||||
'main',
|
||||
'checkpoint',
|
||||
nameFilter
|
||||
),
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(
|
||||
undefined,
|
||||
{
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
|
||||
isLoadingLoraModels: isLoading,
|
||||
}),
|
||||
}
|
||||
);
|
||||
const { filteredLoraModels } = useGetLoRAModelsQuery(undefined, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredOnnxModels, isLoadingOnnxModels } = useGetOnnxModelsQuery(
|
||||
ALL_BASE_MODELS,
|
||||
{
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter),
|
||||
isLoadingOnnxModels: isLoading,
|
||||
}),
|
||||
}
|
||||
);
|
||||
const { filteredOnnxModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter),
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredOliveModels, isLoadingOliveModels } = useGetOnnxModelsQuery(
|
||||
ALL_BASE_MODELS,
|
||||
{
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter),
|
||||
isLoadingOliveModels: isLoading,
|
||||
}),
|
||||
}
|
||||
);
|
||||
const { filteredOliveModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter),
|
||||
}),
|
||||
});
|
||||
|
||||
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setNameFilter(e.target.value);
|
||||
@ -100,8 +84,8 @@ const ModelList = (props: ModelListProps) => {
|
||||
<Flex flexDirection="column" gap={4} paddingInlineEnd={4}>
|
||||
<ButtonGroup isAttached>
|
||||
<IAIButton
|
||||
onClick={() => setModelFormatFilter('all')}
|
||||
isChecked={modelFormatFilter === 'all'}
|
||||
onClick={() => setModelFormatFilter('images')}
|
||||
isChecked={modelFormatFilter === 'images'}
|
||||
size="sm"
|
||||
>
|
||||
{t('modelManager.allModels')}
|
||||
@ -155,76 +139,95 @@ const ModelList = (props: ModelListProps) => {
|
||||
maxHeight={window.innerHeight - 280}
|
||||
overflow="scroll"
|
||||
>
|
||||
{/* Diffusers List */}
|
||||
{isLoadingDiffusersModels && (
|
||||
<FetchingModelsLoader loadingMessage="Loading Diffusers..." />
|
||||
)}
|
||||
{['all', 'diffusers'].includes(modelFormatFilter) &&
|
||||
!isLoadingDiffusersModels &&
|
||||
{['images', 'diffusers'].includes(modelFormatFilter) &&
|
||||
filteredDiffusersModels.length > 0 && (
|
||||
<ModelListWrapper
|
||||
title="Diffusers"
|
||||
modelList={filteredDiffusersModels}
|
||||
selected={{ selectedModelId, setSelectedModelId }}
|
||||
key="diffusers"
|
||||
/>
|
||||
<StyledModelContainer>
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
Diffusers
|
||||
</Text>
|
||||
{filteredDiffusersModels.map((model) => (
|
||||
<ModelListItem
|
||||
key={model.id}
|
||||
model={model}
|
||||
isSelected={selectedModelId === model.id}
|
||||
setSelectedModelId={setSelectedModelId}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
)}
|
||||
{/* Checkpoints List */}
|
||||
{isLoadingCheckpointModels && (
|
||||
<FetchingModelsLoader loadingMessage="Loading Checkpoints..." />
|
||||
)}
|
||||
{['all', 'checkpoint'].includes(modelFormatFilter) &&
|
||||
!isLoadingCheckpointModels &&
|
||||
{['images', 'checkpoint'].includes(modelFormatFilter) &&
|
||||
filteredCheckpointModels.length > 0 && (
|
||||
<ModelListWrapper
|
||||
title="Checkpoints"
|
||||
modelList={filteredCheckpointModels}
|
||||
selected={{ selectedModelId, setSelectedModelId }}
|
||||
key="checkpoints"
|
||||
/>
|
||||
<StyledModelContainer>
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
Checkpoints
|
||||
</Text>
|
||||
{filteredCheckpointModels.map((model) => (
|
||||
<ModelListItem
|
||||
key={model.id}
|
||||
model={model}
|
||||
isSelected={selectedModelId === model.id}
|
||||
setSelectedModelId={setSelectedModelId}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
)}
|
||||
|
||||
{/* LoRAs List */}
|
||||
{isLoadingLoraModels && (
|
||||
<FetchingModelsLoader loadingMessage="Loading LoRAs..." />
|
||||
)}
|
||||
{['all', 'lora'].includes(modelFormatFilter) &&
|
||||
!isLoadingLoraModels &&
|
||||
filteredLoraModels.length > 0 && (
|
||||
<ModelListWrapper
|
||||
title="LoRAs"
|
||||
modelList={filteredLoraModels}
|
||||
selected={{ selectedModelId, setSelectedModelId }}
|
||||
key="loras"
|
||||
/>
|
||||
)}
|
||||
{/* Olive List */}
|
||||
{isLoadingOliveModels && (
|
||||
<FetchingModelsLoader loadingMessage="Loading Olives..." />
|
||||
)}
|
||||
{['all', 'olive'].includes(modelFormatFilter) &&
|
||||
!isLoadingOliveModels &&
|
||||
{['images', 'olive'].includes(modelFormatFilter) &&
|
||||
filteredOliveModels.length > 0 && (
|
||||
<ModelListWrapper
|
||||
title="Olives"
|
||||
modelList={filteredOliveModels}
|
||||
selected={{ selectedModelId, setSelectedModelId }}
|
||||
key="olive"
|
||||
/>
|
||||
<StyledModelContainer>
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
Olives
|
||||
</Text>
|
||||
{filteredOliveModels.map((model) => (
|
||||
<ModelListItem
|
||||
key={model.id}
|
||||
model={model}
|
||||
isSelected={selectedModelId === model.id}
|
||||
setSelectedModelId={setSelectedModelId}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
)}
|
||||
{/* Onnx List */}
|
||||
{isLoadingOnnxModels && (
|
||||
<FetchingModelsLoader loadingMessage="Loading ONNX..." />
|
||||
)}
|
||||
{['all', 'onnx'].includes(modelFormatFilter) &&
|
||||
!isLoadingOnnxModels &&
|
||||
{['images', 'onnx'].includes(modelFormatFilter) &&
|
||||
filteredOnnxModels.length > 0 && (
|
||||
<ModelListWrapper
|
||||
title="ONNX"
|
||||
modelList={filteredOnnxModels}
|
||||
selected={{ selectedModelId, setSelectedModelId }}
|
||||
key="onnx"
|
||||
/>
|
||||
<StyledModelContainer>
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
Onnx
|
||||
</Text>
|
||||
{filteredOnnxModels.map((model) => (
|
||||
<ModelListItem
|
||||
key={model.id}
|
||||
model={model}
|
||||
isSelected={selectedModelId === model.id}
|
||||
setSelectedModelId={setSelectedModelId}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
)}
|
||||
{['images', 'lora'].includes(modelFormatFilter) &&
|
||||
filteredLoraModels.length > 0 && (
|
||||
<StyledModelContainer>
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
LoRAs
|
||||
</Text>
|
||||
{filteredLoraModels.map((model) => (
|
||||
<ModelListItem
|
||||
key={model.id}
|
||||
model={model}
|
||||
isSelected={selectedModelId === model.id}
|
||||
setSelectedModelId={setSelectedModelId}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
@ -284,52 +287,3 @@ const StyledModelContainer = (props: PropsWithChildren) => {
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
type ModelListWrapperProps = {
|
||||
title: string;
|
||||
modelList:
|
||||
| MainModelConfigEntity[]
|
||||
| LoRAModelConfigEntity[]
|
||||
| OnnxModelConfigEntity[];
|
||||
selected: ModelListProps;
|
||||
};
|
||||
|
||||
function ModelListWrapper(props: ModelListWrapperProps) {
|
||||
const { title, modelList, selected } = props;
|
||||
return (
|
||||
<StyledModelContainer>
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
{title}
|
||||
</Text>
|
||||
{modelList.map((model) => (
|
||||
<ModelListItem
|
||||
key={model.id}
|
||||
model={model}
|
||||
isSelected={selected.selectedModelId === model.id}
|
||||
setSelectedModelId={selected.setSelectedModelId}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
);
|
||||
}
|
||||
|
||||
function FetchingModelsLoader({ loadingMessage }: { loadingMessage?: string }) {
|
||||
return (
|
||||
<StyledModelContainer>
|
||||
<Flex
|
||||
justifyContent="center"
|
||||
alignItems="center"
|
||||
flexDirection="column"
|
||||
p={4}
|
||||
gap={8}
|
||||
>
|
||||
<Spinner />
|
||||
<Text variant="subtext">
|
||||
{loadingMessage ? loadingMessage : 'Fetching...'}
|
||||
</Text>
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
);
|
||||
}
|
||||
|
@ -181,7 +181,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
providesTags: (result, error, arg) => {
|
||||
const tags: ApiFullTagDescription[] = [
|
||||
{ type: 'OnnxModel', id: LIST_TAG },
|
||||
{ id: 'OnnxModel', type: LIST_TAG },
|
||||
];
|
||||
|
||||
if (result) {
|
||||
@ -266,7 +266,6 @@ export const modelsApi = api.injectEndpoints({
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
{ type: 'OnnxModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
importMainModels: build.mutation<
|
||||
@ -283,7 +282,6 @@ export const modelsApi = api.injectEndpoints({
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
{ type: 'OnnxModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
|
||||
@ -297,7 +295,6 @@ export const modelsApi = api.injectEndpoints({
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
{ type: 'OnnxModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
deleteMainModels: build.mutation<
|
||||
@ -313,7 +310,6 @@ export const modelsApi = api.injectEndpoints({
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
{ type: 'OnnxModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
convertMainModels: build.mutation<
|
||||
@ -330,7 +326,6 @@ export const modelsApi = api.injectEndpoints({
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
{ type: 'OnnxModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
|
||||
@ -344,7 +339,6 @@ export const modelsApi = api.injectEndpoints({
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
{ type: 'OnnxModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
syncModels: build.mutation<SyncModelsResponse, void>({
|
||||
@ -357,7 +351,6 @@ export const modelsApi = api.injectEndpoints({
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
{ type: 'OnnxModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
|
||||
|
@ -48,7 +48,6 @@ def mock_services() -> InvocationServices:
|
||||
images=None, # type: ignore
|
||||
latents=None, # type: ignore
|
||||
boards=None, # type: ignore
|
||||
batch_manager=None, # type: ignore
|
||||
board_images=None, # type: ignore
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
|
||||
|
@ -40,7 +40,6 @@ def mock_services() -> InvocationServices:
|
||||
logger=None, # type: ignore
|
||||
images=None, # type: ignore
|
||||
latents=None, # type: ignore
|
||||
batch_manager=None, # type: ignore
|
||||
boards=None, # type: ignore
|
||||
board_images=None, # type: ignore
|
||||
queue=MemoryInvocationQueue(),
|
||||
|
Reference in New Issue
Block a user