feat(ui): clean up state, add mutex for image loading, add thumbnail loading

This commit is contained in:
psychedelicious 2024-08-07 17:20:18 +10:00
parent 6b385614f0
commit a27d39b9ff
15 changed files with 96 additions and 96 deletions

View File

@ -41,7 +41,7 @@ const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: Im
const deleteControlAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
state.canvasV2.controlAdapters.entities.forEach(({ id, imageObject, processedImageObject }) => {
if (imageObject?.image.name === imageDTO.image_name || processedImageObject?.image.name === imageDTO.image_name) {
if (imageObject?.image.image_name === imageDTO.image_name || processedImageObject?.image.image_name === imageDTO.image_name) {
dispatch(caImageChanged({ id, imageDTO: null }));
dispatch(caProcessedImageChanged({ id, imageDTO: null }));
}
@ -50,7 +50,7 @@ const deleteControlAdapterImages = (state: RootState, dispatch: AppDispatch, ima
const deleteIPAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
state.canvasV2.ipAdapters.entities.forEach(({ id, imageObject }) => {
if (imageObject?.image.name === imageDTO.image_name) {
if (imageObject?.image.image_name === imageDTO.image_name) {
dispatch(ipaImageChanged({ id, imageDTO: null }));
}
});
@ -60,7 +60,7 @@ const deleteLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: Im
state.canvasV2.layers.entities.forEach(({ id, objects }) => {
let shouldDelete = false;
for (const obj of objects) {
if (obj.type === 'image' && obj.image.name === imageDTO.image_name) {
if (obj.type === 'image' && obj.image.image_name === imageDTO.image_name) {
shouldDelete = true;
break;
}

View File

@ -47,10 +47,10 @@ export const CAImagePreview = memo(
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(
controlAdapter.imageObject?.image.name ?? skipToken
controlAdapter.imageObject?.image.image_name ?? skipToken
);
const { currentData: processedControlImage, isError: isErrorProcessedControlImage } = useGetImageDTOQuery(
controlAdapter.processedImageObject?.image.name ?? skipToken
controlAdapter.processedImageObject?.image.image_name ?? skipToken
);
const [changeIsIntermediate] = useChangeImageIsIntermediateMutation();

View File

@ -29,7 +29,7 @@ export const IPAImagePreview = memo(({ image, onChangeImage, ipAdapterId, droppa
const optimalDimension = useAppSelector(selectOptimalDimension);
const shift = useShiftModifier();
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(image?.name ?? skipToken);
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(image?.image_name ?? skipToken);
const handleResetControlImage = useCallback(() => {
onChangeImage(null);
}, [onChangeImage]);

View File

@ -72,7 +72,7 @@ export class CanvasControlAdapter extends CanvasEntity {
this.image = new CanvasImageRenderer(imageObject, this);
this.updateGroup(true);
this.konva.objectGroup.add(this.image.konva.group);
await this.image.updateImageSource(imageObject.image.name);
await this.image.updateImageSource(imageObject.image.image_name);
} else if (!this.image.isLoading && !this.image.isError) {
if (await this.image.update(imageObject)) {
didDraw = true;

View File

@ -1,3 +1,4 @@
import { Mutex } from 'async-mutex';
import { deepClone } from 'common/util/deepClone';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { CanvasObjectRenderer } from 'features/controlLayers/konva/CanvasObjectRenderer';
@ -29,12 +30,15 @@ export class CanvasImageRenderer {
placeholder: { group: Konva.Group; rect: Konva.Rect; text: Konva.Text };
image: Konva.Image | null; // The image is loaded asynchronously, so it may not be available immediately
};
imageName: string | null;
isLoading: boolean;
isError: boolean;
thumbnailElement: HTMLImageElement | null = null;
imageElement: HTMLImageElement | null = null;
isLoading: boolean = false;
isError: boolean = false;
mutex = new Mutex();
constructor(state: CanvasImageState, parent: CanvasObjectRenderer) {
const { id, width, height, x, y } = state;
const { id, image } = state;
const { width, height } = image;
this.id = id;
this.parent = parent;
this.manager = parent.manager;
@ -44,7 +48,7 @@ export class CanvasImageRenderer {
this.log.trace({ state }, 'Creating image');
this.konva = {
group: new Konva.Group({ name: CanvasImageRenderer.GROUP_NAME, listening: false, x, y }),
group: new Konva.Group({ name: CanvasImageRenderer.GROUP_NAME, listening: false }),
placeholder: {
group: new Konva.Group({ name: CanvasImageRenderer.PLACEHOLDER_GROUP_NAME, listening: false }),
rect: new Konva.Rect({
@ -73,10 +77,6 @@ export class CanvasImageRenderer {
this.konva.placeholder.group.add(this.konva.placeholder.rect);
this.konva.placeholder.group.add(this.konva.placeholder.text);
this.konva.group.add(this.konva.placeholder.group);
this.imageName = null;
this.isLoading = false;
this.isError = false;
this.state = state;
}
@ -94,22 +94,50 @@ export class CanvasImageRenderer {
const imageDTO = await getImageDTO(imageName);
if (imageDTO === null) {
this.log.error({ imageName }, 'Image not found');
this.onFailedToLoadImage();
return;
}
const imageEl = await loadImage(imageDTO.image_url);
loadImage(imageDTO.thumbnail_url)
.then((thumbnailElement) => {
this.thumbnailElement = thumbnailElement;
this.mutex.runExclusive(this.updateImageElement);
})
.catch(this.onFailedToLoadImage);
loadImage(imageDTO.image_url)
.then((imageElement) => {
this.imageElement = imageElement;
this.mutex.runExclusive(this.updateImageElement);
})
.catch(this.onFailedToLoadImage);
} catch {
this.onFailedToLoadImage();
}
};
if (this.konva.image) {
onFailedToLoadImage = () => {
this.log({ image: this.state.image }, 'Failed to load image');
this.konva.image?.visible(false);
this.isLoading = false;
this.isError = true;
this.konva.placeholder.text.text(t('common.imageFailedToLoad', 'Image Failed to Load'));
this.konva.placeholder.group.visible(true);
};
updateImageElement = () => {
const element = this.imageElement ?? this.thumbnailElement;
if (element) {
if (this.konva.image && this.konva.image.image() !== element) {
this.konva.image.setAttrs({
image: imageEl,
image: element,
});
} else {
this.konva.image = new Konva.Image({
name: CanvasImageRenderer.IMAGE_NAME,
listening: false,
image: imageEl,
width: this.state.width,
height: this.state.height,
image: element,
width: this.state.image.width,
height: this.state.image.height,
});
this.konva.group.add(this.konva.image);
}
@ -122,18 +150,9 @@ export class CanvasImageRenderer {
this.konva.image.filters([]);
}
this.imageName = imageName;
this.isLoading = false;
this.isError = false;
this.konva.placeholder.group.visible(false);
} catch {
this.log({ imageName }, 'Failed to load image');
this.konva.image?.visible(false);
this.imageName = null;
this.isLoading = false;
this.isError = true;
this.konva.placeholder.text.text(t('common.imageFailedToLoad', 'Image Failed to Load'));
this.konva.placeholder.group.visible(true);
}
};
@ -141,11 +160,12 @@ export class CanvasImageRenderer {
if (force || this.state !== state) {
this.log.trace({ state }, 'Updating image');
const { width, height, x, y, image, filters } = state;
if (force || (this.state.image.name !== image.name && !this.isLoading)) {
await this.updateImageSource(image.name);
const { image, filters } = state;
const { width, height, image_name } = image;
if (force || (this.state.image.image_name !== image_name && !this.isLoading)) {
await this.updateImageSource(image_name);
}
this.konva.image?.setAttrs({ x, y, width, height });
this.konva.image?.setAttrs({ width, height });
if (filters.length > 0) {
this.konva.image?.cache();
this.konva.image?.filters(filters.map((f) => FILTER_MAP[f]));
@ -177,7 +197,6 @@ export class CanvasImageRenderer {
id: this.id,
type: CanvasImageRenderer.TYPE,
parent: this.parent.id,
imageName: this.imageName,
isLoading: this.isLoading,
isError: this.isError,
state: deepClone(this.state),

View File

@ -685,7 +685,7 @@ export class CanvasManager {
const region = this.getEntity({ id, type: 'regional_guidance' });
assert(region?.type === 'regional_guidance');
if (region.state.imageCache) {
const imageDTO = await getImageDTO(region.state.imageCache.name);
const imageDTO = await getImageDTO(region.state.imageCache);
if (imageDTO) {
return imageDTO;
}

View File

@ -25,7 +25,7 @@ export class CanvasRectRenderer {
isFirstRender: boolean = false;
constructor(state: CanvasRectState, parent: CanvasObjectRenderer) {
const { id, x, y, width, height, color } = state;
const { id, rect, color } = state;
this.id = id;
this.parent = parent;
this.manager = parent.manager;
@ -37,10 +37,7 @@ export class CanvasRectRenderer {
group: new Konva.Group({ name: CanvasRectRenderer.GROUP_NAME, listening: false }),
rect: new Konva.Rect({
name: CanvasRectRenderer.RECT_NAME,
x,
y,
width,
height,
...rect,
listening: false,
fill: rgbaColorToString(color),
}),
@ -54,12 +51,9 @@ export class CanvasRectRenderer {
this.isFirstRender = false;
this.log.trace({ state }, 'Updating rect');
const { x, y, width, height, color } = state;
const { rect, color } = state;
this.konva.rect.setAttrs({
x,
y,
width,
height,
...rect,
fill: rgbaColorToString(color),
});
this.state = state;

View File

@ -53,7 +53,7 @@ export class CanvasStagingArea {
height,
filters: [],
image: {
name: image_name,
image_name: image_name,
width,
height,
},

View File

@ -261,10 +261,7 @@ export const setStageEventHandlers = (manager: CanvasManager): (() => void) => {
await selectedEntity.adapter.renderer.setBuffer({
id: getObjectId('rect', true),
type: 'rect',
x: Math.round(normalizedPoint.x),
y: Math.round(normalizedPoint.y),
width: 0,
height: 0,
rect: { x: Math.round(normalizedPoint.x), y: Math.round(normalizedPoint.y), width: 0, height: 0 },
color: getCurrentFill(),
});
}
@ -407,8 +404,8 @@ export const setStageEventHandlers = (manager: CanvasManager): (() => void) => {
if (drawingBuffer) {
if (drawingBuffer.type === 'rect') {
const normalizedPoint = offsetCoord(pos, selectedEntity.state.position);
drawingBuffer.width = Math.round(normalizedPoint.x - drawingBuffer.x);
drawingBuffer.height = Math.round(normalizedPoint.y - drawingBuffer.y);
drawingBuffer.rect.width = Math.round(normalizedPoint.x - drawingBuffer.rect.x);
drawingBuffer.rect.height = Math.round(normalizedPoint.y - drawingBuffer.rect.y);
await selectedEntity.adapter.renderer.setBuffer(drawingBuffer);
} else {
await selectedEntity.adapter.renderer.clearBuffer();
@ -447,8 +444,8 @@ export const setStageEventHandlers = (manager: CanvasManager): (() => void) => {
await selectedEntity.adapter.renderer.setBuffer(drawingBuffer);
await selectedEntity.adapter.renderer.commitBuffer();
} else if (toolState.selected === 'rect' && drawingBuffer?.type === 'rect') {
drawingBuffer.width = Math.round(normalizedPoint.x - drawingBuffer.x);
drawingBuffer.height = Math.round(normalizedPoint.y - drawingBuffer.y);
drawingBuffer.rect.width = Math.round(normalizedPoint.x - drawingBuffer.rect.x);
drawingBuffer.rect.height = Math.round(normalizedPoint.y - drawingBuffer.rect.y);
await selectedEntity.adapter.renderer.setBuffer(drawingBuffer);
await selectedEntity.adapter.renderer.commitBuffer();
}

View File

@ -43,7 +43,7 @@ import { z } from 'zod';
export const zId = z.string().min(1);
export const zImageWithDims = z.object({
name: z.string(),
image_name: z.string(),
width: z.number().int().positive(),
height: z.number().int().positive(),
});
@ -248,7 +248,7 @@ export const CA_PROCESSOR_DATA: CAProcessorsData = {
buildNode: (image, config) => ({
...config,
type: 'canny_image_processor',
image: { image_name: image.name },
image: { image_name: image.image_name },
detect_resolution: minDim(image),
image_resolution: minDim(image),
}),
@ -265,7 +265,7 @@ export const CA_PROCESSOR_DATA: CAProcessorsData = {
buildNode: (image, config) => ({
...config,
type: 'color_map_image_processor',
image: { image_name: image.name },
image: { image_name: image.image_name },
}),
},
content_shuffle_image_processor: {
@ -281,7 +281,7 @@ export const CA_PROCESSOR_DATA: CAProcessorsData = {
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.name },
image: { image_name: image.image_name },
detect_resolution: minDim(image),
image_resolution: minDim(image),
}),
@ -297,7 +297,7 @@ export const CA_PROCESSOR_DATA: CAProcessorsData = {
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.name },
image: { image_name: image.image_name },
resolution: minDim(image),
}),
},
@ -312,7 +312,7 @@ export const CA_PROCESSOR_DATA: CAProcessorsData = {
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.name },
image: { image_name: image.image_name },
detect_resolution: minDim(image),
image_resolution: minDim(image),
}),
@ -327,7 +327,7 @@ export const CA_PROCESSOR_DATA: CAProcessorsData = {
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.name },
image: { image_name: image.image_name },
detect_resolution: minDim(image),
image_resolution: minDim(image),
}),
@ -343,7 +343,7 @@ export const CA_PROCESSOR_DATA: CAProcessorsData = {
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.name },
image: { image_name: image.image_name },
detect_resolution: minDim(image),
image_resolution: minDim(image),
}),
@ -360,7 +360,7 @@ export const CA_PROCESSOR_DATA: CAProcessorsData = {
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.name },
image: { image_name: image.image_name },
detect_resolution: minDim(image),
image_resolution: minDim(image),
}),
@ -377,7 +377,7 @@ export const CA_PROCESSOR_DATA: CAProcessorsData = {
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.name },
image: { image_name: image.image_name },
detect_resolution: minDim(image),
image_resolution: minDim(image),
}),
@ -394,7 +394,7 @@ export const CA_PROCESSOR_DATA: CAProcessorsData = {
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.name },
image: { image_name: image.image_name },
detect_resolution: minDim(image),
image_resolution: minDim(image),
}),
@ -409,7 +409,7 @@ export const CA_PROCESSOR_DATA: CAProcessorsData = {
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.name },
image: { image_name: image.image_name },
detect_resolution: minDim(image),
image_resolution: minDim(image),
}),
@ -427,7 +427,7 @@ export const CA_PROCESSOR_DATA: CAProcessorsData = {
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.name },
image: { image_name: image.image_name },
image_resolution: minDim(image),
}),
},
@ -443,7 +443,7 @@ export const CA_PROCESSOR_DATA: CAProcessorsData = {
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.name },
image: { image_name: image.image_name },
detect_resolution: minDim(image),
image_resolution: minDim(image),
}),
@ -458,7 +458,7 @@ export const CA_PROCESSOR_DATA: CAProcessorsData = {
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.name },
image: { image_name: image.image_name },
}),
},
};
@ -548,10 +548,7 @@ export type CanvasEraserLineState = z.infer<typeof zCanvasEraserLineState>;
const zCanvasRectState = z.object({
id: zId,
type: z.literal('rect'),
x: z.number(),
y: z.number(),
width: z.number().min(1),
height: z.number().min(1),
rect: zRect,
color: zRgbaColor,
});
export type CanvasRectState = z.infer<typeof zCanvasRectState>;
@ -563,10 +560,6 @@ const zCanvasImageState = z.object({
id: zId,
type: z.literal('image'),
image: zImageWithDims,
x: z.number(),
y: z.number(),
width: z.number().min(1),
height: z.number().min(1),
filters: z.array(zFilter),
});
export type CanvasImageState = z.infer<typeof zCanvasImageState>;
@ -589,6 +582,7 @@ export const zCanvasLayerState = z.object({
position: zCoordinate,
opacity: zOpacity,
objects: z.array(zCanvasObjectState),
imageCache: z.string().min(1).nullable(),
});
export type CanvasLayerState = z.infer<typeof zCanvasLayerState>;
@ -661,7 +655,7 @@ export const zCanvasRegionalGuidanceState = z.object({
negativePrompt: zParameterNegativePrompt.nullable(),
ipAdapters: z.array(zCanvasIPAdapterState),
autoNegative: zAutoNegative,
imageCache: zImageWithDims.nullable(),
imageCache: z.string().min(1).nullable(),
});
export type CanvasRegionalGuidanceState = z.infer<typeof zCanvasRegionalGuidanceState>;
@ -681,7 +675,7 @@ const zCanvasInpaintMaskState = z.object({
position: zCoordinate,
fill: zRgbColor,
objects: z.array(zCanvasObjectState),
imageCache: zImageWithDims.nullable(),
imageCache: z.string().min(1).nullable(),
});
export type CanvasInpaintMaskState = z.infer<typeof zCanvasInpaintMaskState>;
@ -773,7 +767,7 @@ export const buildControlAdapterProcessorV2 = (
};
export const imageDTOToImageWithDims = ({ image_name, width, height }: ImageDTO): ImageWithDims => ({
name: image_name,
image_name,
width,
height,
});
@ -783,13 +777,9 @@ export const imageDTOToImageObject = (imageDTO: ImageDTO, overrides?: Partial<Ca
return {
id: getObjectId('image'),
type: 'image',
x: 0,
y: 0,
width,
height,
filters: [],
image: {
name: image_name,
image_name,
width,
height,
},

View File

@ -12,7 +12,7 @@ import type { ImageUsage } from './types';
export const getImageUsage = (nodes: NodesState, canvasV2: CanvasV2State, image_name: string) => {
const isLayerImage = canvasV2.layers.entities.some((layer) =>
layer.objects.some((obj) => obj.type === 'image' && obj.image.name === image_name)
layer.objects.some((obj) => obj.type === 'image' && obj.image.image_name === image_name)
);
const isNodesImage = nodes.nodes
@ -22,10 +22,10 @@ export const getImageUsage = (nodes: NodesState, canvasV2: CanvasV2State, image_
);
const isControlAdapterImage = canvasV2.controlAdapters.entities.some(
(ca) => ca.imageObject?.image.name === image_name || ca.processedImageObject?.image.name === image_name
(ca) => ca.imageObject?.image.image_name === image_name || ca.processedImageObject?.image.image_name === image_name
);
const isIPAdapterImage = canvasV2.ipAdapters.entities.some((ipa) => ipa.imageObject?.image.name === image_name);
const isIPAdapterImage = canvasV2.ipAdapters.entities.some((ipa) => ipa.imageObject?.image.image_name === image_name);
const imageUsage: ImageUsage = {
isLayerImage,

View File

@ -334,7 +334,7 @@ const recallLayer: MetadataRecallFunc<CanvasLayerState> = async (layer) => {
const invalidObjects: string[] = [];
for (const obj of clone.objects) {
if (obj.type === 'image') {
const imageDTO = await getImageDTO(obj.image.name);
const imageDTO = await getImageDTO(obj.image.image_name);
if (!imageDTO) {
invalidObjects.push(obj.id);
}

View File

@ -129,12 +129,12 @@ const buildControlImage = (
if (processedImage && processorConfig) {
// We've processed the image in the app - use it for the control image.
return {
image_name: processedImage.name,
image_name: processedImage.image_name,
};
} else if (image) {
// No processor selected, and we have an image - the user provided a processed image, use it for the control image.
return {
image_name: image.name,
image_name: image.image_name,
};
}
assert(false, 'Attempted to add unprocessed control image');

View File

@ -49,7 +49,7 @@ const addIPAdapter = (ipa: CanvasIPAdapterState, g: Graph, denoise: Invocation<'
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
image: {
image_name: imageObject.image.name,
image_name: imageObject.image.image_name,
},
});
g.addEdge(ipAdapter, 'ip_adapter', ipAdapterCollect, 'item');

View File

@ -191,7 +191,7 @@ export const addRegions = async (
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
image: {
image_name: imageObject.image.name,
image_name: imageObject.image.image_name,
},
});