feat(ui): use graph util for ad-hoc upscale graph

This commit is contained in:
psychedelicious 2024-07-24 06:51:45 +10:00
parent e8d2e2330e
commit aeb53563ff

View File

@ -1,15 +1,13 @@
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import type { GraphType } from 'features/nodes/util/graph/generation/Graph';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
import {
type ImageDTO,
type Invocation,
isSpandrelImageToImageModelConfig,
type NonNullableGraph,
} from 'services/api/types';
import type { ImageDTO } from 'services/api/types';
import { isSpandrelImageToImageModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { addCoreMetadataNode, getModelMetadataField, upsertMetadata } from './canvas/metadata';
import { getModelMetadataField } from './canvas/metadata';
import { SPANDREL } from './constants';
type Arg = {
@ -17,32 +15,26 @@ type Arg = {
state: RootState;
};
export const buildAdHocUpscaleGraph = async ({ image, state }: Arg): Promise<NonNullableGraph> => {
export const buildAdHocUpscaleGraph = async ({ image, state }: Arg): Promise<GraphType> => {
const { simpleUpscaleModel } = state.upscale;
assert(simpleUpscaleModel, 'No upscale model found in state');
const upscaleNode: Invocation<'spandrel_image_to_image'> = {
const g = new Graph('adhoc-upscale-graph');
g.addNode({
id: SPANDREL,
type: 'spandrel_image_to_image',
image_to_image_model: simpleUpscaleModel,
image,
board: getBoardField(state),
};
is_intermediate: false,
});
const graph: NonNullableGraph = {
id: `adhoc-upscale-graph`,
nodes: {
[SPANDREL]: upscaleNode,
},
edges: [],
};
const modelConfig = await fetchModelConfigWithTypeGuard(simpleUpscaleModel.key, isSpandrelImageToImageModelConfig);
addCoreMetadataNode(graph, {}, SPANDREL);
upsertMetadata(graph, {
g.upsertMetadata({
upscale_model: getModelMetadataField(modelConfig),
});
return graph;
return g.getGraph();
};