mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lama-infill
This commit is contained in:
commit
ec09e21fc2
@ -279,8 +279,8 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
crop_left: int = InputField(default=0, description="")
|
||||
target_width: int = InputField(default=1024, description="")
|
||||
target_height: int = InputField(default=1024, description="")
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
|
@ -72,10 +72,10 @@ class CoreMetadata(BaseModelExcludeNull):
|
||||
)
|
||||
refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner")
|
||||
refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner")
|
||||
refiner_positive_aesthetic_store: Optional[float] = Field(
|
||||
refiner_positive_aesthetic_score: Optional[float] = Field(
|
||||
default=None, description="The aesthetic score used for the refiner"
|
||||
)
|
||||
refiner_negative_aesthetic_store: Optional[float] = Field(
|
||||
refiner_negative_aesthetic_score: Optional[float] = Field(
|
||||
default=None, description="The aesthetic score used for the refiner"
|
||||
)
|
||||
refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising")
|
||||
@ -160,11 +160,11 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
default=None,
|
||||
description="The scheduler used for the refiner",
|
||||
)
|
||||
refiner_positive_aesthetic_store: Optional[float] = InputField(
|
||||
refiner_positive_aesthetic_score: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The aesthetic score used for the refiner",
|
||||
)
|
||||
refiner_negative_aesthetic_store: Optional[float] = InputField(
|
||||
refiner_negative_aesthetic_score: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The aesthetic score used for the refiner",
|
||||
)
|
||||
|
@ -250,13 +250,13 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
|
||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = Field(
|
||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNET"
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
||||
)
|
||||
clip: Optional[ClipField] = Field(
|
||||
clip: Optional[ClipField] = InputField(
|
||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1"
|
||||
)
|
||||
clip2: Optional[ClipField] = Field(
|
||||
clip2: Optional[ClipField] = InputField(
|
||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2"
|
||||
)
|
||||
|
||||
|
@ -50,6 +50,7 @@ class ModelProbe(object):
|
||||
"StableDiffusionInpaintPipeline": ModelType.Main,
|
||||
"StableDiffusionXLPipeline": ModelType.Main,
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.Vae,
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
}
|
||||
|
@ -110,7 +110,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
);
|
||||
|
||||
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
||||
lastSelectedImage?.image_name ?? skipToken,
|
||||
lastSelectedImage ?? skipToken,
|
||||
{
|
||||
selectFromResult: (res) => ({
|
||||
isLoading: res.isFetching,
|
||||
|
@ -52,7 +52,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||
|
||||
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
||||
imageDTO.image_name,
|
||||
imageDTO,
|
||||
{
|
||||
selectFromResult: (res) => ({
|
||||
isLoading: res.isFetching,
|
||||
|
@ -101,13 +101,15 @@ const ImageMetadataActions = (props: Props) => {
|
||||
onClick={handleRecallSeed}
|
||||
/>
|
||||
)}
|
||||
{metadata.model !== undefined && metadata.model !== null && (
|
||||
<ImageMetadataItem
|
||||
label="Model"
|
||||
value={metadata.model.model_name}
|
||||
onClick={handleRecallModel}
|
||||
/>
|
||||
)}
|
||||
{metadata.model !== undefined &&
|
||||
metadata.model !== null &&
|
||||
metadata.model.model_name && (
|
||||
<ImageMetadataItem
|
||||
label="Model"
|
||||
value={metadata.model.model_name}
|
||||
onClick={handleRecallModel}
|
||||
/>
|
||||
)}
|
||||
{metadata.width && (
|
||||
<ImageMetadataItem
|
||||
label="Width"
|
||||
|
@ -27,15 +27,12 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
||||
// dispatch(setShouldShowImageDetails(false));
|
||||
// });
|
||||
|
||||
const { metadata, workflow } = useGetImageMetadataFromFileQuery(
|
||||
image.image_name,
|
||||
{
|
||||
selectFromResult: (res) => ({
|
||||
metadata: res?.currentData?.metadata,
|
||||
workflow: res?.currentData?.workflow,
|
||||
}),
|
||||
}
|
||||
);
|
||||
const { metadata, workflow } = useGetImageMetadataFromFileQuery(image, {
|
||||
selectFromResult: (res) => ({
|
||||
metadata: res?.currentData?.metadata,
|
||||
workflow: res?.currentData?.workflow,
|
||||
}),
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex
|
||||
|
@ -1,8 +1,9 @@
|
||||
import { store } from 'app/store/store';
|
||||
import {
|
||||
SchedulerParam,
|
||||
zBaseModel,
|
||||
zMainModel,
|
||||
zMainOrOnnxModel,
|
||||
zOnnxModel,
|
||||
zSDXLRefinerModel,
|
||||
zScheduler,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
@ -10,7 +11,6 @@ import { keyBy } from 'lodash-es';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import { RgbaColor } from 'react-colorful';
|
||||
import { Node } from 'reactflow';
|
||||
import { JsonObject } from 'type-fest';
|
||||
import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types';
|
||||
import {
|
||||
AnyInvocationType,
|
||||
@ -18,6 +18,7 @@ import {
|
||||
ProgressImage,
|
||||
} from 'services/events/types';
|
||||
import { O } from 'ts-toolbelt';
|
||||
import { JsonObject } from 'type-fest';
|
||||
import { z } from 'zod';
|
||||
|
||||
export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>;
|
||||
@ -770,12 +771,14 @@ export const zCoreMetadata = z
|
||||
steps: z.number().int().nullish(),
|
||||
scheduler: z.string().nullish(),
|
||||
clip_skip: z.number().int().nullish(),
|
||||
model: zMainOrOnnxModel.nullish(),
|
||||
controlnets: z.array(zControlField).nullish(),
|
||||
model: z
|
||||
.union([zMainModel.deepPartial(), zOnnxModel.deepPartial()])
|
||||
.nullish(),
|
||||
controlnets: z.array(zControlField.deepPartial()).nullish(),
|
||||
loras: z
|
||||
.array(
|
||||
z.object({
|
||||
lora: zLoRAModelField,
|
||||
lora: zLoRAModelField.deepPartial(),
|
||||
weight: z.number(),
|
||||
})
|
||||
)
|
||||
@ -785,15 +788,15 @@ export const zCoreMetadata = z
|
||||
init_image: z.string().nullish(),
|
||||
positive_style_prompt: z.string().nullish(),
|
||||
negative_style_prompt: z.string().nullish(),
|
||||
refiner_model: zSDXLRefinerModel.nullish(),
|
||||
refiner_model: zSDXLRefinerModel.deepPartial().nullish(),
|
||||
refiner_cfg_scale: z.number().nullish(),
|
||||
refiner_steps: z.number().int().nullish(),
|
||||
refiner_scheduler: z.string().nullish(),
|
||||
refiner_positive_aesthetic_store: z.number().nullish(),
|
||||
refiner_negative_aesthetic_store: z.number().nullish(),
|
||||
refiner_positive_aesthetic_score: z.number().nullish(),
|
||||
refiner_negative_aesthetic_score: z.number().nullish(),
|
||||
refiner_start: z.number().nullish(),
|
||||
})
|
||||
.catchall(z.record(z.any()));
|
||||
.passthrough();
|
||||
|
||||
export type CoreMetadata = z.infer<typeof zCoreMetadata>;
|
||||
|
||||
@ -936,22 +939,10 @@ export const zWorkflow = z.object({
|
||||
});
|
||||
|
||||
export const zValidatedWorkflow = zWorkflow.transform((workflow) => {
|
||||
const nodeTemplates = store.getState().nodes.nodeTemplates;
|
||||
const { nodes, edges } = workflow;
|
||||
const warnings: WorkflowWarning[] = [];
|
||||
const invocationNodes = nodes.filter(isWorkflowInvocationNode);
|
||||
const keyedNodes = keyBy(invocationNodes, 'id');
|
||||
invocationNodes.forEach((node, i) => {
|
||||
const nodeTemplate = nodeTemplates[node.data.type];
|
||||
if (!nodeTemplate) {
|
||||
warnings.push({
|
||||
message: `Node "${node.data.label || node.data.id}" skipped`,
|
||||
issues: [`Unable to find template for type "${node.data.type}"`],
|
||||
data: node,
|
||||
});
|
||||
delete nodes[i];
|
||||
}
|
||||
});
|
||||
edges.forEach((edge, i) => {
|
||||
const sourceNode = keyedNodes[edge.source];
|
||||
const targetNode = keyedNodes[edge.target];
|
||||
|
@ -1,4 +1,6 @@
|
||||
import * as png from '@stevebel/png';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import {
|
||||
ImageMetadataAndWorkflow,
|
||||
zCoreMetadata,
|
||||
@ -18,6 +20,11 @@ export const getMetadataAndWorkflowFromImageBlob = async (
|
||||
const metadataResult = zCoreMetadata.safeParse(JSON.parse(rawMetadata));
|
||||
if (metadataResult.success) {
|
||||
data.metadata = metadataResult.data;
|
||||
} else {
|
||||
logger('system').error(
|
||||
{ error: parseify(metadataResult.error) },
|
||||
'Problem reading metadata from image'
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -26,6 +33,11 @@ export const getMetadataAndWorkflowFromImageBlob = async (
|
||||
const workflowResult = zWorkflow.safeParse(JSON.parse(rawWorkflow));
|
||||
if (workflowResult.success) {
|
||||
data.workflow = workflowResult.data;
|
||||
} else {
|
||||
logger('system').error(
|
||||
{ error: parseify(workflowResult.error) },
|
||||
'Problem reading workflow from image'
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -61,9 +61,9 @@ export const addSDXLRefinerToGraph = (
|
||||
|
||||
if (metadataAccumulator) {
|
||||
metadataAccumulator.refiner_model = refinerModel;
|
||||
metadataAccumulator.refiner_positive_aesthetic_store =
|
||||
metadataAccumulator.refiner_positive_aesthetic_score =
|
||||
refinerPositiveAestheticScore;
|
||||
metadataAccumulator.refiner_negative_aesthetic_store =
|
||||
metadataAccumulator.refiner_negative_aesthetic_score =
|
||||
refinerNegativeAestheticScore;
|
||||
metadataAccumulator.refiner_cfg_scale = refinerCFGScale;
|
||||
metadataAccumulator.refiner_scheduler = refinerScheduler;
|
||||
|
@ -341,8 +341,8 @@ export const useRecallParameters = () => {
|
||||
refiner_cfg_scale,
|
||||
refiner_steps,
|
||||
refiner_scheduler,
|
||||
refiner_positive_aesthetic_store,
|
||||
refiner_negative_aesthetic_store,
|
||||
refiner_positive_aesthetic_score,
|
||||
refiner_negative_aesthetic_score,
|
||||
refiner_start,
|
||||
} = metadata;
|
||||
|
||||
@ -403,21 +403,21 @@ export const useRecallParameters = () => {
|
||||
|
||||
if (
|
||||
isValidSDXLRefinerPositiveAestheticScore(
|
||||
refiner_positive_aesthetic_store
|
||||
refiner_positive_aesthetic_score
|
||||
)
|
||||
) {
|
||||
dispatch(
|
||||
setRefinerPositiveAestheticScore(refiner_positive_aesthetic_store)
|
||||
setRefinerPositiveAestheticScore(refiner_positive_aesthetic_score)
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
isValidSDXLRefinerNegativeAestheticScore(
|
||||
refiner_negative_aesthetic_store
|
||||
refiner_negative_aesthetic_score
|
||||
)
|
||||
) {
|
||||
dispatch(
|
||||
setRefinerNegativeAestheticScore(refiner_negative_aesthetic_store)
|
||||
setRefinerNegativeAestheticScore(refiner_negative_aesthetic_score)
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -28,6 +28,8 @@ import {
|
||||
} from '../util';
|
||||
import { boardsApi } from './boards';
|
||||
import { ImageMetadataAndWorkflow } from 'features/nodes/types/types';
|
||||
import { fetchBaseQuery } from '@reduxjs/toolkit/dist/query';
|
||||
import { $authToken, $projectId } from '../client';
|
||||
|
||||
export const imagesApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
@ -115,18 +117,40 @@ export const imagesApi = api.injectEndpoints({
|
||||
],
|
||||
keepUnusedDataFor: 86400, // 24 hours
|
||||
}),
|
||||
getImageMetadataFromFile: build.query<ImageMetadataAndWorkflow, string>({
|
||||
query: (image_name) => ({
|
||||
url: `images/i/${image_name}/full`,
|
||||
responseHandler: async (res) => {
|
||||
return await res.blob();
|
||||
},
|
||||
}),
|
||||
providesTags: (result, error, image_name) => [
|
||||
{ type: 'ImageMetadataFromFile', id: image_name },
|
||||
getImageMetadataFromFile: build.query<ImageMetadataAndWorkflow, ImageDTO>({
|
||||
queryFn: async (args: ImageDTO, api, extraOptions) => {
|
||||
const authToken = $authToken.get();
|
||||
const projectId = $projectId.get();
|
||||
const customBaseQuery = fetchBaseQuery({
|
||||
baseUrl: '',
|
||||
prepareHeaders: (headers) => {
|
||||
if (authToken) {
|
||||
headers.set('Authorization', `Bearer ${authToken}`);
|
||||
}
|
||||
if (projectId) {
|
||||
headers.set('project-id', projectId);
|
||||
}
|
||||
|
||||
return headers;
|
||||
},
|
||||
responseHandler: async (res) => {
|
||||
return await res.blob();
|
||||
},
|
||||
});
|
||||
|
||||
const response = await customBaseQuery(
|
||||
args.image_url,
|
||||
api,
|
||||
extraOptions
|
||||
);
|
||||
const data = await getMetadataAndWorkflowFromImageBlob(
|
||||
response.data as Blob
|
||||
);
|
||||
return { data };
|
||||
},
|
||||
providesTags: (result, error, image_dto) => [
|
||||
{ type: 'ImageMetadataFromFile', id: image_dto.image_name },
|
||||
],
|
||||
transformResponse: (response: Blob) =>
|
||||
getMetadataAndWorkflowFromImageBlob(response),
|
||||
keepUnusedDataFor: 86400, // 24 hours
|
||||
}),
|
||||
clearIntermediates: build.mutation<number, void>({
|
||||
|
Loading…
Reference in New Issue
Block a user