From 05d3a989f64dba7360aecb543081f57f54688e02 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 24 Aug 2024 14:49:17 +1000 Subject: [PATCH] feat(ui): use new Result utils for enqueueing --- .../listeners/enqueueRequestedLinear.ts | 91 ++++++++++++------- .../graph/generation/addControlAdapters.ts | 4 +- 2 files changed, 60 insertions(+), 35 deletions(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts index 5412ccf28f..d2a065ec22 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts @@ -1,6 +1,9 @@ import { logger } from 'app/logging/logger'; import { enqueueRequested } from 'app/store/actions'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; +import type { SerializableObject } from 'common/types'; +import type { Result } from 'common/util/result'; +import { isErr, withResult, withResultAsync } from 'common/util/result'; import { $canvasManager } from 'features/controlLayers/konva/CanvasManager'; import { sessionStagingAreaReset, sessionStartedStaging } from 'features/controlLayers/store/canvasV2Slice'; import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; @@ -27,48 +30,70 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) assert(manager, 'No model found in state'); let didStartStaging = false; + if (!state.canvasV2.session.isStaging && state.canvasV2.session.mode === 'compose') { dispatch(sessionStartedStaging()); didStartStaging = true; } - try { - let g: Graph; - let noise: Invocation<'noise'>; - let posCond: Invocation<'compel' | 'sdxl_compel_prompt'>; - - assert(model, 'No model found in state'); - const base = model.base; - - if (base === 'sdxl') { - const result = await buildSDXLGraph(state, manager); - g = result.g; - noise = result.noise; - posCond = result.posCond; - } else if (base === 'sd-1' || base === 'sd-2') { - const result = await buildSD1Graph(state, manager); - g = result.g; - noise = result.noise; - posCond = result.posCond; - } else { - assert(false, `No graph builders for base ${base}`); - } - - const batchConfig = prepareLinearUIBatch(state, g, prepend, noise, posCond); - - const req = dispatch( - queueApi.endpoints.enqueueBatch.initiate(batchConfig, { - fixedCacheKey: 'enqueueBatch', - }) - ); - req.reset(); - await req.unwrap(); - } catch (error) { - log.error({ error: serializeError(error) }, 'Failed to enqueue batch'); + const abortStaging = () => { if (didStartStaging && getState().canvasV2.session.isStaging) { dispatch(sessionStagingAreaReset()); } + }; + + let buildGraphResult: Result< + { g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'compel' | 'sdxl_compel_prompt'> }, + Error + >; + + assert(model, 'No model found in state'); + const base = model.base; + + switch (base) { + case 'sdxl': + buildGraphResult = await withResultAsync(() => buildSDXLGraph(state, manager)); + break; + case 'sd-1': + case `sd-2`: + buildGraphResult = await withResultAsync(() => buildSD1Graph(state, manager)); + break; + default: + assert(false, `No graph builders for base ${base}`); } + + if (isErr(buildGraphResult)) { + log.error({ error: serializeError(buildGraphResult.error) }, 'Failed to build graph'); + abortStaging(); + return; + } + + const { g, noise, posCond } = buildGraphResult.value; + + const prepareBatchResult = withResult(() => prepareLinearUIBatch(state, g, prepend, noise, posCond)); + + if (isErr(prepareBatchResult)) { + log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch'); + abortStaging(); + return; + } + + const req = dispatch( + queueApi.endpoints.enqueueBatch.initiate(prepareBatchResult.value, { + fixedCacheKey: 'enqueueBatch', + }) + ); + req.reset(); + + const enqueueResult = await withResultAsync(() => req.unwrap()); + + if (isErr(enqueueResult)) { + log.error({ error: serializeError(enqueueResult.error) }, 'Failed to enqueue batch'); + abortStaging(); + return; + } + + log.debug({ batchConfig: prepareBatchResult.value } as SerializableObject, 'Enqueued batch'); }, }); }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts index 8b3630574b..a178fe2369 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts @@ -36,7 +36,7 @@ export const addControlNets = async ( const adapter = manager.adapters.controlLayers.get(layer.id); assert(adapter, 'Adapter not found'); const imageDTO = await adapter.renderer.rasterize({ rect: bbox, attrs: { opacity: 1, filters: [] } }); - await addControlNetToGraph(g, layer, imageDTO, collector); + addControlNetToGraph(g, layer, imageDTO, collector); } return result; @@ -69,7 +69,7 @@ export const addT2IAdapters = async ( const adapter = manager.adapters.controlLayers.get(layer.id); assert(adapter, 'Adapter not found'); const imageDTO = await adapter.renderer.rasterize({ rect: bbox, attrs: { opacity: 1, filters: [] } }); - await addT2IAdapterToGraph(g, layer, imageDTO, collector); + addT2IAdapterToGraph(g, layer, imageDTO, collector); } return result;