Merge branch 'main' into fix/lora_node_inputs_definition

This commit is contained in:
blessedcoolant 2023-09-02 13:38:05 +12:00 committed by GitHub
commit 2c754cfce7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 44 additions and 36 deletions

View File

@ -110,7 +110,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
); );
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery( const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
lastSelectedImage?.image_name ?? skipToken, lastSelectedImage ?? skipToken,
{ {
selectFromResult: (res) => ({ selectFromResult: (res) => ({
isLoading: res.isFetching, isLoading: res.isFetching,

View File

@ -52,7 +52,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled; const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery( const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
imageDTO.image_name, imageDTO,
{ {
selectFromResult: (res) => ({ selectFromResult: (res) => ({
isLoading: res.isFetching, isLoading: res.isFetching,

View File

@ -27,15 +27,12 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
// dispatch(setShouldShowImageDetails(false)); // dispatch(setShouldShowImageDetails(false));
// }); // });
const { metadata, workflow } = useGetImageMetadataFromFileQuery( const { metadata, workflow } = useGetImageMetadataFromFileQuery(image, {
image.image_name,
{
selectFromResult: (res) => ({ selectFromResult: (res) => ({
metadata: res?.currentData?.metadata, metadata: res?.currentData?.metadata,
workflow: res?.currentData?.workflow, workflow: res?.currentData?.workflow,
}), }),
} });
);
return ( return (
<Flex <Flex

View File

@ -1,4 +1,3 @@
import { store } from 'app/store/store';
import { import {
SchedulerParam, SchedulerParam,
zBaseModel, zBaseModel,
@ -10,7 +9,6 @@ import { keyBy } from 'lodash-es';
import { OpenAPIV3 } from 'openapi-types'; import { OpenAPIV3 } from 'openapi-types';
import { RgbaColor } from 'react-colorful'; import { RgbaColor } from 'react-colorful';
import { Node } from 'reactflow'; import { Node } from 'reactflow';
import { JsonObject } from 'type-fest';
import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types'; import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types';
import { import {
AnyInvocationType, AnyInvocationType,
@ -18,6 +16,7 @@ import {
ProgressImage, ProgressImage,
} from 'services/events/types'; } from 'services/events/types';
import { O } from 'ts-toolbelt'; import { O } from 'ts-toolbelt';
import { JsonObject } from 'type-fest';
import { z } from 'zod'; import { z } from 'zod';
export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>; export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>;
@ -936,22 +935,10 @@ export const zWorkflow = z.object({
}); });
export const zValidatedWorkflow = zWorkflow.transform((workflow) => { export const zValidatedWorkflow = zWorkflow.transform((workflow) => {
const nodeTemplates = store.getState().nodes.nodeTemplates;
const { nodes, edges } = workflow; const { nodes, edges } = workflow;
const warnings: WorkflowWarning[] = []; const warnings: WorkflowWarning[] = [];
const invocationNodes = nodes.filter(isWorkflowInvocationNode); const invocationNodes = nodes.filter(isWorkflowInvocationNode);
const keyedNodes = keyBy(invocationNodes, 'id'); const keyedNodes = keyBy(invocationNodes, 'id');
invocationNodes.forEach((node, i) => {
const nodeTemplate = nodeTemplates[node.data.type];
if (!nodeTemplate) {
warnings.push({
message: `Node "${node.data.label || node.data.id}" skipped`,
issues: [`Unable to find template for type "${node.data.type}"`],
data: node,
});
delete nodes[i];
}
});
edges.forEach((edge, i) => { edges.forEach((edge, i) => {
const sourceNode = keyedNodes[edge.source]; const sourceNode = keyedNodes[edge.source];
const targetNode = keyedNodes[edge.target]; const targetNode = keyedNodes[edge.target];

View File

@ -28,6 +28,8 @@ import {
} from '../util'; } from '../util';
import { boardsApi } from './boards'; import { boardsApi } from './boards';
import { ImageMetadataAndWorkflow } from 'features/nodes/types/types'; import { ImageMetadataAndWorkflow } from 'features/nodes/types/types';
import { fetchBaseQuery } from '@reduxjs/toolkit/dist/query';
import { $authToken, $projectId } from '../client';
export const imagesApi = api.injectEndpoints({ export const imagesApi = api.injectEndpoints({
endpoints: (build) => ({ endpoints: (build) => ({
@ -115,18 +117,40 @@ export const imagesApi = api.injectEndpoints({
], ],
keepUnusedDataFor: 86400, // 24 hours keepUnusedDataFor: 86400, // 24 hours
}), }),
getImageMetadataFromFile: build.query<ImageMetadataAndWorkflow, string>({ getImageMetadataFromFile: build.query<ImageMetadataAndWorkflow, ImageDTO>({
query: (image_name) => ({ queryFn: async (args: ImageDTO, api, extraOptions) => {
url: `images/i/${image_name}/full`, const authToken = $authToken.get();
const projectId = $projectId.get();
const customBaseQuery = fetchBaseQuery({
baseUrl: '',
prepareHeaders: (headers) => {
if (authToken) {
headers.set('Authorization', `Bearer ${authToken}`);
}
if (projectId) {
headers.set('project-id', projectId);
}
return headers;
},
responseHandler: async (res) => { responseHandler: async (res) => {
return await res.blob(); return await res.blob();
}, },
}), });
providesTags: (result, error, image_name) => [
{ type: 'ImageMetadataFromFile', id: image_name }, const response = await customBaseQuery(
args.image_url,
api,
extraOptions
);
const data = await getMetadataAndWorkflowFromImageBlob(
response.data as Blob
);
return { data };
},
providesTags: (result, error, image_dto) => [
{ type: 'ImageMetadataFromFile', id: image_dto.image_name },
], ],
transformResponse: (response: Blob) =>
getMetadataAndWorkflowFromImageBlob(response),
keepUnusedDataFor: 86400, // 24 hours keepUnusedDataFor: 86400, // 24 hours
}), }),
clearIntermediates: build.mutation<number, void>({ clearIntermediates: build.mutation<number, void>({