Merge branch 'main' into bugfix/set-vram-on-macs

This commit is contained in:
psychedelicious 2023-09-02 11:33:20 +10:00 committed by GitHub
commit 4b78deba92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 10 additions and 23 deletions

View File

@ -249,7 +249,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder.""" """Apply selected lora to unet and text_encoder."""
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA") lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = Field(default=0.75, description=FieldDescriptions.lora_weight) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = Field( unet: Optional[UNetField] = Field(
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNET" default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNET"
) )

View File

@ -1,7 +1,7 @@
import math import math
import torch
import diffusers
import diffusers
import torch
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
torch.empty = torch.zeros torch.empty = torch.zeros

View File

@ -104,22 +104,22 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
]); ]);
const handleSetControlImageToDimensions = useCallback(() => { const handleSetControlImageToDimensions = useCallback(() => {
if (!processedControlImage) { if (!controlImage) {
return; return;
} }
if (activeTabName === 'unifiedCanvas') { if (activeTabName === 'unifiedCanvas') {
dispatch( dispatch(
setBoundingBoxDimensions({ setBoundingBoxDimensions({
width: processedControlImage.width, width: controlImage.width,
height: processedControlImage.height, height: controlImage.height,
}) })
); );
} else { } else {
dispatch(setWidth(processedControlImage.width)); dispatch(setWidth(controlImage.width));
dispatch(setHeight(processedControlImage.height)); dispatch(setHeight(controlImage.height));
} }
}, [processedControlImage, activeTabName, dispatch]); }, [controlImage, activeTabName, dispatch]);
const handleMouseEnter = useCallback(() => { const handleMouseEnter = useCallback(() => {
setIsMouseOverImage(true); setIsMouseOverImage(true);

View File

@ -1,4 +1,3 @@
import { store } from 'app/store/store';
import { import {
SchedulerParam, SchedulerParam,
zBaseModel, zBaseModel,
@ -10,7 +9,6 @@ import { keyBy } from 'lodash-es';
import { OpenAPIV3 } from 'openapi-types'; import { OpenAPIV3 } from 'openapi-types';
import { RgbaColor } from 'react-colorful'; import { RgbaColor } from 'react-colorful';
import { Node } from 'reactflow'; import { Node } from 'reactflow';
import { JsonObject } from 'type-fest';
import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types'; import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types';
import { import {
AnyInvocationType, AnyInvocationType,
@ -18,6 +16,7 @@ import {
ProgressImage, ProgressImage,
} from 'services/events/types'; } from 'services/events/types';
import { O } from 'ts-toolbelt'; import { O } from 'ts-toolbelt';
import { JsonObject } from 'type-fest';
import { z } from 'zod'; import { z } from 'zod';
export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>; export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>;
@ -936,22 +935,10 @@ export const zWorkflow = z.object({
}); });
export const zValidatedWorkflow = zWorkflow.transform((workflow) => { export const zValidatedWorkflow = zWorkflow.transform((workflow) => {
const nodeTemplates = store.getState().nodes.nodeTemplates;
const { nodes, edges } = workflow; const { nodes, edges } = workflow;
const warnings: WorkflowWarning[] = []; const warnings: WorkflowWarning[] = [];
const invocationNodes = nodes.filter(isWorkflowInvocationNode); const invocationNodes = nodes.filter(isWorkflowInvocationNode);
const keyedNodes = keyBy(invocationNodes, 'id'); const keyedNodes = keyBy(invocationNodes, 'id');
invocationNodes.forEach((node, i) => {
const nodeTemplate = nodeTemplates[node.data.type];
if (!nodeTemplate) {
warnings.push({
message: `Node "${node.data.label || node.data.id}" skipped`,
issues: [`Unable to find template for type "${node.data.type}"`],
data: node,
});
delete nodes[i];
}
});
edges.forEach((edge, i) => { edges.forEach((edge, i) => {
const sourceNode = keyedNodes[edge.source]; const sourceNode = keyedNodes[edge.source];
const targetNode = keyedNodes[edge.target]; const targetNode = keyedNodes[edge.target];