Merge branch 'main' into fix-inpainting

This commit is contained in:
blessedcoolant 2023-07-13 22:31:08 +12:00
commit 16f53228c2
13 changed files with 47 additions and 41 deletions

View File

@ -37,7 +37,7 @@ const ParamLora = (props: Props) => {
return ( return (
<Flex sx={{ gap: 2.5, alignItems: 'flex-end' }}> <Flex sx={{ gap: 2.5, alignItems: 'flex-end' }}>
<IAISlider <IAISlider
label={lora.name} label={lora.model_name}
value={lora.weight} value={lora.weight}
onChange={handleChange} onChange={handleChange}
min={-1} min={-1}

View File

@ -18,7 +18,7 @@ const selector = createSelector(
const ParamLoraList = () => { const ParamLoraList = () => {
const { loras } = useAppSelector(selector); const { loras } = useAppSelector(selector);
return map(loras, (lora) => <ParamLora key={lora.name} lora={lora} />); return map(loras, (lora) => <ParamLora key={lora.model_name} lora={lora} />);
}; };
export default ParamLoraList; export default ParamLoraList;

View File

@ -1,12 +1,8 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit'; import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { LoRAModelParam } from 'features/parameters/store/parameterZodSchemas'; import { LoRAModelParam } from 'features/parameters/store/parameterZodSchemas';
import { LoRAModelConfigEntity } from 'services/api/endpoints/models'; import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
import { BaseModelType } from 'services/api/types';
export type Lora = { export type Lora = LoRAModelParam & {
id: string;
base_model: BaseModelType;
name: string;
weight: number; weight: number;
}; };
@ -27,8 +23,8 @@ export const loraSlice = createSlice({
initialState: intialLoraState, initialState: intialLoraState,
reducers: { reducers: {
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => { loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
const { name, id, base_model } = action.payload; const { model_name, id, base_model } = action.payload;
state.loras[id] = { id, name, base_model, ...defaultLoRAConfig }; state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig };
}, },
loraRemoved: (state, action: PayloadAction<string>) => { loraRemoved: (state, action: PayloadAction<string>) => {
const id = action.payload; const id = action.payload;

View File

@ -45,7 +45,7 @@ const LoRAModelInputFieldComponent = (
data.push({ data.push({
value: id, value: id,
label: model.name, label: model.model_name,
group: BASE_MODEL_NAME_MAP[model.base_model], group: BASE_MODEL_NAME_MAP[model.base_model],
}); });
}); });

View File

@ -38,7 +38,7 @@ const ModelInputFieldComponent = (
data.push({ data.push({
value: id, value: id,
label: model.name, label: model.model_name,
group: BASE_MODEL_NAME_MAP[model.base_model], group: BASE_MODEL_NAME_MAP[model.base_model],
}); });
}); });

View File

@ -45,7 +45,7 @@ const VaeModelInputFieldComponent = (
data.push({ data.push({
value: id, value: id,
label: model.name, label: model.model_name,
group: BASE_MODEL_NAME_MAP[model.base_model], group: BASE_MODEL_NAME_MAP[model.base_model],
}); });
}); });

View File

@ -19,9 +19,9 @@ export const addControlNetToLinearGraph = (
const validControlNets = getValidControlNets(controlNets); const validControlNets = getValidControlNets(controlNets);
const metadataAccumulator = graph.nodes[ const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
METADATA_ACCUMULATOR | MetadataAccumulatorInvocation
] as MetadataAccumulatorInvocation; | undefined;
if (isControlNetEnabled && Boolean(validControlNets.length)) { if (isControlNetEnabled && Boolean(validControlNets.length)) {
if (validControlNets.length) { if (validControlNets.length) {
@ -79,13 +79,15 @@ export const addControlNetToLinearGraph = (
graph.nodes[controlNetNode.id] = controlNetNode; graph.nodes[controlNetNode.id] = controlNetNode;
// metadata accumulator only needs a control field - not the whole node if (metadataAccumulator) {
// extract what we need and add to the accumulator // metadata accumulator only needs a control field - not the whole node
const controlField = omit(controlNetNode, [ // extract what we need and add to the accumulator
'id', const controlField = omit(controlNetNode, [
'type', 'id',
]) as ControlField; 'type',
metadataAccumulator.controlnets.push(controlField); ]) as ControlField;
metadataAccumulator.controlnets.push(controlField);
}
graph.edges.push({ graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' }, source: { node_id: controlNetNode.id, field: 'control' },

View File

@ -32,9 +32,9 @@ export const addDynamicPromptsToGraph = (
maxPrompts, maxPrompts,
} = state.dynamicPrompts; } = state.dynamicPrompts;
const metadataAccumulator = graph.nodes[ const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
METADATA_ACCUMULATOR | MetadataAccumulatorInvocation
] as MetadataAccumulatorInvocation; | undefined;
if (isDynamicPromptsEnabled) { if (isDynamicPromptsEnabled) {
// iteration is handled via dynamic prompts // iteration is handled via dynamic prompts
@ -116,11 +116,15 @@ export const addDynamicPromptsToGraph = (
(graph.nodes[NOISE] as NoiseInvocation).seed = seed; (graph.nodes[NOISE] as NoiseInvocation).seed = seed;
// hook up seed to metadata // hook up seed to metadata
metadataAccumulator.seed = seed; if (metadataAccumulator) {
metadataAccumulator.seed = seed;
}
} }
} else { } else {
// no dynamic prompt - hook up positive prompt // no dynamic prompt - hook up positive prompt
metadataAccumulator.positive_prompt = positivePrompt; if (metadataAccumulator) {
metadataAccumulator.positive_prompt = positivePrompt;
}
const rangeOfSizeNode: RangeOfSizeInvocation = { const rangeOfSizeNode: RangeOfSizeInvocation = {
id: RANGE_OF_SIZE, id: RANGE_OF_SIZE,

View File

@ -30,9 +30,9 @@ export const addLoRAsToGraph = (
const { loras } = state.lora; const { loras } = state.lora;
const loraCount = size(loras); const loraCount = size(loras);
const metadataAccumulator = graph.nodes[ const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
METADATA_ACCUMULATOR | MetadataAccumulatorInvocation
] as MetadataAccumulatorInvocation; | undefined;
if (loraCount > 0) { if (loraCount > 0) {
// Remove MAIN_MODEL_LOADER unet connection to feed it to LoRAs // Remove MAIN_MODEL_LOADER unet connection to feed it to LoRAs
@ -70,7 +70,9 @@ export const addLoRAsToGraph = (
}; };
// add the lora to the metadata accumulator // add the lora to the metadata accumulator
metadataAccumulator.loras.push({ lora: loraField, weight }); if (metadataAccumulator) {
metadataAccumulator.loras.push({ lora: loraField, weight });
}
// add to graph // add to graph
graph.nodes[currentLoraNodeId] = loraLoaderNode; graph.nodes[currentLoraNodeId] = loraLoaderNode;

View File

@ -22,9 +22,9 @@ export const addVAEToGraph = (
const vae_model = modelIdToVAEModelField(vae?.id || ''); const vae_model = modelIdToVAEModelField(vae?.id || '');
const isAutoVae = !vae; const isAutoVae = !vae;
const metadataAccumulator = graph.nodes[ const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
METADATA_ACCUMULATOR | MetadataAccumulatorInvocation
] as MetadataAccumulatorInvocation; | undefined;
if (!isAutoVae) { if (!isAutoVae) {
graph.nodes[VAE_LOADER] = { graph.nodes[VAE_LOADER] = {
@ -73,7 +73,7 @@ export const addVAEToGraph = (
}); });
} }
if (vae) { if (vae && metadataAccumulator) {
metadataAccumulator.vae = vae_model; metadataAccumulator.vae = vae_model;
} }
}; };

View File

@ -7,7 +7,6 @@ import {
RandomIntInvocation, RandomIntInvocation,
RangeOfSizeInvocation, RangeOfSizeInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { import {
@ -35,7 +34,7 @@ export const buildCanvasInpaintGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: currentModel, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -53,14 +52,17 @@ export const buildCanvasInpaintGraph = (
clipSkip, clipSkip,
} = state.generation; } = state.generation;
if (!model) {
moduleLog.error('No model found in state');
throw new Error('No model found in state');
}
// The bounding box determines width and height, not the width and height params // The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions; const { width, height } = state.canvas.boundingBoxDimensions;
// We may need to set the inpaint width and height to scale the image // We may need to set the inpaint width and height to scale the image
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas; const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const model = modelIdToMainModelField(currentModel?.id || '');
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
id: INPAINT_GRAPH, id: INPAINT_GRAPH,
nodes: { nodes: {

View File

@ -13,7 +13,7 @@ import {
buildOutputFieldTemplates, buildOutputFieldTemplates,
} from './fieldTemplateBuilders'; } from './fieldTemplateBuilders';
const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate', 'core_metadata']; const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate', 'metadata'];
const invocationDenylist = [ const invocationDenylist = [
'Graph', 'Graph',

View File

@ -170,7 +170,7 @@ export const isValidVaeModel = (val: unknown): val is VaeModelParam =>
*/ */
export const zLoRAModel = z.object({ export const zLoRAModel = z.object({
id: z.string(), id: z.string(),
name: z.string(), model_name: z.string(),
base_model: zBaseModel, base_model: zBaseModel,
}); });
/** /**