Merge branch 'main' into bugfix/set-vram-on-macs

This commit is contained in:
Lincoln Stein 2023-09-02 10:08:13 -04:00 committed by GitHub
commit c965d3eb6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 168 additions and 76 deletions

View File

@ -279,8 +279,8 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
crop_left: int = InputField(default=0, description="")
target_width: int = InputField(default=1024, description="")
target_height: int = InputField(default=1024, description="")
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:

View File

@ -72,10 +72,10 @@ class CoreMetadata(BaseModelExcludeNull):
)
refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner")
refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner")
refiner_positive_aesthetic_store: Optional[float] = Field(
refiner_positive_aesthetic_score: Optional[float] = Field(
default=None, description="The aesthetic score used for the refiner"
)
refiner_negative_aesthetic_store: Optional[float] = Field(
refiner_negative_aesthetic_score: Optional[float] = Field(
default=None, description="The aesthetic score used for the refiner"
)
refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising")
@ -160,11 +160,11 @@ class MetadataAccumulatorInvocation(BaseInvocation):
default=None,
description="The scheduler used for the refiner",
)
refiner_positive_aesthetic_store: Optional[float] = InputField(
refiner_positive_aesthetic_score: Optional[float] = InputField(
default=None,
description="The aesthetic score used for the refiner",
)
refiner_negative_aesthetic_store: Optional[float] = InputField(
refiner_negative_aesthetic_score: Optional[float] = InputField(
default=None,
description="The aesthetic score used for the refiner",
)

View File

@ -250,13 +250,13 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = Field(
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNET"
unet: Optional[UNetField] = InputField(
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
)
clip: Optional[ClipField] = Field(
clip: Optional[ClipField] = InputField(
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1"
)
clip2: Optional[ClipField] = Field(
clip2: Optional[ClipField] = InputField(
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2"
)

View File

@ -50,6 +50,7 @@ class ModelProbe(object):
"StableDiffusionInpaintPipeline": ModelType.Main,
"StableDiffusionXLPipeline": ModelType.Main,
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet,
}

View File

@ -110,7 +110,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
);
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
lastSelectedImage?.image_name ?? skipToken,
lastSelectedImage ?? skipToken,
{
selectFromResult: (res) => ({
isLoading: res.isFetching,

View File

@ -52,7 +52,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
imageDTO.image_name,
imageDTO,
{
selectFromResult: (res) => ({
isLoading: res.isFetching,

View File

@ -101,13 +101,15 @@ const ImageMetadataActions = (props: Props) => {
onClick={handleRecallSeed}
/>
)}
{metadata.model !== undefined && metadata.model !== null && (
<ImageMetadataItem
label="Model"
value={metadata.model.model_name}
onClick={handleRecallModel}
/>
)}
{metadata.model !== undefined &&
metadata.model !== null &&
metadata.model.model_name && (
<ImageMetadataItem
label="Model"
value={metadata.model.model_name}
onClick={handleRecallModel}
/>
)}
{metadata.width && (
<ImageMetadataItem
label="Width"

View File

@ -27,15 +27,12 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
// dispatch(setShouldShowImageDetails(false));
// });
const { metadata, workflow } = useGetImageMetadataFromFileQuery(
image.image_name,
{
selectFromResult: (res) => ({
metadata: res?.currentData?.metadata,
workflow: res?.currentData?.workflow,
}),
}
);
const { metadata, workflow } = useGetImageMetadataFromFileQuery(image, {
selectFromResult: (res) => ({
metadata: res?.currentData?.metadata,
workflow: res?.currentData?.workflow,
}),
});
return (
<Flex

View File

@ -1,7 +1,9 @@
import {
SchedulerParam,
zBaseModel,
zMainModel,
zMainOrOnnxModel,
zOnnxModel,
zSDXLRefinerModel,
zScheduler,
} from 'features/parameters/types/parameterSchemas';
@ -769,12 +771,14 @@ export const zCoreMetadata = z
steps: z.number().int().nullish(),
scheduler: z.string().nullish(),
clip_skip: z.number().int().nullish(),
model: zMainOrOnnxModel.nullish(),
controlnets: z.array(zControlField).nullish(),
model: z
.union([zMainModel.deepPartial(), zOnnxModel.deepPartial()])
.nullish(),
controlnets: z.array(zControlField.deepPartial()).nullish(),
loras: z
.array(
z.object({
lora: zLoRAModelField,
lora: zLoRAModelField.deepPartial(),
weight: z.number(),
})
)
@ -784,15 +788,15 @@ export const zCoreMetadata = z
init_image: z.string().nullish(),
positive_style_prompt: z.string().nullish(),
negative_style_prompt: z.string().nullish(),
refiner_model: zSDXLRefinerModel.nullish(),
refiner_model: zSDXLRefinerModel.deepPartial().nullish(),
refiner_cfg_scale: z.number().nullish(),
refiner_steps: z.number().int().nullish(),
refiner_scheduler: z.string().nullish(),
refiner_positive_aesthetic_store: z.number().nullish(),
refiner_negative_aesthetic_store: z.number().nullish(),
refiner_positive_aesthetic_score: z.number().nullish(),
refiner_negative_aesthetic_score: z.number().nullish(),
refiner_start: z.number().nullish(),
})
.catchall(z.record(z.any()));
.passthrough();
export type CoreMetadata = z.infer<typeof zCoreMetadata>;

View File

@ -1,4 +1,6 @@
import * as png from '@stevebel/png';
import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import {
ImageMetadataAndWorkflow,
zCoreMetadata,
@ -18,6 +20,11 @@ export const getMetadataAndWorkflowFromImageBlob = async (
const metadataResult = zCoreMetadata.safeParse(JSON.parse(rawMetadata));
if (metadataResult.success) {
data.metadata = metadataResult.data;
} else {
logger('system').error(
{ error: parseify(metadataResult.error) },
'Problem reading metadata from image'
);
}
}
@ -26,6 +33,11 @@ export const getMetadataAndWorkflowFromImageBlob = async (
const workflowResult = zWorkflow.safeParse(JSON.parse(rawWorkflow));
if (workflowResult.success) {
data.workflow = workflowResult.data;
} else {
logger('system').error(
{ error: parseify(workflowResult.error) },
'Problem reading workflow from image'
);
}
}

View File

@ -60,9 +60,9 @@ export const addSDXLRefinerToGraph = (
if (metadataAccumulator) {
metadataAccumulator.refiner_model = refinerModel;
metadataAccumulator.refiner_positive_aesthetic_store =
metadataAccumulator.refiner_positive_aesthetic_score =
refinerPositiveAestheticScore;
metadataAccumulator.refiner_negative_aesthetic_store =
metadataAccumulator.refiner_negative_aesthetic_score =
refinerNegativeAestheticScore;
metadataAccumulator.refiner_cfg_scale = refinerCFGScale;
metadataAccumulator.refiner_scheduler = refinerScheduler;

View File

@ -341,8 +341,8 @@ export const useRecallParameters = () => {
refiner_cfg_scale,
refiner_steps,
refiner_scheduler,
refiner_positive_aesthetic_store,
refiner_negative_aesthetic_store,
refiner_positive_aesthetic_score,
refiner_negative_aesthetic_score,
refiner_start,
} = metadata;
@ -403,21 +403,21 @@ export const useRecallParameters = () => {
if (
isValidSDXLRefinerPositiveAestheticScore(
refiner_positive_aesthetic_store
refiner_positive_aesthetic_score
)
) {
dispatch(
setRefinerPositiveAestheticScore(refiner_positive_aesthetic_store)
setRefinerPositiveAestheticScore(refiner_positive_aesthetic_score)
);
}
if (
isValidSDXLRefinerNegativeAestheticScore(
refiner_negative_aesthetic_store
refiner_negative_aesthetic_score
)
) {
dispatch(
setRefinerNegativeAestheticScore(refiner_negative_aesthetic_store)
setRefinerNegativeAestheticScore(refiner_negative_aesthetic_score)
);
}

View File

@ -28,6 +28,8 @@ import {
} from '../util';
import { boardsApi } from './boards';
import { ImageMetadataAndWorkflow } from 'features/nodes/types/types';
import { fetchBaseQuery } from '@reduxjs/toolkit/dist/query';
import { $authToken, $projectId } from '../client';
export const imagesApi = api.injectEndpoints({
endpoints: (build) => ({
@ -115,18 +117,40 @@ export const imagesApi = api.injectEndpoints({
],
keepUnusedDataFor: 86400, // 24 hours
}),
getImageMetadataFromFile: build.query<ImageMetadataAndWorkflow, string>({
query: (image_name) => ({
url: `images/i/${image_name}/full`,
responseHandler: async (res) => {
return await res.blob();
},
}),
providesTags: (result, error, image_name) => [
{ type: 'ImageMetadataFromFile', id: image_name },
getImageMetadataFromFile: build.query<ImageMetadataAndWorkflow, ImageDTO>({
queryFn: async (args: ImageDTO, api, extraOptions) => {
const authToken = $authToken.get();
const projectId = $projectId.get();
const customBaseQuery = fetchBaseQuery({
baseUrl: '',
prepareHeaders: (headers) => {
if (authToken) {
headers.set('Authorization', `Bearer ${authToken}`);
}
if (projectId) {
headers.set('project-id', projectId);
}
return headers;
},
responseHandler: async (res) => {
return await res.blob();
},
});
const response = await customBaseQuery(
args.image_url,
api,
extraOptions
);
const data = await getMetadataAndWorkflowFromImageBlob(
response.data as Blob
);
return { data };
},
providesTags: (result, error, image_dto) => [
{ type: 'ImageMetadataFromFile', id: image_dto.image_name },
],
transformResponse: (response: Blob) =>
getMetadataAndWorkflowFromImageBlob(response),
keepUnusedDataFor: 86400, // 24 hours
}),
clearIntermediates: build.mutation<number, void>({

File diff suppressed because one or more lines are too long