refactor(ui): use "entity" instead of "data" for canvas

This commit is contained in:
psychedelicious 2024-06-17 17:16:13 +10:00
parent 844590a571
commit e7df53e260
19 changed files with 112 additions and 112 deletions

View File

@ -5,7 +5,7 @@ import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon'; import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
import { heightChanged, widthChanged } from 'features/controlLayers/store/canvasV2Slice'; import { heightChanged, widthChanged } from 'features/controlLayers/store/canvasV2Slice';
import { selectOptimalDimension } from 'features/controlLayers/store/selectors'; import { selectOptimalDimension } from 'features/controlLayers/store/selectors';
import type { ControlAdapterData } from 'features/controlLayers/store/types'; import type { ControlAdapterEntity } from 'features/controlLayers/store/types';
import type { ImageDraggableData, TypesafeDroppableData } from 'features/dnd/types'; import type { ImageDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize'; import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
import { memo, useCallback, useEffect, useMemo, useState } from 'react'; import { memo, useCallback, useEffect, useMemo, useState } from 'react';
@ -20,7 +20,7 @@ import {
import type { ImageDTO, PostUploadAction } from 'services/api/types'; import type { ImageDTO, PostUploadAction } from 'services/api/types';
type Props = { type Props = {
controlAdapter: ControlAdapterData; controlAdapter: ControlAdapterEntity;
onChangeImage: (imageDTO: ImageDTO | null) => void; onChangeImage: (imageDTO: ImageDTO | null) => void;
droppableData: TypesafeDroppableData; droppableData: TypesafeDroppableData;
postUploadAction: PostUploadAction; postUploadAction: PostUploadAction;

View File

@ -7,7 +7,7 @@ import {
} from 'features/controlLayers/konva/naming'; } from 'features/controlLayers/konva/naming';
import { createBboxRect } from 'features/controlLayers/konva/renderers/objects'; import { createBboxRect } from 'features/controlLayers/konva/renderers/objects';
import { imageDataToDataURL } from "features/controlLayers/konva/util"; import { imageDataToDataURL } from "features/controlLayers/konva/util";
import type { ControlAdapterData, LayerData, RegionalGuidanceData } from 'features/controlLayers/store/types'; import type { ControlAdapterEntity, LayerEntity, RegionEntity } from 'features/controlLayers/store/types';
import Konva from 'konva'; import Konva from 'konva';
import type { IRect } from 'konva/lib/types'; import type { IRect } from 'konva/lib/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
@ -186,7 +186,7 @@ const filterCAChildren = (node: Konva.Node): boolean => node.name() === CA_LAYER
*/ */
export const updateBboxes = ( export const updateBboxes = (
stage: Konva.Stage, stage: Konva.Stage,
entityStates: (ControlAdapterData | LayerData | RegionalGuidanceData)[], entityStates: (ControlAdapterEntity | LayerEntity | RegionEntity)[],
onBboxChanged: (layerId: string, bbox: IRect | null) => void onBboxChanged: (layerId: string, bbox: IRect | null) => void
): void => { ): void => {
for (const entityState of entityStates) { for (const entityState of entityStates) {

View File

@ -1,6 +1,6 @@
import { LightnessToAlphaFilter } from 'features/controlLayers/konva/filters'; import { LightnessToAlphaFilter } from 'features/controlLayers/konva/filters';
import { CA_LAYER_IMAGE_NAME, CA_LAYER_NAME, getCAImageId } from 'features/controlLayers/konva/naming'; import { CA_LAYER_IMAGE_NAME, CA_LAYER_NAME, getCAImageId } from 'features/controlLayers/konva/naming';
import type { ControlAdapterData } from 'features/controlLayers/store/types'; import type { ControlAdapterEntity } from 'features/controlLayers/store/types';
import Konva from 'konva'; import Konva from 'konva';
import type { ImageDTO } from 'services/api/types'; import type { ImageDTO } from 'services/api/types';
@ -14,7 +14,7 @@ import type { ImageDTO } from 'services/api/types';
* @param stage The konva stage * @param stage The konva stage
* @param ca The control adapter layer state * @param ca The control adapter layer state
*/ */
const createCALayer = (stage: Konva.Stage, ca: ControlAdapterData): Konva.Layer => { const createCALayer = (stage: Konva.Stage, ca: ControlAdapterEntity): Konva.Layer => {
const konvaLayer = new Konva.Layer({ const konvaLayer = new Konva.Layer({
id: ca.id, id: ca.id,
name: CA_LAYER_NAME, name: CA_LAYER_NAME,
@ -50,7 +50,7 @@ const createCALayerImage = (konvaLayer: Konva.Layer, imageEl: HTMLImageElement):
const updateCALayerImageSource = async ( const updateCALayerImageSource = async (
stage: Konva.Stage, stage: Konva.Stage,
konvaLayer: Konva.Layer, konvaLayer: Konva.Layer,
ca: ControlAdapterData, ca: ControlAdapterEntity,
getImageDTO: (imageName: string) => Promise<ImageDTO | null> getImageDTO: (imageName: string) => Promise<ImageDTO | null>
): Promise<void> => { ): Promise<void> => {
const image = ca.processedImage ?? ca.image; const image = ca.processedImage ?? ca.image;
@ -90,7 +90,7 @@ const updateCALayerImageSource = async (
* @param ca The control adapter layer state * @param ca The control adapter layer state
*/ */
const updateCALayerImageAttrs = (stage: Konva.Stage, konvaImage: Konva.Image, ca: ControlAdapterData): void => { const updateCALayerImageAttrs = (stage: Konva.Stage, konvaImage: Konva.Image, ca: ControlAdapterEntity): void => {
let needsCache = false; let needsCache = false;
// TODO(psyche): `node.filters()` returns null if no filters; report upstream // TODO(psyche): `node.filters()` returns null if no filters; report upstream
const filters = konvaImage.filters() ?? []; const filters = konvaImage.filters() ?? [];
@ -128,7 +128,7 @@ const updateCALayerImageAttrs = (stage: Konva.Stage, konvaImage: Konva.Image, ca
*/ */
export const renderCALayer = ( export const renderCALayer = (
stage: Konva.Stage, stage: Konva.Stage,
ca: ControlAdapterData, ca: ControlAdapterEntity,
getImageDTO: (imageName: string) => Promise<ImageDTO | null> getImageDTO: (imageName: string) => Promise<ImageDTO | null>
): void => { ): void => {
const konvaLayer = stage.findOne<Konva.Layer>(`#${ca.id}`) ?? createCALayer(stage, ca); const konvaLayer = stage.findOne<Konva.Layer>(`#${ca.id}`) ?? createCALayer(stage, ca);
@ -157,7 +157,7 @@ export const renderCALayer = (
export const renderControlAdapters = ( export const renderControlAdapters = (
stage: Konva.Stage, stage: Konva.Stage,
controlAdapters: ControlAdapterData[], controlAdapters: ControlAdapterEntity[],
getImageDTO: (imageName: string) => Promise<ImageDTO | null> getImageDTO: (imageName: string) => Promise<ImageDTO | null>
): void => { ): void => {
// Destroy nonexistent layers // Destroy nonexistent layers

View File

@ -8,10 +8,10 @@ import { renderRGLayer } from 'features/controlLayers/konva/renderers/rgLayer';
import { mapId, selectRenderableLayers } from 'features/controlLayers/konva/util'; import { mapId, selectRenderableLayers } from 'features/controlLayers/konva/util';
import type { import type {
CanvasEntity, CanvasEntity,
ControlAdapterData, ControlAdapterEntity,
LayerData, LayerEntity,
PosChangedArg, PosChangedArg,
RegionalGuidanceData, RegionEntity,
Tool, Tool,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
import type Konva from 'konva'; import type Konva from 'konva';
@ -33,9 +33,9 @@ import type { ImageDTO } from 'services/api/types';
*/ */
const renderLayers = ( const renderLayers = (
stage: Konva.Stage, stage: Konva.Stage,
layers: LayerData[], layers: LayerEntity[],
controlAdapters: ControlAdapterData[], controlAdapters: ControlAdapterEntity[],
regions: RegionalGuidanceData[], regions: RegionEntity[],
rgGlobalOpacity: number, rgGlobalOpacity: number,
tool: Tool, tool: Tool,
selectedEntity: CanvasEntity | null, selectedEntity: CanvasEntity | null,
@ -96,9 +96,9 @@ export const debouncedRenderers: typeof renderers = getDebouncedRenderers();
export const arrangeEntities = ( export const arrangeEntities = (
stage: Konva.Stage, stage: Konva.Stage,
layers: LayerData[], layers: LayerEntity[],
controlAdapters: ControlAdapterData[], controlAdapters: ControlAdapterEntity[],
regions: RegionalGuidanceData[] regions: RegionEntity[]
): void => { ): void => {
let zIndex = 0; let zIndex = 0;
stage.findOne<Konva.Layer>(`#${BACKGROUND_LAYER_ID}`)?.zIndex(++zIndex); stage.findOne<Konva.Layer>(`#${BACKGROUND_LAYER_ID}`)?.zIndex(++zIndex);

View File

@ -14,7 +14,7 @@ import {
createRectShape, createRectShape,
} from 'features/controlLayers/konva/renderers/objects'; } from 'features/controlLayers/konva/renderers/objects';
import { mapId, selectRasterObjects } from 'features/controlLayers/konva/util'; import { mapId, selectRasterObjects } from 'features/controlLayers/konva/util';
import type { CanvasEntity, LayerData, PosChangedArg, Tool } from 'features/controlLayers/store/types'; import type { CanvasEntity, LayerEntity, PosChangedArg, Tool } from 'features/controlLayers/store/types';
import Konva from 'konva'; import Konva from 'konva';
/** /**
@ -29,7 +29,7 @@ import Konva from 'konva';
*/ */
const createRasterLayer = ( const createRasterLayer = (
stage: Konva.Stage, stage: Konva.Stage,
layerState: LayerData, layerState: LayerEntity,
onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void
): Konva.Layer => { ): Konva.Layer => {
// This layer hasn't been added to the konva state yet // This layer hasn't been added to the konva state yet
@ -62,7 +62,7 @@ const createRasterLayer = (
*/ */
export const renderRasterLayer = async ( export const renderRasterLayer = async (
stage: Konva.Stage, stage: Konva.Stage,
layerState: LayerData, layerState: LayerEntity,
tool: Tool, tool: Tool,
onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void
) => { ) => {
@ -146,7 +146,7 @@ export const renderRasterLayer = async (
export const renderLayers = ( export const renderLayers = (
stage: Konva.Stage, stage: Konva.Stage,
layers: LayerData[], layers: LayerEntity[],
tool: Tool, tool: Tool,
onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void
): void => { ): void => {

View File

@ -18,7 +18,7 @@ import {
createRectShape, createRectShape,
} from 'features/controlLayers/konva/renderers/objects'; } from 'features/controlLayers/konva/renderers/objects';
import { mapId, selectVectorMaskObjects } from 'features/controlLayers/konva/util'; import { mapId, selectVectorMaskObjects } from 'features/controlLayers/konva/util';
import type { CanvasEntity, PosChangedArg, RegionalGuidanceData, Tool } from 'features/controlLayers/store/types'; import type { CanvasEntity, PosChangedArg, RegionEntity, Tool } from 'features/controlLayers/store/types';
import Konva from 'konva'; import Konva from 'konva';
/** /**
@ -46,7 +46,7 @@ const createCompositingRect = (konvaLayer: Konva.Layer): Konva.Rect => {
*/ */
const createRGLayer = ( const createRGLayer = (
stage: Konva.Stage, stage: Konva.Stage,
rg: RegionalGuidanceData, rg: RegionEntity,
onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void
): Konva.Layer => { ): Konva.Layer => {
// This layer hasn't been added to the konva state yet // This layer hasn't been added to the konva state yet
@ -80,7 +80,7 @@ const createRGLayer = (
*/ */
export const renderRGLayer = ( export const renderRGLayer = (
stage: Konva.Stage, stage: Konva.Stage,
rg: RegionalGuidanceData, rg: RegionEntity,
globalMaskLayerOpacity: number, globalMaskLayerOpacity: number,
tool: Tool, tool: Tool,
selectedEntity: CanvasEntity | null, selectedEntity: CanvasEntity | null,
@ -234,7 +234,7 @@ export const renderRGLayer = (
export const renderRegions = ( export const renderRegions = (
stage: Konva.Stage, stage: Konva.Stage,
regions: RegionalGuidanceData[], regions: RegionEntity[],
maskOpacity: number, maskOpacity: number,
tool: Tool, tool: Tool,
selectedEntity: CanvasEntity | null, selectedEntity: CanvasEntity | null,

View File

@ -9,7 +9,7 @@ import { v4 as uuidv4 } from 'uuid';
import type { import type {
CanvasV2State, CanvasV2State,
ControlAdapterData, ControlAdapterEntity,
ControlModeV2, ControlModeV2,
ControlNetConfig, ControlNetConfig,
ControlNetData, ControlNetData,
@ -50,7 +50,7 @@ export const controlAdaptersReducers = {
payload: { id: uuidv4(), ...payload }, payload: { id: uuidv4(), ...payload },
}), }),
}, },
caRecalled: (state, action: PayloadAction<{ data: ControlAdapterData }>) => { caRecalled: (state, action: PayloadAction<{ data: ControlAdapterEntity }>) => {
const { data } = action.payload; const { data } = action.payload;
state.controlAdapters.push(data); state.controlAdapters.push(data);
state.selectedEntityIdentifier = { type: 'control_adapter', id: data.id }; state.selectedEntityIdentifier = { type: 'control_adapter', id: data.id };

View File

@ -4,7 +4,7 @@ import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import type { CanvasV2State, CLIPVisionModelV2, IPAdapterConfig, IPAdapterData, IPMethodV2 } from './types'; import type { CanvasV2State, CLIPVisionModelV2, IPAdapterConfig, IPAdapterEntity, IPMethodV2 } from './types';
import { imageDTOToImageWithDims } from './types'; import { imageDTOToImageWithDims } from './types';
export const selectIPA = (state: CanvasV2State, id: string) => state.ipAdapters.find((ipa) => ipa.id === id); export const selectIPA = (state: CanvasV2State, id: string) => state.ipAdapters.find((ipa) => ipa.id === id);
@ -18,7 +18,7 @@ export const ipAdaptersReducers = {
ipaAdded: { ipaAdded: {
reducer: (state, action: PayloadAction<{ id: string; config: IPAdapterConfig }>) => { reducer: (state, action: PayloadAction<{ id: string; config: IPAdapterConfig }>) => {
const { id, config } = action.payload; const { id, config } = action.payload;
const layer: IPAdapterData = { const layer: IPAdapterEntity = {
id, id,
type: 'ip_adapter', type: 'ip_adapter',
isEnabled: true, isEnabled: true,
@ -29,7 +29,7 @@ export const ipAdaptersReducers = {
}, },
prepare: (payload: { config: IPAdapterConfig }) => ({ payload: { id: uuidv4(), ...payload } }), prepare: (payload: { config: IPAdapterConfig }) => ({ payload: { id: uuidv4(), ...payload } }),
}, },
ipaRecalled: (state, action: PayloadAction<{ data: IPAdapterData }>) => { ipaRecalled: (state, action: PayloadAction<{ data: IPAdapterEntity }>) => {
const { data } = action.payload; const { data } = action.payload;
state.ipAdapters.push(data); state.ipAdapters.push(data);
state.selectedEntityIdentifier = { type: 'ip_adapter', id: data.id }; state.selectedEntityIdentifier = { type: 'ip_adapter', id: data.id };

View File

@ -10,7 +10,7 @@ import type {
CanvasV2State, CanvasV2State,
EraserLineAddedArg, EraserLineAddedArg,
ImageObjectAddedArg, ImageObjectAddedArg,
LayerData, LayerEntity,
PointAddedToLineArg, PointAddedToLineArg,
RectShapeAddedArg, RectShapeAddedArg,
} from './types'; } from './types';
@ -42,7 +42,7 @@ export const layersReducers = {
}, },
prepare: () => ({ payload: { id: uuidv4() } }), prepare: () => ({ payload: { id: uuidv4() } }),
}, },
layerRecalled: (state, action: PayloadAction<{ data: LayerData }>) => { layerRecalled: (state, action: PayloadAction<{ data: LayerEntity }>) => {
const { data } = action.payload; const { data } = action.payload;
state.layers.push(data); state.layers.push(data);
state.selectedEntityIdentifier = { type: 'layer', id: data.id }; state.selectedEntityIdentifier = { type: 'layer', id: data.id };

View File

@ -14,10 +14,10 @@ import { v4 as uuidv4 } from 'uuid';
import type { import type {
BrushLineAddedArg, BrushLineAddedArg,
EraserLineAddedArg, EraserLineAddedArg,
IPAdapterData, IPAdapterEntity,
PointAddedToLineArg, PointAddedToLineArg,
RectShapeAddedArg, RectShapeAddedArg,
RegionalGuidanceData, RegionEntity,
RgbColor, RgbColor,
} from './types'; } from './types';
import { isLine } from './types'; import { isLine } from './types';
@ -55,7 +55,7 @@ export const regionsReducers = {
rgAdded: { rgAdded: {
reducer: (state, action: PayloadAction<{ id: string }>) => { reducer: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload; const { id } = action.payload;
const rg: RegionalGuidanceData = { const rg: RegionEntity = {
id, id,
type: 'regional_guidance', type: 'regional_guidance',
isEnabled: true, isEnabled: true,
@ -87,7 +87,7 @@ export const regionsReducers = {
rg.bboxNeedsUpdate = false; rg.bboxNeedsUpdate = false;
rg.imageCache = null; rg.imageCache = null;
}, },
rgRecalled: (state, action: PayloadAction<{ data: RegionalGuidanceData }>) => { rgRecalled: (state, action: PayloadAction<{ data: RegionEntity }>) => {
const { data } = action.payload; const { data } = action.payload;
state.regions.push(data); state.regions.push(data);
state.selectedEntityIdentifier = { type: 'regional_guidance', id: data.id }; state.selectedEntityIdentifier = { type: 'regional_guidance', id: data.id };
@ -194,7 +194,7 @@ export const regionsReducers = {
} }
rg.autoNegative = autoNegative; rg.autoNegative = autoNegative;
}, },
rgIPAdapterAdded: (state, action: PayloadAction<{ id: string; ipAdapter: IPAdapterData }>) => { rgIPAdapterAdded: (state, action: PayloadAction<{ id: string; ipAdapter: IPAdapterEntity }>) => {
const { id, ipAdapter } = action.payload; const { id, ipAdapter } = action.payload;
const rg = selectRG(state, id); const rg = selectRG(state, id);
if (!rg) { if (!rg) {

View File

@ -573,7 +573,7 @@ const zRect = z.object({
height: z.number().min(1), height: z.number().min(1),
}); });
export const zLayerData = z.object({ export const zLayerEntity = z.object({
id: zId, id: zId,
type: z.literal('layer'), type: z.literal('layer'),
isEnabled: z.boolean(), isEnabled: z.boolean(),
@ -584,9 +584,9 @@ export const zLayerData = z.object({
opacity: zOpacity, opacity: zOpacity,
objects: z.array(zLayerObject), objects: z.array(zLayerObject),
}); });
export type LayerData = z.infer<typeof zLayerData>; export type LayerEntity = z.infer<typeof zLayerEntity>;
export const zIPAdapterData = z.object({ export const zIPAdapterEntity = z.object({
id: zId, id: zId,
type: z.literal('ip_adapter'), type: z.literal('ip_adapter'),
isEnabled: z.boolean(), isEnabled: z.boolean(),
@ -597,9 +597,9 @@ export const zIPAdapterData = z.object({
clipVisionModel: zCLIPVisionModelV2, clipVisionModel: zCLIPVisionModelV2,
beginEndStepPct: zBeginEndStepPct, beginEndStepPct: zBeginEndStepPct,
}); });
export type IPAdapterData = z.infer<typeof zIPAdapterData>; export type IPAdapterEntity = z.infer<typeof zIPAdapterEntity>;
export type IPAdapterConfig = Pick< export type IPAdapterConfig = Pick<
IPAdapterData, IPAdapterEntity,
'weight' | 'image' | 'beginEndStepPct' | 'model' | 'clipVisionModel' | 'method' 'weight' | 'image' | 'beginEndStepPct' | 'model' | 'clipVisionModel' | 'method'
>; >;
@ -636,7 +636,7 @@ const zMaskObject = z
}) })
.pipe(z.discriminatedUnion('type', [zBrushLine, zEraserline, zRectShape])); .pipe(z.discriminatedUnion('type', [zBrushLine, zEraserline, zRectShape]));
export const zRegionalGuidanceData = z.object({ export const zRegionEntity = z.object({
id: zId, id: zId,
type: z.literal('regional_guidance'), type: z.literal('regional_guidance'),
isEnabled: z.boolean(), isEnabled: z.boolean(),
@ -647,12 +647,12 @@ export const zRegionalGuidanceData = z.object({
objects: z.array(zMaskObject), objects: z.array(zMaskObject),
positivePrompt: zParameterPositivePrompt.nullable(), positivePrompt: zParameterPositivePrompt.nullable(),
negativePrompt: zParameterNegativePrompt.nullable(), negativePrompt: zParameterNegativePrompt.nullable(),
ipAdapters: z.array(zIPAdapterData), ipAdapters: z.array(zIPAdapterEntity),
fill: zRgbColor, fill: zRgbColor,
autoNegative: zAutoNegative, autoNegative: zAutoNegative,
imageCache: zImageWithDims.nullable(), imageCache: zImageWithDims.nullable(),
}); });
export type RegionalGuidanceData = z.infer<typeof zRegionalGuidanceData>; export type RegionEntity = z.infer<typeof zRegionEntity>;
const zColorFill = z.object({ const zColorFill = z.object({
type: z.literal('color_fill'), type: z.literal('color_fill'),
@ -680,7 +680,7 @@ export type InpaintMaskData = z.infer<typeof zInpaintMaskData>;
const zFilter = z.enum(['none', 'LightnessToAlphaFilter']); const zFilter = z.enum(['none', 'LightnessToAlphaFilter']);
export type Filter = z.infer<typeof zFilter>; export type Filter = z.infer<typeof zFilter>;
const zControlAdapterDataBase = z.object({ const zControlAdapterEntityBase = z.object({
id: zId, id: zId,
type: z.literal('control_adapter'), type: z.literal('control_adapter'),
isEnabled: z.boolean(), isEnabled: z.boolean(),
@ -698,18 +698,18 @@ const zControlAdapterDataBase = z.object({
beginEndStepPct: zBeginEndStepPct, beginEndStepPct: zBeginEndStepPct,
model: zModelIdentifierField.nullable(), model: zModelIdentifierField.nullable(),
}); });
const zControlNetData = zControlAdapterDataBase.extend({ const zControlNetEntity = zControlAdapterEntityBase.extend({
adapterType: z.literal('controlnet'), adapterType: z.literal('controlnet'),
controlMode: zControlModeV2, controlMode: zControlModeV2,
}); });
export type ControlNetData = z.infer<typeof zControlNetData>; export type ControlNetData = z.infer<typeof zControlNetEntity>;
const zT2IAdapterData = zControlAdapterDataBase.extend({ const zT2IAdapterEntity = zControlAdapterEntityBase.extend({
adapterType: z.literal('t2i_adapter'), adapterType: z.literal('t2i_adapter'),
}); });
export type T2IAdapterData = z.infer<typeof zT2IAdapterData>; export type T2IAdapterData = z.infer<typeof zT2IAdapterEntity>;
export const zControlAdapterData = z.discriminatedUnion('adapterType', [zControlNetData, zT2IAdapterData]); export const zControlAdapterEntity = z.discriminatedUnion('adapterType', [zControlNetEntity, zT2IAdapterEntity]);
export type ControlAdapterData = z.infer<typeof zControlAdapterData>; export type ControlAdapterEntity = z.infer<typeof zControlAdapterEntity>;
export type ControlNetConfig = Pick< export type ControlNetConfig = Pick<
ControlNetData, ControlNetData,
| 'adapterType' | 'adapterType'
@ -778,7 +778,7 @@ export type BoundingBoxScaleMethod = z.infer<typeof zBoundingBoxScaleMethod>;
export const isBoundingBoxScaleMethod = (v: unknown): v is BoundingBoxScaleMethod => export const isBoundingBoxScaleMethod = (v: unknown): v is BoundingBoxScaleMethod =>
zBoundingBoxScaleMethod.safeParse(v).success; zBoundingBoxScaleMethod.safeParse(v).success;
export type CanvasEntity = LayerData | IPAdapterData | ControlAdapterData | RegionalGuidanceData | InpaintMaskData; export type CanvasEntity = LayerEntity | IPAdapterEntity | ControlAdapterEntity | RegionEntity | InpaintMaskData;
export type CanvasEntityIdentifier = Pick<CanvasEntity, 'id' | 'type'>; export type CanvasEntityIdentifier = Pick<CanvasEntity, 'id' | 'type'>;
export type Dimensions = { export type Dimensions = {
@ -796,10 +796,10 @@ export type LoRA = {
export type CanvasV2State = { export type CanvasV2State = {
_version: 3; _version: 3;
selectedEntityIdentifier: CanvasEntityIdentifier | null; selectedEntityIdentifier: CanvasEntityIdentifier | null;
layers: LayerData[]; layers: LayerEntity[];
controlAdapters: ControlAdapterData[]; controlAdapters: ControlAdapterEntity[];
ipAdapters: IPAdapterData[]; ipAdapters: IPAdapterEntity[];
regions: RegionalGuidanceData[]; regions: RegionEntity[];
loras: LoRA[]; loras: LoRA[];
tool: { tool: {
selected: Tool; selected: Tool;

View File

@ -1,4 +1,4 @@
import type { LayerData } from 'features/controlLayers/store/types'; import type { LayerEntity } from 'features/controlLayers/store/types';
import { MetadataItemView } from 'features/metadata/components/MetadataItemView'; import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
import type { MetadataHandlers } from 'features/metadata/types'; import type { MetadataHandlers } from 'features/metadata/types';
import { handlers } from 'features/metadata/util/handlers'; import { handlers } from 'features/metadata/util/handlers';
@ -9,7 +9,7 @@ type Props = {
}; };
export const MetadataLayers = ({ metadata }: Props) => { export const MetadataLayers = ({ metadata }: Props) => {
const [layers, setLayers] = useState<LayerData[]>([]); const [layers, setLayers] = useState<LayerEntity[]>([]);
useEffect(() => { useEffect(() => {
const parse = async () => { const parse = async () => {
@ -40,8 +40,8 @@ const MetadataViewLayer = ({
handlers, handlers,
}: { }: {
label: string; label: string;
layer: LayerData; layer: LayerEntity;
handlers: MetadataHandlers<LayerData[], LayerData>; handlers: MetadataHandlers<LayerEntity[], LayerEntity>;
}) => { }) => {
const onRecall = useCallback(() => { const onRecall = useCallback(() => {
if (!handlers.recallItem) { if (!handlers.recallItem) {

View File

@ -2,7 +2,7 @@ import { getStore } from 'app/store/nanostores/store';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { objectKeys } from 'common/util/objectKeys'; import { objectKeys } from 'common/util/objectKeys';
import { shouldConcatPromptsChanged } from 'features/controlLayers/store/canvasV2Slice'; import { shouldConcatPromptsChanged } from 'features/controlLayers/store/canvasV2Slice';
import type { LayerData, LoRA } from 'features/controlLayers/store/types'; import type { LayerEntity, LoRA } from 'features/controlLayers/store/types';
import type { import type {
AnyControlAdapterConfigMetadata, AnyControlAdapterConfigMetadata,
BuildMetadataHandlers, BuildMetadataHandlers,
@ -48,7 +48,7 @@ const renderControlAdapterValue: MetadataRenderValueFunc<AnyControlAdapterConfig
return `${value.model.key} (${value.model.base.toUpperCase()}) - ${value.weight}`; return `${value.model.key} (${value.model.base.toUpperCase()}) - ${value.weight}`;
} }
}; };
const renderLayerValue: MetadataRenderValueFunc<LayerData> = async (layer) => { const renderLayerValue: MetadataRenderValueFunc<LayerEntity> = async (layer) => {
if (layer.type === 'initial_image_layer') { if (layer.type === 'initial_image_layer') {
let rendered = t('controlLayers.globalInitialImageLayer'); let rendered = t('controlLayers.globalInitialImageLayer');
if (layer.image) { if (layer.image) {
@ -88,7 +88,7 @@ const renderLayerValue: MetadataRenderValueFunc<LayerData> = async (layer) => {
} }
assert(false, 'Unknown layer type'); assert(false, 'Unknown layer type');
}; };
const renderLayersValue: MetadataRenderValueFunc<LayerData[]> = async (layers) => { const renderLayersValue: MetadataRenderValueFunc<LayerEntity[]> = async (layers) => {
return `${layers.length} ${t('controlLayers.layers', { count: layers.length })}`; return `${layers.length} ${t('controlLayers.layers', { count: layers.length })}`;
}; };

View File

@ -1,6 +1,6 @@
import { getCAId, getImageObjectId, getIPAId, getLayerId } from 'features/controlLayers/konva/naming'; import { getCAId, getImageObjectId, getIPAId, getLayerId } from 'features/controlLayers/konva/naming';
import { defaultLoRAConfig } from 'features/controlLayers/store/lorasReducers'; import { defaultLoRAConfig } from 'features/controlLayers/store/lorasReducers';
import type { ControlAdapterData, IPAdapterData, LayerData, LoRA } from 'features/controlLayers/store/types'; import type { ControlAdapterEntity, IPAdapterEntity, LayerEntity, LoRA } from 'features/controlLayers/store/types';
import { import {
CA_PROCESSOR_DATA, CA_PROCESSOR_DATA,
imageDTOToImageWithDims, imageDTOToImageWithDims,
@ -8,7 +8,7 @@ import {
initialIPAdapterV2, initialIPAdapterV2,
initialT2IAdapterV2, initialT2IAdapterV2,
isProcessorTypeV2, isProcessorTypeV2,
zLayerData, zLayerEntity,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
import type { import type {
ControlNetConfigMetadata, ControlNetConfigMetadata,
@ -424,22 +424,22 @@ const parseAllIPAdapters: MetadataParseFunc<IPAdapterConfigMetadata[]> = async (
}; };
//#region Control Layers //#region Control Layers
const parseLayer: MetadataParseFunc<LayerData> = async (metadataItem) => zLayerData.parseAsync(metadataItem); const parseLayer: MetadataParseFunc<LayerEntity> = async (metadataItem) => zLayerEntity.parseAsync(metadataItem);
const parseLayers: MetadataParseFunc<LayerData[]> = async (metadata) => { const parseLayers: MetadataParseFunc<LayerEntity[]> = async (metadata) => {
// We need to support recalling pre-Control Layers metadata into Control Layers. A separate set of parsers handles // We need to support recalling pre-Control Layers metadata into Control Layers. A separate set of parsers handles
// taking pre-CL metadata and parsing it into layers. It doesn't always map 1-to-1, so this is best-effort. For // taking pre-CL metadata and parsing it into layers. It doesn't always map 1-to-1, so this is best-effort. For
// example, CL Control Adapters don't support resize mode, so we simply omit that property. // example, CL Control Adapters don't support resize mode, so we simply omit that property.
try { try {
const layers: LayerData[] = []; const layers: LayerEntity[] = [];
try { try {
const control_layers = await getProperty(metadata, 'control_layers'); const control_layers = await getProperty(metadata, 'control_layers');
const controlLayersRaw = await getProperty(control_layers, 'layers', isArray); const controlLayersRaw = await getProperty(control_layers, 'layers', isArray);
const controlLayersParseResults = await Promise.allSettled(controlLayersRaw.map(parseLayer)); const controlLayersParseResults = await Promise.allSettled(controlLayersRaw.map(parseLayer));
const controlLayers = controlLayersParseResults const controlLayers = controlLayersParseResults
.filter((result): result is PromiseFulfilledResult<LayerData> => result.status === 'fulfilled') .filter((result): result is PromiseFulfilledResult<LayerEntity> => result.status === 'fulfilled')
.map((result) => result.value); .map((result) => result.value);
layers.push(...controlLayers); layers.push(...controlLayers);
} catch { } catch {
@ -452,7 +452,7 @@ const parseLayers: MetadataParseFunc<LayerData[]> = async (metadata) => {
controlNetsRaw.map(async (cn) => await parseControlNetToControlAdapterLayer(cn)) controlNetsRaw.map(async (cn) => await parseControlNetToControlAdapterLayer(cn))
); );
const controlNetsAsLayers = controlNetsParseResults const controlNetsAsLayers = controlNetsParseResults
.filter((result): result is PromiseFulfilledResult<ControlAdapterData> => result.status === 'fulfilled') .filter((result): result is PromiseFulfilledResult<ControlAdapterEntity> => result.status === 'fulfilled')
.map((result) => result.value); .map((result) => result.value);
layers.push(...controlNetsAsLayers); layers.push(...controlNetsAsLayers);
} catch { } catch {
@ -465,7 +465,7 @@ const parseLayers: MetadataParseFunc<LayerData[]> = async (metadata) => {
t2iAdaptersRaw.map(async (cn) => await parseT2IAdapterToControlAdapterLayer(cn)) t2iAdaptersRaw.map(async (cn) => await parseT2IAdapterToControlAdapterLayer(cn))
); );
const t2iAdaptersAsLayers = t2iAdaptersParseResults const t2iAdaptersAsLayers = t2iAdaptersParseResults
.filter((result): result is PromiseFulfilledResult<ControlAdapterData> => result.status === 'fulfilled') .filter((result): result is PromiseFulfilledResult<ControlAdapterEntity> => result.status === 'fulfilled')
.map((result) => result.value); .map((result) => result.value);
layers.push(...t2iAdaptersAsLayers); layers.push(...t2iAdaptersAsLayers);
} catch { } catch {
@ -478,7 +478,7 @@ const parseLayers: MetadataParseFunc<LayerData[]> = async (metadata) => {
ipAdaptersRaw.map(async (cn) => await parseIPAdapterToIPAdapterLayer(cn)) ipAdaptersRaw.map(async (cn) => await parseIPAdapterToIPAdapterLayer(cn))
); );
const ipAdaptersAsLayers = ipAdaptersParseResults const ipAdaptersAsLayers = ipAdaptersParseResults
.filter((result): result is PromiseFulfilledResult<IPAdapterData> => result.status === 'fulfilled') .filter((result): result is PromiseFulfilledResult<IPAdapterEntity> => result.status === 'fulfilled')
.map((result) => result.value); .map((result) => result.value);
layers.push(...ipAdaptersAsLayers); layers.push(...ipAdaptersAsLayers);
} catch { } catch {
@ -498,14 +498,14 @@ const parseLayers: MetadataParseFunc<LayerData[]> = async (metadata) => {
} }
}; };
const parseInitialImageToInitialImageLayer: MetadataParseFunc<LayerData> = async (metadata) => { const parseInitialImageToInitialImageLayer: MetadataParseFunc<LayerEntity> = async (metadata) => {
// TODO(psyche): recall denoise strength // TODO(psyche): recall denoise strength
// const denoisingStrength = await getProperty(metadata, 'strength', isParameterStrength); // const denoisingStrength = await getProperty(metadata, 'strength', isParameterStrength);
const imageName = await getProperty(metadata, 'init_image', isString); const imageName = await getProperty(metadata, 'init_image', isString);
const imageDTO = await getImageDTO(imageName); const imageDTO = await getImageDTO(imageName);
assert(imageDTO, 'ImageDTO is null'); assert(imageDTO, 'ImageDTO is null');
const id = getLayerId(uuidv4()); const id = getLayerId(uuidv4());
const layer: LayerData = { const layer: LayerEntity = {
id, id,
type: 'layer', type: 'layer',
bbox: null, bbox: null,
@ -529,7 +529,7 @@ const parseInitialImageToInitialImageLayer: MetadataParseFunc<LayerData> = async
return layer; return layer;
}; };
const parseControlNetToControlAdapterLayer: MetadataParseFunc<ControlAdapterData> = async (metadataItem) => { const parseControlNetToControlAdapterLayer: MetadataParseFunc<ControlAdapterEntity> = async (metadataItem) => {
const control_model = await getProperty(metadataItem, 'control_model'); const control_model = await getProperty(metadataItem, 'control_model');
const key = await getModelKey(control_model, 'controlnet'); const key = await getModelKey(control_model, 'controlnet');
const controlNetModel = await fetchModelConfigWithTypeGuard(key, isControlNetModelConfig); const controlNetModel = await fetchModelConfigWithTypeGuard(key, isControlNetModelConfig);
@ -569,7 +569,7 @@ const parseControlNetToControlAdapterLayer: MetadataParseFunc<ControlAdapterData
const imageDTO = image ? await getImageDTO(image.image_name) : null; const imageDTO = image ? await getImageDTO(image.image_name) : null;
const processedImageDTO = processedImage ? await getImageDTO(processedImage.image_name) : null; const processedImageDTO = processedImage ? await getImageDTO(processedImage.image_name) : null;
const layer: ControlAdapterData = { const layer: ControlAdapterEntity = {
id: getCAId(uuidv4()), id: getCAId(uuidv4()),
type: 'control_adapter', type: 'control_adapter',
bbox: null, bbox: null,
@ -593,7 +593,7 @@ const parseControlNetToControlAdapterLayer: MetadataParseFunc<ControlAdapterData
return layer; return layer;
}; };
const parseT2IAdapterToControlAdapterLayer: MetadataParseFunc<ControlAdapterData> = async (metadataItem) => { const parseT2IAdapterToControlAdapterLayer: MetadataParseFunc<ControlAdapterEntity> = async (metadataItem) => {
const t2i_adapter_model = await getProperty(metadataItem, 't2i_adapter_model'); const t2i_adapter_model = await getProperty(metadataItem, 't2i_adapter_model');
const key = await getModelKey(t2i_adapter_model, 't2i_adapter'); const key = await getModelKey(t2i_adapter_model, 't2i_adapter');
const t2iAdapterModel = await fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig); const t2iAdapterModel = await fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig);
@ -630,7 +630,7 @@ const parseT2IAdapterToControlAdapterLayer: MetadataParseFunc<ControlAdapterData
const imageDTO = image ? await getImageDTO(image.image_name) : null; const imageDTO = image ? await getImageDTO(image.image_name) : null;
const processedImageDTO = processedImage ? await getImageDTO(processedImage.image_name) : null; const processedImageDTO = processedImage ? await getImageDTO(processedImage.image_name) : null;
const layer: ControlAdapterData = { const layer: ControlAdapterEntity = {
id: getCAId(uuidv4()), id: getCAId(uuidv4()),
bbox: null, bbox: null,
bboxNeedsUpdate: true, bboxNeedsUpdate: true,
@ -653,7 +653,7 @@ const parseT2IAdapterToControlAdapterLayer: MetadataParseFunc<ControlAdapterData
return layer; return layer;
}; };
const parseIPAdapterToIPAdapterLayer: MetadataParseFunc<IPAdapterData> = async (metadataItem) => { const parseIPAdapterToIPAdapterLayer: MetadataParseFunc<IPAdapterEntity> = async (metadataItem) => {
const ip_adapter_model = await getProperty(metadataItem, 'ip_adapter_model'); const ip_adapter_model = await getProperty(metadataItem, 'ip_adapter_model');
const key = await getModelKey(ip_adapter_model, 'ip_adapter'); const key = await getModelKey(ip_adapter_model, 'ip_adapter');
const ipAdapterModel = await fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig); const ipAdapterModel = await fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig);
@ -685,7 +685,7 @@ const parseIPAdapterToIPAdapterLayer: MetadataParseFunc<IPAdapterData> = async (
]; ];
const imageDTO = image ? await getImageDTO(image.image_name) : null; const imageDTO = image ? await getImageDTO(image.image_name) : null;
const layer: IPAdapterData = { const layer: IPAdapterEntity = {
id: getIPAId(uuidv4()), id: getIPAId(uuidv4()),
type: 'ip_adapter', type: 'ip_adapter',
isEnabled: true, isEnabled: true,

View File

@ -40,11 +40,11 @@ import {
widthChanged, widthChanged,
} from 'features/controlLayers/store/canvasV2Slice'; } from 'features/controlLayers/store/canvasV2Slice';
import type { import type {
ControlAdapterData, ControlAdapterEntity,
IPAdapterData, IPAdapterEntity,
LayerData, LayerEntity,
LoRA, LoRA,
RegionalGuidanceData, RegionEntity,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice'; import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
import type { import type {
@ -246,7 +246,7 @@ const recallIPAdapters: MetadataRecallFunc<IPAdapterConfigMetadata[]> = (ipAdapt
}); });
}; };
const recallCA: MetadataRecallFunc<ControlAdapterData> = async (ca) => { const recallCA: MetadataRecallFunc<ControlAdapterEntity> = async (ca) => {
const { dispatch } = getStore(); const { dispatch } = getStore();
const clone = deepClone(ca); const clone = deepClone(ca);
if (clone.image) { if (clone.image) {
@ -275,7 +275,7 @@ const recallCA: MetadataRecallFunc<ControlAdapterData> = async (ca) => {
return; return;
}; };
const recallIPA: MetadataRecallFunc<IPAdapterData> = async (ipa) => { const recallIPA: MetadataRecallFunc<IPAdapterEntity> = async (ipa) => {
const { dispatch } = getStore(); const { dispatch } = getStore();
const clone = deepClone(ipa); const clone = deepClone(ipa);
if (clone.image) { if (clone.image) {
@ -298,7 +298,7 @@ const recallIPA: MetadataRecallFunc<IPAdapterData> = async (ipa) => {
return; return;
}; };
const recallRG: MetadataRecallFunc<RegionalGuidanceData> = async (rg) => { const recallRG: MetadataRecallFunc<RegionEntity> = async (rg) => {
const { dispatch } = getStore(); const { dispatch } = getStore();
const clone = deepClone(rg); const clone = deepClone(rg);
// Strip out the uploaded mask image property - this is an intermediate image // Strip out the uploaded mask image property - this is an intermediate image
@ -328,7 +328,7 @@ const recallRG: MetadataRecallFunc<RegionalGuidanceData> = async (rg) => {
}; };
//#region Control Layers //#region Control Layers
const recallLayer: MetadataRecallFunc<LayerData> = async (layer) => { const recallLayer: MetadataRecallFunc<LayerEntity> = async (layer) => {
const { dispatch } = getStore(); const { dispatch } = getStore();
const clone = deepClone(layer); const clone = deepClone(layer);
const invalidObjects: string[] = []; const invalidObjects: string[] = [];
@ -359,7 +359,7 @@ const recallLayer: MetadataRecallFunc<LayerData> = async (layer) => {
return; return;
}; };
const recallLayers: MetadataRecallFunc<LayerData[]> = (layers) => { const recallLayers: MetadataRecallFunc<LayerEntity[]> = (layers) => {
const { dispatch } = getStore(); const { dispatch } = getStore();
dispatch(layerAllDeleted()); dispatch(layerAllDeleted());
for (const l of layers) { for (const l of layers) {

View File

@ -1,5 +1,5 @@
import { getStore } from 'app/store/nanostores/store'; import { getStore } from 'app/store/nanostores/store';
import type { LayerData, LoRA } from 'features/controlLayers/store/types'; import type { LayerEntity, LoRA } from 'features/controlLayers/store/types';
import type { import type {
ControlNetConfigMetadata, ControlNetConfigMetadata,
IPAdapterConfigMetadata, IPAdapterConfigMetadata,
@ -109,7 +109,7 @@ const validateIPAdapters: MetadataValidateFunc<IPAdapterConfigMetadata[]> = (ipA
return new Promise((resolve) => resolve(validatedIPAdapters)); return new Promise((resolve) => resolve(validatedIPAdapters));
}; };
const validateLayer: MetadataValidateFunc<LayerData> = async (layer) => { const validateLayer: MetadataValidateFunc<LayerEntity> = async (layer) => {
if (layer.type === 'control_adapter_layer') { if (layer.type === 'control_adapter_layer') {
const model = layer.controlAdapter.model; const model = layer.controlAdapter.model;
assert(model, 'Control Adapter layer missing model'); assert(model, 'Control Adapter layer missing model');
@ -131,8 +131,8 @@ const validateLayer: MetadataValidateFunc<LayerData> = async (layer) => {
return layer; return layer;
}; };
const validateLayers: MetadataValidateFunc<LayerData[]> = async (layers) => { const validateLayers: MetadataValidateFunc<LayerEntity[]> = async (layers) => {
const validatedLayers: LayerData[] = []; const validatedLayers: LayerEntity[] = [];
for (const l of layers) { for (const l of layers) {
try { try {
const validated = await validateLayer(l); const validated = await validateLayer(l);

View File

@ -1,5 +1,5 @@
import type { import type {
ControlAdapterData, ControlAdapterEntity,
ControlNetData, ControlNetData,
ImageWithDims, ImageWithDims,
ProcessorConfig, ProcessorConfig,
@ -12,11 +12,11 @@ import type { BaseModelType, Invocation } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
export const addControlAdapters = ( export const addControlAdapters = (
controlAdapters: ControlAdapterData[], controlAdapters: ControlAdapterEntity[],
g: Graph, g: Graph,
denoise: Invocation<'denoise_latents'>, denoise: Invocation<'denoise_latents'>,
base: BaseModelType base: BaseModelType
): ControlAdapterData[] => { ): ControlAdapterEntity[] => {
const validControlAdapters = controlAdapters.filter((ca) => isValidControlAdapter(ca, base)); const validControlAdapters = controlAdapters.filter((ca) => isValidControlAdapter(ca, base));
for (const ca of validControlAdapters) { for (const ca of validControlAdapters) {
if (ca.adapterType === 'controlnet') { if (ca.adapterType === 'controlnet') {
@ -122,7 +122,7 @@ const buildControlImage = (
assert(false, 'Attempted to add unprocessed control image'); assert(false, 'Attempted to add unprocessed control image');
}; };
const isValidControlAdapter = (ca: ControlAdapterData, base: BaseModelType): boolean => { const isValidControlAdapter = (ca: ControlAdapterEntity, base: BaseModelType): boolean => {
// Must be have a model that matches the current base and must have a control image // Must be have a model that matches the current base and must have a control image
const hasModel = Boolean(ca.model); const hasModel = Boolean(ca.model);
const modelMatchesBase = ca.model?.base === base; const modelMatchesBase = ca.model?.base === base;

View File

@ -1,15 +1,15 @@
import type { IPAdapterData } from 'features/controlLayers/store/types'; import type { IPAdapterEntity } from 'features/controlLayers/store/types';
import { IP_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants'; import { IP_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { BaseModelType, Invocation } from 'services/api/types'; import type { BaseModelType, Invocation } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
export const addIPAdapters = ( export const addIPAdapters = (
ipAdapters: IPAdapterData[], ipAdapters: IPAdapterEntity[],
g: Graph, g: Graph,
denoise: Invocation<'denoise_latents'>, denoise: Invocation<'denoise_latents'>,
base: BaseModelType base: BaseModelType
): IPAdapterData[] => { ): IPAdapterEntity[] => {
const validIPAdapters = ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base)); const validIPAdapters = ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base));
for (const ipa of validIPAdapters) { for (const ipa of validIPAdapters) {
addIPAdapter(ipa, g, denoise); addIPAdapter(ipa, g, denoise);
@ -33,7 +33,7 @@ export const addIPAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise
} }
}; };
const addIPAdapter = (ipa: IPAdapterData, g: Graph, denoise: Invocation<'denoise_latents'>) => { const addIPAdapter = (ipa: IPAdapterEntity, g: Graph, denoise: Invocation<'denoise_latents'>) => {
const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipa; const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipa;
assert(image, 'IP Adapter image is required'); assert(image, 'IP Adapter image is required');
assert(model, 'IP Adapter model is required'); assert(model, 'IP Adapter model is required');
@ -55,7 +55,7 @@ const addIPAdapter = (ipa: IPAdapterData, g: Graph, denoise: Invocation<'denoise
g.addEdge(ipAdapter, 'ip_adapter', ipAdapterCollect, 'item'); g.addEdge(ipAdapter, 'ip_adapter', ipAdapterCollect, 'item');
}; };
export const isValidIPAdapter = (ipa: IPAdapterData, base: BaseModelType): boolean => { export const isValidIPAdapter = (ipa: IPAdapterEntity, base: BaseModelType): boolean => {
// Must be have a model that matches the current base and must have a control image // Must be have a model that matches the current base and must have a control image
const hasModel = Boolean(ipa.model); const hasModel = Boolean(ipa.model);
const modelMatchesBase = ipa.model?.base === base; const modelMatchesBase = ipa.model?.base === base;

View File

@ -5,7 +5,7 @@ import { RG_LAYER_NAME } from 'features/controlLayers/konva/naming';
import { renderers } from 'features/controlLayers/konva/renderers/layers'; import { renderers } from 'features/controlLayers/konva/renderers/layers';
import { blobToDataURL } from "features/controlLayers/konva/util"; import { blobToDataURL } from "features/controlLayers/konva/util";
import { rgMaskImageUploaded } from 'features/controlLayers/store/canvasV2Slice'; import { rgMaskImageUploaded } from 'features/controlLayers/store/canvasV2Slice';
import type { Dimensions, IPAdapterData, RegionalGuidanceData } from 'features/controlLayers/store/types'; import type { Dimensions, IPAdapterEntity, RegionEntity } from 'features/controlLayers/store/types';
import { import {
PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX, PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX,
PROMPT_REGION_MASK_TO_TENSOR_PREFIX, PROMPT_REGION_MASK_TO_TENSOR_PREFIX,
@ -36,7 +36,7 @@ import { assert } from 'tsafe';
*/ */
export const addRegions = async ( export const addRegions = async (
regions: RegionalGuidanceData[], regions: RegionEntity[],
g: Graph, g: Graph,
documentSize: Dimensions, documentSize: Dimensions,
bbox: IRect, bbox: IRect,
@ -46,7 +46,7 @@ export const addRegions = async (
negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>, negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
posCondCollect: Invocation<'collect'>, posCondCollect: Invocation<'collect'>,
negCondCollect: Invocation<'collect'> negCondCollect: Invocation<'collect'>
): Promise<RegionalGuidanceData[]> => { ): Promise<RegionEntity[]> => {
const isSDXL = base === 'sdxl'; const isSDXL = base === 'sdxl';
const validRegions = regions.filter((rg) => isValidRegion(rg, base)); const validRegions = regions.filter((rg) => isValidRegion(rg, base));
@ -186,7 +186,7 @@ export const addRegions = async (
} }
} }
const validRGIPAdapters: IPAdapterData[] = rg.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base)); const validRGIPAdapters: IPAdapterEntity[] = rg.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base));
for (const ipa of validRGIPAdapters) { for (const ipa of validRGIPAdapters) {
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise); const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
@ -218,13 +218,13 @@ export const addRegions = async (
return validRegions; return validRegions;
}; };
export const isValidRegion = (rg: RegionalGuidanceData, base: BaseModelType) => { export const isValidRegion = (rg: RegionEntity, base: BaseModelType) => {
const hasTextPrompt = Boolean(rg.positivePrompt || rg.negativePrompt); const hasTextPrompt = Boolean(rg.positivePrompt || rg.negativePrompt);
const hasIPAdapter = rg.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base)).length > 0; const hasIPAdapter = rg.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base)).length > 0;
return hasTextPrompt || hasIPAdapter; return hasTextPrompt || hasIPAdapter;
}; };
export const getMaskImage = async (rg: RegionalGuidanceData, blob: Blob): Promise<ImageDTO> => { export const getMaskImage = async (rg: RegionEntity, blob: Blob): Promise<ImageDTO> => {
const { id, imageCache } = rg; const { id, imageCache } = rg;
if (imageCache) { if (imageCache) {
const imageDTO = await getImageDTO(imageCache.name); const imageDTO = await getImageDTO(imageCache.name);
@ -253,7 +253,7 @@ export const getMaskImage = async (rg: RegionalGuidanceData, blob: Blob): Promis
*/ */
export const getRGMaskBlobs = async ( export const getRGMaskBlobs = async (
regions: RegionalGuidanceData[], regions: RegionEntity[],
documentSize: Dimensions, documentSize: Dimensions,
bbox: IRect, bbox: IRect,
preview: boolean = false preview: boolean = false