Merge branch 'main' into feat/taesd

This commit is contained in:
Kevin Turner 2023-09-01 22:18:40 -07:00 committed by GitHub
commit 7df67d077a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 235 additions and 153 deletions

View File

@ -279,8 +279,8 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
crop_left: int = InputField(default=0, description="") crop_left: int = InputField(default=0, description="")
target_width: int = InputField(default=1024, description="") target_width: int = InputField(default=1024, description="")
target_height: int = InputField(default=1024, description="") target_height: int = InputField(default=1024, description="")
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:

View File

@ -72,10 +72,10 @@ class CoreMetadata(BaseModelExcludeNull):
) )
refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner") refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner")
refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner") refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner")
refiner_positive_aesthetic_store: Optional[float] = Field( refiner_positive_aesthetic_score: Optional[float] = Field(
default=None, description="The aesthetic score used for the refiner" default=None, description="The aesthetic score used for the refiner"
) )
refiner_negative_aesthetic_store: Optional[float] = Field( refiner_negative_aesthetic_score: Optional[float] = Field(
default=None, description="The aesthetic score used for the refiner" default=None, description="The aesthetic score used for the refiner"
) )
refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising") refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising")
@ -160,11 +160,11 @@ class MetadataAccumulatorInvocation(BaseInvocation):
default=None, default=None,
description="The scheduler used for the refiner", description="The scheduler used for the refiner",
) )
refiner_positive_aesthetic_store: Optional[float] = InputField( refiner_positive_aesthetic_score: Optional[float] = InputField(
default=None, default=None,
description="The aesthetic score used for the refiner", description="The aesthetic score used for the refiner",
) )
refiner_negative_aesthetic_store: Optional[float] = InputField( refiner_negative_aesthetic_score: Optional[float] = InputField(
default=None, default=None,
description="The aesthetic score used for the refiner", description="The aesthetic score used for the refiner",
) )

View File

@ -249,14 +249,14 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder.""" """Apply selected lora to unet and text_encoder."""
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA") lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = Field(default=0.75, description=FieldDescriptions.lora_weight) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = Field( unet: Optional[UNetField] = InputField(
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNET" default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
) )
clip: Optional[ClipField] = Field( clip: Optional[ClipField] = InputField(
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1" default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1"
) )
clip2: Optional[ClipField] = Field( clip2: Optional[ClipField] = InputField(
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2" default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2"
) )

View File

@ -49,6 +49,7 @@ class ModelProbe(object):
"StableDiffusionInpaintPipeline": ModelType.Main, "StableDiffusionInpaintPipeline": ModelType.Main,
"StableDiffusionXLPipeline": ModelType.Main, "StableDiffusionXLPipeline": ModelType.Main,
"StableDiffusionXLImg2ImgPipeline": ModelType.Main, "StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae, "AutoencoderKL": ModelType.Vae,
"AutoencoderTiny": ModelType.Vae, "AutoencoderTiny": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet, "ControlNetModel": ModelType.ControlNet,

View File

@ -215,7 +215,10 @@ class InvokeAIDiffuserComponent:
dim=0, dim=0,
), ),
} }
(encoder_hidden_states, encoder_attention_mask,) = self._concat_conditionings_for_batch( (
encoder_hidden_states,
encoder_attention_mask,
) = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.unconditioned_embeddings.embeds,
conditioning_data.text_embeddings.embeds, conditioning_data.text_embeddings.embeds,
) )
@ -277,7 +280,10 @@ class InvokeAIDiffuserComponent:
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0 wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
if wants_cross_attention_control: if wants_cross_attention_control:
(unconditioned_next_x, conditioned_next_x,) = self._apply_cross_attention_controlled_conditioning( (
unconditioned_next_x,
conditioned_next_x,
) = self._apply_cross_attention_controlled_conditioning(
sample, sample,
timestep, timestep,
conditioning_data, conditioning_data,
@ -285,7 +291,10 @@ class InvokeAIDiffuserComponent:
**kwargs, **kwargs,
) )
elif self.sequential_guidance: elif self.sequential_guidance:
(unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning_sequentially( (
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning_sequentially(
sample, sample,
timestep, timestep,
conditioning_data, conditioning_data,
@ -293,7 +302,10 @@ class InvokeAIDiffuserComponent:
) )
else: else:
(unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning( (
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
sample, sample,
timestep, timestep,
conditioning_data, conditioning_data,

View File

@ -562,18 +562,14 @@ def rgb2ycbcr(img, only_y=True):
if only_y: if only_y:
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
else: else:
rlt = ( rlt = np.matmul(
np.matmul(
img, img,
[ [
[65.481, -37.797, 112.0], [65.481, -37.797, 112.0],
[128.553, -74.203, -93.786], [128.553, -74.203, -93.786],
[24.966, 112.0, -18.214], [24.966, 112.0, -18.214],
], ],
) ) / 255.0 + [16, 128, 128]
/ 255.0
+ [16, 128, 128]
)
if in_img_type == np.uint8: if in_img_type == np.uint8:
rlt = rlt.round() rlt = rlt.round()
else: else:
@ -592,18 +588,14 @@ def ycbcr2rgb(img):
if in_img_type != np.uint8: if in_img_type != np.uint8:
img *= 255.0 img *= 255.0
# convert # convert
rlt = ( rlt = np.matmul(
np.matmul(
img, img,
[ [
[0.00456621, 0.00456621, 0.00456621], [0.00456621, 0.00456621, 0.00456621],
[0, -0.00153632, 0.00791071], [0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0], [0.00625893, -0.00318811, 0],
], ],
) ) * 255.0 + [-222.921, 135.576, -276.836]
* 255.0
+ [-222.921, 135.576, -276.836]
)
if in_img_type == np.uint8: if in_img_type == np.uint8:
rlt = rlt.round() rlt = rlt.round()
else: else:
@ -626,18 +618,14 @@ def bgr2ycbcr(img, only_y=True):
if only_y: if only_y:
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
else: else:
rlt = ( rlt = np.matmul(
np.matmul(
img, img,
[ [
[24.966, 112.0, -18.214], [24.966, 112.0, -18.214],
[128.553, -74.203, -93.786], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0], [65.481, -37.797, 112.0],
], ],
) ) / 255.0 + [16, 128, 128]
/ 255.0
+ [16, 128, 128]
)
if in_img_type == np.uint8: if in_img_type == np.uint8:
rlt = rlt.round() rlt = rlt.round()
else: else:

View File

@ -475,7 +475,10 @@ class TextualInversionDataset(Dataset):
if self.center_crop: if self.center_crop:
crop = min(img.shape[0], img.shape[1]) crop = min(img.shape[0], img.shape[1])
(h, w,) = ( (
h,
w,
) = (
img.shape[0], img.shape[0],
img.shape[1], img.shape[1],
) )

View File

@ -1,7 +1,7 @@
import math import math
import torch
import diffusers
import diffusers
import torch
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
torch.empty = torch.zeros torch.empty = torch.zeros

View File

@ -104,22 +104,22 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
]); ]);
const handleSetControlImageToDimensions = useCallback(() => { const handleSetControlImageToDimensions = useCallback(() => {
if (!processedControlImage) { if (!controlImage) {
return; return;
} }
if (activeTabName === 'unifiedCanvas') { if (activeTabName === 'unifiedCanvas') {
dispatch( dispatch(
setBoundingBoxDimensions({ setBoundingBoxDimensions({
width: processedControlImage.width, width: controlImage.width,
height: processedControlImage.height, height: controlImage.height,
}) })
); );
} else { } else {
dispatch(setWidth(processedControlImage.width)); dispatch(setWidth(controlImage.width));
dispatch(setHeight(processedControlImage.height)); dispatch(setHeight(controlImage.height));
} }
}, [processedControlImage, activeTabName, dispatch]); }, [controlImage, activeTabName, dispatch]);
const handleMouseEnter = useCallback(() => { const handleMouseEnter = useCallback(() => {
setIsMouseOverImage(true); setIsMouseOverImage(true);

View File

@ -110,7 +110,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
); );
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery( const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
lastSelectedImage?.image_name ?? skipToken, lastSelectedImage ?? skipToken,
{ {
selectFromResult: (res) => ({ selectFromResult: (res) => ({
isLoading: res.isFetching, isLoading: res.isFetching,

View File

@ -52,7 +52,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled; const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery( const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
imageDTO.image_name, imageDTO,
{ {
selectFromResult: (res) => ({ selectFromResult: (res) => ({
isLoading: res.isFetching, isLoading: res.isFetching,

View File

@ -101,7 +101,9 @@ const ImageMetadataActions = (props: Props) => {
onClick={handleRecallSeed} onClick={handleRecallSeed}
/> />
)} )}
{metadata.model !== undefined && metadata.model !== null && ( {metadata.model !== undefined &&
metadata.model !== null &&
metadata.model.model_name && (
<ImageMetadataItem <ImageMetadataItem
label="Model" label="Model"
value={metadata.model.model_name} value={metadata.model.model_name}

View File

@ -27,15 +27,12 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
// dispatch(setShouldShowImageDetails(false)); // dispatch(setShouldShowImageDetails(false));
// }); // });
const { metadata, workflow } = useGetImageMetadataFromFileQuery( const { metadata, workflow } = useGetImageMetadataFromFileQuery(image, {
image.image_name,
{
selectFromResult: (res) => ({ selectFromResult: (res) => ({
metadata: res?.currentData?.metadata, metadata: res?.currentData?.metadata,
workflow: res?.currentData?.workflow, workflow: res?.currentData?.workflow,
}), }),
} });
);
return ( return (
<Flex <Flex

View File

@ -1,8 +1,9 @@
import { store } from 'app/store/store';
import { import {
SchedulerParam, SchedulerParam,
zBaseModel, zBaseModel,
zMainModel,
zMainOrOnnxModel, zMainOrOnnxModel,
zOnnxModel,
zSDXLRefinerModel, zSDXLRefinerModel,
zScheduler, zScheduler,
} from 'features/parameters/types/parameterSchemas'; } from 'features/parameters/types/parameterSchemas';
@ -10,7 +11,6 @@ import { keyBy } from 'lodash-es';
import { OpenAPIV3 } from 'openapi-types'; import { OpenAPIV3 } from 'openapi-types';
import { RgbaColor } from 'react-colorful'; import { RgbaColor } from 'react-colorful';
import { Node } from 'reactflow'; import { Node } from 'reactflow';
import { JsonObject } from 'type-fest';
import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types'; import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types';
import { import {
AnyInvocationType, AnyInvocationType,
@ -18,6 +18,7 @@ import {
ProgressImage, ProgressImage,
} from 'services/events/types'; } from 'services/events/types';
import { O } from 'ts-toolbelt'; import { O } from 'ts-toolbelt';
import { JsonObject } from 'type-fest';
import { z } from 'zod'; import { z } from 'zod';
export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>; export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>;
@ -770,12 +771,14 @@ export const zCoreMetadata = z
steps: z.number().int().nullish(), steps: z.number().int().nullish(),
scheduler: z.string().nullish(), scheduler: z.string().nullish(),
clip_skip: z.number().int().nullish(), clip_skip: z.number().int().nullish(),
model: zMainOrOnnxModel.nullish(), model: z
controlnets: z.array(zControlField).nullish(), .union([zMainModel.deepPartial(), zOnnxModel.deepPartial()])
.nullish(),
controlnets: z.array(zControlField.deepPartial()).nullish(),
loras: z loras: z
.array( .array(
z.object({ z.object({
lora: zLoRAModelField, lora: zLoRAModelField.deepPartial(),
weight: z.number(), weight: z.number(),
}) })
) )
@ -785,15 +788,15 @@ export const zCoreMetadata = z
init_image: z.string().nullish(), init_image: z.string().nullish(),
positive_style_prompt: z.string().nullish(), positive_style_prompt: z.string().nullish(),
negative_style_prompt: z.string().nullish(), negative_style_prompt: z.string().nullish(),
refiner_model: zSDXLRefinerModel.nullish(), refiner_model: zSDXLRefinerModel.deepPartial().nullish(),
refiner_cfg_scale: z.number().nullish(), refiner_cfg_scale: z.number().nullish(),
refiner_steps: z.number().int().nullish(), refiner_steps: z.number().int().nullish(),
refiner_scheduler: z.string().nullish(), refiner_scheduler: z.string().nullish(),
refiner_positive_aesthetic_store: z.number().nullish(), refiner_positive_aesthetic_score: z.number().nullish(),
refiner_negative_aesthetic_store: z.number().nullish(), refiner_negative_aesthetic_score: z.number().nullish(),
refiner_start: z.number().nullish(), refiner_start: z.number().nullish(),
}) })
.catchall(z.record(z.any())); .passthrough();
export type CoreMetadata = z.infer<typeof zCoreMetadata>; export type CoreMetadata = z.infer<typeof zCoreMetadata>;
@ -936,22 +939,10 @@ export const zWorkflow = z.object({
}); });
export const zValidatedWorkflow = zWorkflow.transform((workflow) => { export const zValidatedWorkflow = zWorkflow.transform((workflow) => {
const nodeTemplates = store.getState().nodes.nodeTemplates;
const { nodes, edges } = workflow; const { nodes, edges } = workflow;
const warnings: WorkflowWarning[] = []; const warnings: WorkflowWarning[] = [];
const invocationNodes = nodes.filter(isWorkflowInvocationNode); const invocationNodes = nodes.filter(isWorkflowInvocationNode);
const keyedNodes = keyBy(invocationNodes, 'id'); const keyedNodes = keyBy(invocationNodes, 'id');
invocationNodes.forEach((node, i) => {
const nodeTemplate = nodeTemplates[node.data.type];
if (!nodeTemplate) {
warnings.push({
message: `Node "${node.data.label || node.data.id}" skipped`,
issues: [`Unable to find template for type "${node.data.type}"`],
data: node,
});
delete nodes[i];
}
});
edges.forEach((edge, i) => { edges.forEach((edge, i) => {
const sourceNode = keyedNodes[edge.source]; const sourceNode = keyedNodes[edge.source];
const targetNode = keyedNodes[edge.target]; const targetNode = keyedNodes[edge.target];

View File

@ -1,4 +1,6 @@
import * as png from '@stevebel/png'; import * as png from '@stevebel/png';
import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import { import {
ImageMetadataAndWorkflow, ImageMetadataAndWorkflow,
zCoreMetadata, zCoreMetadata,
@ -18,6 +20,11 @@ export const getMetadataAndWorkflowFromImageBlob = async (
const metadataResult = zCoreMetadata.safeParse(JSON.parse(rawMetadata)); const metadataResult = zCoreMetadata.safeParse(JSON.parse(rawMetadata));
if (metadataResult.success) { if (metadataResult.success) {
data.metadata = metadataResult.data; data.metadata = metadataResult.data;
} else {
logger('system').error(
{ error: parseify(metadataResult.error) },
'Problem reading metadata from image'
);
} }
} }
@ -26,6 +33,11 @@ export const getMetadataAndWorkflowFromImageBlob = async (
const workflowResult = zWorkflow.safeParse(JSON.parse(rawWorkflow)); const workflowResult = zWorkflow.safeParse(JSON.parse(rawWorkflow));
if (workflowResult.success) { if (workflowResult.success) {
data.workflow = workflowResult.data; data.workflow = workflowResult.data;
} else {
logger('system').error(
{ error: parseify(workflowResult.error) },
'Problem reading workflow from image'
);
} }
} }

View File

@ -60,9 +60,9 @@ export const addSDXLRefinerToGraph = (
if (metadataAccumulator) { if (metadataAccumulator) {
metadataAccumulator.refiner_model = refinerModel; metadataAccumulator.refiner_model = refinerModel;
metadataAccumulator.refiner_positive_aesthetic_store = metadataAccumulator.refiner_positive_aesthetic_score =
refinerPositiveAestheticScore; refinerPositiveAestheticScore;
metadataAccumulator.refiner_negative_aesthetic_store = metadataAccumulator.refiner_negative_aesthetic_score =
refinerNegativeAestheticScore; refinerNegativeAestheticScore;
metadataAccumulator.refiner_cfg_scale = refinerCFGScale; metadataAccumulator.refiner_cfg_scale = refinerCFGScale;
metadataAccumulator.refiner_scheduler = refinerScheduler; metadataAccumulator.refiner_scheduler = refinerScheduler;

View File

@ -341,8 +341,8 @@ export const useRecallParameters = () => {
refiner_cfg_scale, refiner_cfg_scale,
refiner_steps, refiner_steps,
refiner_scheduler, refiner_scheduler,
refiner_positive_aesthetic_store, refiner_positive_aesthetic_score,
refiner_negative_aesthetic_store, refiner_negative_aesthetic_score,
refiner_start, refiner_start,
} = metadata; } = metadata;
@ -403,21 +403,21 @@ export const useRecallParameters = () => {
if ( if (
isValidSDXLRefinerPositiveAestheticScore( isValidSDXLRefinerPositiveAestheticScore(
refiner_positive_aesthetic_store refiner_positive_aesthetic_score
) )
) { ) {
dispatch( dispatch(
setRefinerPositiveAestheticScore(refiner_positive_aesthetic_store) setRefinerPositiveAestheticScore(refiner_positive_aesthetic_score)
); );
} }
if ( if (
isValidSDXLRefinerNegativeAestheticScore( isValidSDXLRefinerNegativeAestheticScore(
refiner_negative_aesthetic_store refiner_negative_aesthetic_score
) )
) { ) {
dispatch( dispatch(
setRefinerNegativeAestheticScore(refiner_negative_aesthetic_store) setRefinerNegativeAestheticScore(refiner_negative_aesthetic_score)
); );
} }

View File

@ -28,6 +28,8 @@ import {
} from '../util'; } from '../util';
import { boardsApi } from './boards'; import { boardsApi } from './boards';
import { ImageMetadataAndWorkflow } from 'features/nodes/types/types'; import { ImageMetadataAndWorkflow } from 'features/nodes/types/types';
import { fetchBaseQuery } from '@reduxjs/toolkit/dist/query';
import { $authToken, $projectId } from '../client';
export const imagesApi = api.injectEndpoints({ export const imagesApi = api.injectEndpoints({
endpoints: (build) => ({ endpoints: (build) => ({
@ -115,18 +117,40 @@ export const imagesApi = api.injectEndpoints({
], ],
keepUnusedDataFor: 86400, // 24 hours keepUnusedDataFor: 86400, // 24 hours
}), }),
getImageMetadataFromFile: build.query<ImageMetadataAndWorkflow, string>({ getImageMetadataFromFile: build.query<ImageMetadataAndWorkflow, ImageDTO>({
query: (image_name) => ({ queryFn: async (args: ImageDTO, api, extraOptions) => {
url: `images/i/${image_name}/full`, const authToken = $authToken.get();
const projectId = $projectId.get();
const customBaseQuery = fetchBaseQuery({
baseUrl: '',
prepareHeaders: (headers) => {
if (authToken) {
headers.set('Authorization', `Bearer ${authToken}`);
}
if (projectId) {
headers.set('project-id', projectId);
}
return headers;
},
responseHandler: async (res) => { responseHandler: async (res) => {
return await res.blob(); return await res.blob();
}, },
}), });
providesTags: (result, error, image_name) => [
{ type: 'ImageMetadataFromFile', id: image_name }, const response = await customBaseQuery(
args.image_url,
api,
extraOptions
);
const data = await getMetadataAndWorkflowFromImageBlob(
response.data as Blob
);
return { data };
},
providesTags: (result, error, image_dto) => [
{ type: 'ImageMetadataFromFile', id: image_dto.image_name },
], ],
transformResponse: (response: Blob) =>
getMetadataAndWorkflowFromImageBlob(response),
keepUnusedDataFor: 86400, // 24 hours keepUnusedDataFor: 86400, // 24 hours
}), }),
clearIntermediates: build.mutation<number, void>({ clearIntermediates: build.mutation<number, void>({

File diff suppressed because one or more lines are too long