feat: remove enqueue_graph routes/methods (#4922)

This is totally extraneous - it's almost identical to `enqueue_batch`.
This commit is contained in:
psychedelicious 2023-10-18 05:00:40 +11:00 committed by GitHub
parent 58a0709c1e
commit 284a257c25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 126 additions and 252 deletions

View File

@ -12,13 +12,11 @@ from invokeai.app.services.session_queue.session_queue_common import (
CancelByBatchIDsResult, CancelByBatchIDsResult,
ClearResult, ClearResult,
EnqueueBatchResult, EnqueueBatchResult,
EnqueueGraphResult,
PruneResult, PruneResult,
SessionQueueItem, SessionQueueItem,
SessionQueueItemDTO, SessionQueueItemDTO,
SessionQueueStatus, SessionQueueStatus,
) )
from invokeai.app.services.shared.graph import Graph
from invokeai.app.services.shared.pagination import CursorPaginatedResults from invokeai.app.services.shared.pagination import CursorPaginatedResults
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -33,23 +31,6 @@ class SessionQueueAndProcessorStatus(BaseModel):
processor: SessionProcessorStatus processor: SessionProcessorStatus
@session_queue_router.post(
"/{queue_id}/enqueue_graph",
operation_id="enqueue_graph",
responses={
201: {"model": EnqueueGraphResult},
},
)
async def enqueue_graph(
queue_id: str = Path(description="The queue id to perform this operation on"),
graph: Graph = Body(description="The graph to enqueue"),
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
) -> EnqueueGraphResult:
"""Enqueues a graph for single execution."""
return ApiDependencies.invoker.services.session_queue.enqueue_graph(queue_id=queue_id, graph=graph, prepend=prepend)
@session_queue_router.post( @session_queue_router.post(
"/{queue_id}/enqueue_batch", "/{queue_id}/enqueue_batch",
operation_id="enqueue_batch", operation_id="enqueue_batch",

View File

@ -9,7 +9,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
CancelByQueueIDResult, CancelByQueueIDResult,
ClearResult, ClearResult,
EnqueueBatchResult, EnqueueBatchResult,
EnqueueGraphResult,
IsEmptyResult, IsEmptyResult,
IsFullResult, IsFullResult,
PruneResult, PruneResult,
@ -17,7 +16,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItemDTO, SessionQueueItemDTO,
SessionQueueStatus, SessionQueueStatus,
) )
from invokeai.app.services.shared.graph import Graph
from invokeai.app.services.shared.pagination import CursorPaginatedResults from invokeai.app.services.shared.pagination import CursorPaginatedResults
@ -29,11 +27,6 @@ class SessionQueueBase(ABC):
"""Dequeues the next session queue item.""" """Dequeues the next session queue item."""
pass pass
@abstractmethod
def enqueue_graph(self, queue_id: str, graph: Graph, prepend: bool) -> EnqueueGraphResult:
"""Enqueues a single graph for execution."""
pass
@abstractmethod @abstractmethod
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult: def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
"""Enqueues all permutations of a batch for execution.""" """Enqueues all permutations of a batch for execution."""

View File

@ -276,14 +276,6 @@ class EnqueueBatchResult(BaseModel):
priority: int = Field(description="The priority of the enqueued batch") priority: int = Field(description="The priority of the enqueued batch")
class EnqueueGraphResult(BaseModel):
enqueued: int = Field(description="The total number of queue items enqueued")
requested: int = Field(description="The total number of queue items requested to be enqueued")
batch: Batch = Field(description="The batch that was enqueued")
priority: int = Field(description="The priority of the enqueued batch")
queue_item: SessionQueueItemDTO = Field(description="The queue item that was enqueued")
class ClearResult(BaseModel): class ClearResult(BaseModel):
"""Result of clearing the session queue""" """Result of clearing the session queue"""

View File

@ -17,7 +17,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
CancelByQueueIDResult, CancelByQueueIDResult,
ClearResult, ClearResult,
EnqueueBatchResult, EnqueueBatchResult,
EnqueueGraphResult,
IsEmptyResult, IsEmptyResult,
IsFullResult, IsFullResult,
PruneResult, PruneResult,
@ -28,7 +27,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
calc_session_count, calc_session_count,
prepare_values_to_insert, prepare_values_to_insert,
) )
from invokeai.app.services.shared.graph import Graph
from invokeai.app.services.shared.pagination import CursorPaginatedResults from invokeai.app.services.shared.pagination import CursorPaginatedResults
from invokeai.app.services.shared.sqlite import SqliteDatabase from invokeai.app.services.shared.sqlite import SqliteDatabase
@ -255,32 +253,6 @@ class SqliteSessionQueue(SessionQueueBase):
) )
return cast(Union[int, None], self.__cursor.fetchone()[0]) or 0 return cast(Union[int, None], self.__cursor.fetchone()[0]) or 0
def enqueue_graph(self, queue_id: str, graph: Graph, prepend: bool) -> EnqueueGraphResult:
enqueue_result = self.enqueue_batch(queue_id=queue_id, batch=Batch(graph=graph), prepend=prepend)
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
AND batch_id = ?
""",
(queue_id, enqueue_result.batch.batch_id),
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
if result is None:
raise SessionQueueItemNotFoundError(f"No queue item with batch id {enqueue_result.batch.batch_id}")
return EnqueueGraphResult(
**enqueue_result.model_dump(),
queue_item=SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)),
)
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult: def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
try: try:
self.__lock.acquire() self.__lock.acquire()

View File

@ -1,15 +1,9 @@
import { isAnyOf } from '@reduxjs/toolkit';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
import { startAppListening } from '..'; import { startAppListening } from '..';
const matcher = isAnyOf(
queueApi.endpoints.enqueueBatch.matchFulfilled,
queueApi.endpoints.enqueueGraph.matchFulfilled
);
export const addAnyEnqueuedListener = () => { export const addAnyEnqueuedListener = () => {
startAppListening({ startAppListening({
matcher, matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
effect: async (_, { dispatch, getState }) => { effect: async (_, { dispatch, getState }) => {
const { data } = queueApi.endpoints.getQueueStatus.select()(getState()); const { data } = queueApi.endpoints.getQueueStatus.select()(getState());

View File

@ -1,22 +1,22 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { controlAdapterImageProcessed } from 'features/controlAdapters/store/actions';
import { import {
pendingControlImagesCleared,
controlAdapterImageChanged, controlAdapterImageChanged,
selectControlAdapterById,
controlAdapterProcessedImageChanged, controlAdapterProcessedImageChanged,
pendingControlImagesCleared,
selectControlAdapterById,
} from 'features/controlAdapters/store/controlAdaptersSlice'; } from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import { SAVE_IMAGE } from 'features/nodes/util/graphBuilders/constants'; import { SAVE_IMAGE } from 'features/nodes/util/graphBuilders/constants';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
import { isImageOutput } from 'services/api/guards'; import { isImageOutput } from 'services/api/guards';
import { Graph, ImageDTO } from 'services/api/types'; import { BatchConfig, ImageDTO } from 'services/api/types';
import { socketInvocationComplete } from 'services/events/actions'; import { socketInvocationComplete } from 'services/events/actions';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { controlAdapterImageProcessed } from 'features/controlAdapters/store/actions';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
export const addControlNetImageProcessedListener = () => { export const addControlNetImageProcessedListener = () => {
startAppListening({ startAppListening({
@ -37,41 +37,46 @@ export const addControlNetImageProcessedListener = () => {
// ControlNet one-off procressing graph is just the processor node, no edges. // ControlNet one-off procressing graph is just the processor node, no edges.
// Also we need to grab the image. // Also we need to grab the image.
const graph: Graph = {
nodes: { const enqueueBatchArg: BatchConfig = {
[ca.processorNode.id]: { prepend: true,
...ca.processorNode, batch: {
is_intermediate: true, graph: {
image: { image_name: ca.controlImage }, nodes: {
}, [ca.processorNode.id]: {
[SAVE_IMAGE]: { ...ca.processorNode,
id: SAVE_IMAGE, is_intermediate: true,
type: 'save_image', image: { image_name: ca.controlImage },
is_intermediate: true, },
use_cache: false, [SAVE_IMAGE]: {
id: SAVE_IMAGE,
type: 'save_image',
is_intermediate: true,
use_cache: false,
},
},
edges: [
{
source: {
node_id: ca.processorNode.id,
field: 'image',
},
destination: {
node_id: SAVE_IMAGE,
field: 'image',
},
},
],
}, },
runs: 1,
}, },
edges: [
{
source: {
node_id: ca.processorNode.id,
field: 'image',
},
destination: {
node_id: SAVE_IMAGE,
field: 'image',
},
},
],
}; };
try { try {
const req = dispatch( const req = dispatch(
queueApi.endpoints.enqueueGraph.initiate( queueApi.endpoints.enqueueBatch.initiate(enqueueBatchArg, {
{ graph, prepend: true }, fixedCacheKey: 'enqueueBatch',
{ })
fixedCacheKey: 'enqueueGraph',
}
)
); );
const enqueueResult = await req.unwrap(); const enqueueResult = await req.unwrap();
req.reset(); req.reset();
@ -83,8 +88,8 @@ export const addControlNetImageProcessedListener = () => {
const [invocationCompleteAction] = await take( const [invocationCompleteAction] = await take(
(action): action is ReturnType<typeof socketInvocationComplete> => (action): action is ReturnType<typeof socketInvocationComplete> =>
socketInvocationComplete.match(action) && socketInvocationComplete.match(action) &&
action.payload.data.graph_execution_state_id === action.payload.data.queue_batch_id ===
enqueueResult.queue_item.session_id && enqueueResult.batch.batch_id &&
action.payload.data.source_node_id === SAVE_IMAGE action.payload.data.source_node_id === SAVE_IMAGE
); );
@ -116,7 +121,10 @@ export const addControlNetImageProcessedListener = () => {
); );
} }
} catch (error) { } catch (error) {
log.error({ graph: parseify(graph) }, t('queue.graphFailedToQueue')); log.error(
{ enqueueBatchArg: parseify(enqueueBatchArg) },
t('queue.graphFailedToQueue')
);
// handle usage-related errors // handle usage-related errors
if (error instanceof Object) { if (error instanceof Object) {

View File

@ -6,7 +6,7 @@ import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next'; import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { ImageDTO } from 'services/api/types'; import { BatchConfig, ImageDTO } from 'services/api/types';
import { createIsAllowedToUpscaleSelector } from 'features/parameters/hooks/useIsAllowedToUpscale'; import { createIsAllowedToUpscaleSelector } from 'features/parameters/hooks/useIsAllowedToUpscale';
export const upscaleRequested = createAction<{ imageDTO: ImageDTO }>( export const upscaleRequested = createAction<{ imageDTO: ImageDTO }>(
@ -44,20 +44,23 @@ export const addUpscaleRequestedListener = () => {
const { esrganModelName } = state.postprocessing; const { esrganModelName } = state.postprocessing;
const { autoAddBoardId } = state.gallery; const { autoAddBoardId } = state.gallery;
const graph = buildAdHocUpscaleGraph({ const enqueueBatchArg: BatchConfig = {
image_name, prepend: true,
esrganModelName, batch: {
autoAddBoardId, graph: buildAdHocUpscaleGraph({
}); image_name,
esrganModelName,
autoAddBoardId,
}),
runs: 1,
},
};
try { try {
const req = dispatch( const req = dispatch(
queueApi.endpoints.enqueueGraph.initiate( queueApi.endpoints.enqueueBatch.initiate(enqueueBatchArg, {
{ graph, prepend: true }, fixedCacheKey: 'enqueueBatch',
{ })
fixedCacheKey: 'enqueueGraph',
}
)
); );
const enqueueResult = await req.unwrap(); const enqueueResult = await req.unwrap();
@ -67,7 +70,10 @@ export const addUpscaleRequestedListener = () => {
t('queue.graphQueued') t('queue.graphQueued')
); );
} catch (error) { } catch (error) {
log.error({ graph: parseify(graph) }, t('queue.graphFailedToQueue')); log.error(
{ enqueueBatchArg: parseify(enqueueBatchArg) },
t('queue.graphFailedToQueue')
);
// handle usage-related errors // handle usage-related errors
if (error instanceof Object) { if (error instanceof Object) {

View File

@ -3,7 +3,6 @@ import {
// useCancelByBatchIdsMutation, // useCancelByBatchIdsMutation,
useClearQueueMutation, useClearQueueMutation,
useEnqueueBatchMutation, useEnqueueBatchMutation,
useEnqueueGraphMutation,
usePruneQueueMutation, usePruneQueueMutation,
useResumeProcessorMutation, useResumeProcessorMutation,
usePauseProcessorMutation, usePauseProcessorMutation,
@ -14,10 +13,6 @@ export const useIsQueueMutationInProgress = () => {
useEnqueueBatchMutation({ useEnqueueBatchMutation({
fixedCacheKey: 'enqueueBatch', fixedCacheKey: 'enqueueBatch',
}); });
const [_triggerEnqueueGraph, { isLoading: isLoadingEnqueueGraph }] =
useEnqueueGraphMutation({
fixedCacheKey: 'enqueueGraph',
});
const [_triggerResumeProcessor, { isLoading: isLoadingResumeProcessor }] = const [_triggerResumeProcessor, { isLoading: isLoadingResumeProcessor }] =
useResumeProcessorMutation({ useResumeProcessorMutation({
fixedCacheKey: 'resumeProcessor', fixedCacheKey: 'resumeProcessor',
@ -44,7 +39,6 @@ export const useIsQueueMutationInProgress = () => {
// }); // });
return ( return (
isLoadingEnqueueBatch || isLoadingEnqueueBatch ||
isLoadingEnqueueGraph ||
isLoadingResumeProcessor || isLoadingResumeProcessor ||
isLoadingPauseProcessor || isLoadingPauseProcessor ||
isLoadingCancelQueue || isLoadingCancelQueue ||

View File

@ -83,30 +83,6 @@ export const queueApi = api.injectEndpoints({
} }
}, },
}), }),
enqueueGraph: build.mutation<
paths['/api/v1/queue/{queue_id}/enqueue_graph']['post']['responses']['201']['content']['application/json'],
paths['/api/v1/queue/{queue_id}/enqueue_graph']['post']['requestBody']['content']['application/json']
>({
query: (arg) => ({
url: `queue/${$queueId.get()}/enqueue_graph`,
body: arg,
method: 'POST',
}),
invalidatesTags: [
'SessionQueueStatus',
'CurrentSessionQueueItem',
'NextSessionQueueItem',
],
onQueryStarted: async (arg, api) => {
const { dispatch, queryFulfilled } = api;
try {
await queryFulfilled;
resetListQueryData(dispatch);
} catch {
// no-op
}
},
}),
resumeProcessor: build.mutation< resumeProcessor: build.mutation<
paths['/api/v1/queue/{queue_id}/processor/resume']['put']['responses']['200']['content']['application/json'], paths['/api/v1/queue/{queue_id}/processor/resume']['put']['responses']['200']['content']['application/json'],
void void
@ -341,7 +317,6 @@ export const queueApi = api.injectEndpoints({
export const { export const {
useCancelByBatchIdsMutation, useCancelByBatchIdsMutation,
useEnqueueGraphMutation,
useEnqueueBatchMutation, useEnqueueBatchMutation,
usePauseProcessorMutation, usePauseProcessorMutation,
useResumeProcessorMutation, useResumeProcessorMutation,

File diff suppressed because one or more lines are too long

View File

@ -26,7 +26,6 @@ export type BatchConfig =
paths['/api/v1/queue/{queue_id}/enqueue_batch']['post']['requestBody']['content']['application/json']; paths['/api/v1/queue/{queue_id}/enqueue_batch']['post']['requestBody']['content']['application/json'];
export type EnqueueBatchResult = components['schemas']['EnqueueBatchResult']; export type EnqueueBatchResult = components['schemas']['EnqueueBatchResult'];
export type EnqueueGraphResult = components['schemas']['EnqueueGraphResult'];
/** /**
* This is an unsafe type; the object inside is not guaranteed to be valid. * This is an unsafe type; the object inside is not guaranteed to be valid.