Merge branch 'main' into fix-inpainting

This commit is contained in:
blessedcoolant 2023-07-13 22:31:08 +12:00
commit 16f53228c2
13 changed files with 47 additions and 41 deletions

View File

@ -37,7 +37,7 @@ const ParamLora = (props: Props) => {
return (
<Flex sx={{ gap: 2.5, alignItems: 'flex-end' }}>
<IAISlider
label={lora.name}
label={lora.model_name}
value={lora.weight}
onChange={handleChange}
min={-1}

View File

@ -18,7 +18,7 @@ const selector = createSelector(
const ParamLoraList = () => {
const { loras } = useAppSelector(selector);
return map(loras, (lora) => <ParamLora key={lora.name} lora={lora} />);
return map(loras, (lora) => <ParamLora key={lora.model_name} lora={lora} />);
};
export default ParamLoraList;

View File

@ -1,12 +1,8 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { LoRAModelParam } from 'features/parameters/store/parameterZodSchemas';
import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
import { BaseModelType } from 'services/api/types';
export type Lora = {
id: string;
base_model: BaseModelType;
name: string;
export type Lora = LoRAModelParam & {
weight: number;
};
@ -27,8 +23,8 @@ export const loraSlice = createSlice({
initialState: intialLoraState,
reducers: {
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
const { name, id, base_model } = action.payload;
state.loras[id] = { id, name, base_model, ...defaultLoRAConfig };
const { model_name, id, base_model } = action.payload;
state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig };
},
loraRemoved: (state, action: PayloadAction<string>) => {
const id = action.payload;

View File

@ -45,7 +45,7 @@ const LoRAModelInputFieldComponent = (
data.push({
value: id,
label: model.name,
label: model.model_name,
group: BASE_MODEL_NAME_MAP[model.base_model],
});
});

View File

@ -38,7 +38,7 @@ const ModelInputFieldComponent = (
data.push({
value: id,
label: model.name,
label: model.model_name,
group: BASE_MODEL_NAME_MAP[model.base_model],
});
});

View File

@ -45,7 +45,7 @@ const VaeModelInputFieldComponent = (
data.push({
value: id,
label: model.name,
label: model.model_name,
group: BASE_MODEL_NAME_MAP[model.base_model],
});
});

View File

@ -19,9 +19,9 @@ export const addControlNetToLinearGraph = (
const validControlNets = getValidControlNets(controlNets);
const metadataAccumulator = graph.nodes[
METADATA_ACCUMULATOR
] as MetadataAccumulatorInvocation;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (isControlNetEnabled && Boolean(validControlNets.length)) {
if (validControlNets.length) {
@ -79,13 +79,15 @@ export const addControlNetToLinearGraph = (
graph.nodes[controlNetNode.id] = controlNetNode;
// metadata accumulator only needs a control field - not the whole node
// extract what we need and add to the accumulator
const controlField = omit(controlNetNode, [
'id',
'type',
]) as ControlField;
metadataAccumulator.controlnets.push(controlField);
if (metadataAccumulator) {
// metadata accumulator only needs a control field - not the whole node
// extract what we need and add to the accumulator
const controlField = omit(controlNetNode, [
'id',
'type',
]) as ControlField;
metadataAccumulator.controlnets.push(controlField);
}
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },

View File

@ -32,9 +32,9 @@ export const addDynamicPromptsToGraph = (
maxPrompts,
} = state.dynamicPrompts;
const metadataAccumulator = graph.nodes[
METADATA_ACCUMULATOR
] as MetadataAccumulatorInvocation;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (isDynamicPromptsEnabled) {
// iteration is handled via dynamic prompts
@ -116,11 +116,15 @@ export const addDynamicPromptsToGraph = (
(graph.nodes[NOISE] as NoiseInvocation).seed = seed;
// hook up seed to metadata
metadataAccumulator.seed = seed;
if (metadataAccumulator) {
metadataAccumulator.seed = seed;
}
}
} else {
// no dynamic prompt - hook up positive prompt
metadataAccumulator.positive_prompt = positivePrompt;
if (metadataAccumulator) {
metadataAccumulator.positive_prompt = positivePrompt;
}
const rangeOfSizeNode: RangeOfSizeInvocation = {
id: RANGE_OF_SIZE,

View File

@ -30,9 +30,9 @@ export const addLoRAsToGraph = (
const { loras } = state.lora;
const loraCount = size(loras);
const metadataAccumulator = graph.nodes[
METADATA_ACCUMULATOR
] as MetadataAccumulatorInvocation;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (loraCount > 0) {
// Remove MAIN_MODEL_LOADER unet connection to feed it to LoRAs
@ -70,7 +70,9 @@ export const addLoRAsToGraph = (
};
// add the lora to the metadata accumulator
metadataAccumulator.loras.push({ lora: loraField, weight });
if (metadataAccumulator) {
metadataAccumulator.loras.push({ lora: loraField, weight });
}
// add to graph
graph.nodes[currentLoraNodeId] = loraLoaderNode;

View File

@ -22,9 +22,9 @@ export const addVAEToGraph = (
const vae_model = modelIdToVAEModelField(vae?.id || '');
const isAutoVae = !vae;
const metadataAccumulator = graph.nodes[
METADATA_ACCUMULATOR
] as MetadataAccumulatorInvocation;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (!isAutoVae) {
graph.nodes[VAE_LOADER] = {
@ -73,7 +73,7 @@ export const addVAEToGraph = (
});
}
if (vae) {
if (vae && metadataAccumulator) {
metadataAccumulator.vae = vae_model;
}
};

View File

@ -7,7 +7,6 @@ import {
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api/types';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import {
@ -35,7 +34,7 @@ export const buildCanvasInpaintGraph = (
const {
positivePrompt,
negativePrompt,
model: currentModel,
model,
cfgScale: cfg_scale,
scheduler,
steps,
@ -53,14 +52,17 @@ export const buildCanvasInpaintGraph = (
clipSkip,
} = state.generation;
if (!model) {
moduleLog.error('No model found in state');
throw new Error('No model found in state');
}
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
// We may need to set the inpaint width and height to scale the image
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const model = modelIdToMainModelField(currentModel?.id || '');
const graph: NonNullableGraph = {
id: INPAINT_GRAPH,
nodes: {

View File

@ -13,7 +13,7 @@ import {
buildOutputFieldTemplates,
} from './fieldTemplateBuilders';
const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate', 'core_metadata'];
const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate', 'metadata'];
const invocationDenylist = [
'Graph',

View File

@ -170,7 +170,7 @@ export const isValidVaeModel = (val: unknown): val is VaeModelParam =>
*/
export const zLoRAModel = z.object({
id: z.string(),
name: z.string(),
model_name: z.string(),
base_model: zBaseModel,
});
/**