This commit is contained in:
Mary Hipp 2024-07-19 20:16:03 -04:00 committed by psychedelicious
parent f18431a999
commit 845d77916e

View File

@ -2,7 +2,8 @@ import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import type { GraphType } from 'features/nodes/util/graph/generation/Graph';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { ImageDTO, isNonRefinerMainModelConfig, isSpandrelImageToImageModelConfig } from 'services/api/types';
import type { ImageDTO } from 'services/api/types';
import { isNonRefinerMainModelConfig, isSpandrelImageToImageModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import {
@ -28,8 +29,11 @@ import { getBoardField, getSDXLStylePrompts } from './graphBuilderUtils';
const UPSCALE_SCALE = 2;
export const getOutputImageSize = (initialImage: ImageDTO) => {
return { width: ((initialImage.width * UPSCALE_SCALE) / 8) * 8, height: ((initialImage.height * UPSCALE_SCALE) / 8) * 8 }
}
return {
width: ((initialImage.width * UPSCALE_SCALE) / 8) * 8,
height: ((initialImage.height * UPSCALE_SCALE) / 8) * 8,
};
};
export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promise<GraphType> => {
const { model, cfgScale: cfg_scale, scheduler, steps, vaePrecision, seed, vae } = state.generation;
@ -42,7 +46,7 @@ export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promis
assert(upscaleInitialImage, 'No initial image found in state');
assert(tileControlnetModel, 'Tile controlnet is required');
const { width: outputWidth, height: outputHeight } = getOutputImageSize(upscaleInitialImage)
const { width: outputWidth, height: outputHeight } = getOutputImageSize(upscaleInitialImage);
const g = new Graph();
@ -59,10 +63,9 @@ export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promis
type: 'spandrel_image_to_image',
image_to_image_model: upscaleModel,
tile_size: 500,
image: upscaleInitialImage
image: upscaleInitialImage,
});
const unsharpMaskNode2 = g.addNode({
id: `${UNSHARP_MASK}_2`,
type: 'unsharp_mask',