feat(ui): update model identifier to be key (wip)

- Update most model identifiers to be `{key: string}` instead of name/base/type. Doesn't change the model select components yet.
- Update model _parameters_, stored in redux, to be `{key: string, base: BaseModel}` - we need to store the base model to be able to check model compatibility. May want to store the whole config? Not sure...
This commit is contained in:
psychedelicious
2024-02-16 18:56:02 +11:00
parent 6df3c450e8
commit dab939f7d1
54 changed files with 267 additions and 453 deletions

View File

@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata';
export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
const validControlNets = selectValidControlNets(state.controlAdapters).filter(
(ca) => ca.model?.base_model === state.generation.model?.base_model
(ca) => ca.model?.base === state.generation.model?.base
);
// const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as

View File

@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata';
export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(
(ca) => ca.model?.base_model === state.generation.model?.base_model
(ca) => ca.model?.base === state.generation.model?.base
);
if (validIPAdapters.length) {

View File

@ -28,6 +28,7 @@ export const addLoRAsToGraph = (
* So we need to inject a LoRA chain into the graph.
*/
// TODO(MM2): check base model
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
const loraCount = size(enabledLoRAs);
@ -48,19 +49,19 @@ export const addLoRAsToGraph = (
const loraMetadata: CoreMetadataInvocation['loras'] = [];
enabledLoRAs.forEach((lora) => {
const { model_name, base_model, weight } = lora;
const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;
const { key, weight } = lora;
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
const loraLoaderNode: LoraLoaderInvocation = {
type: 'lora_loader',
id: currentLoraNodeId,
is_intermediate: true,
lora: { model_name, base_model },
lora: { key },
weight,
};
loraMetadata.push({
lora: { model_name, base_model },
lora: { key },
weight,
});

View File

@ -31,6 +31,7 @@ export const addSDXLLoRAsToGraph = (
* So we need to inject a LoRA chain into the graph.
*/
// TODO(MM2): check base model
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
const loraCount = size(enabledLoRAs);
@ -60,20 +61,20 @@ export const addSDXLLoRAsToGraph = (
let currentLoraIndex = 0;
enabledLoRAs.forEach((lora) => {
const { model_name, base_model, weight } = lora;
const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;
const { key, weight } = lora;
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
const loraLoaderNode: SDXLLoraLoaderInvocation = {
type: 'sdxl_lora_loader',
id: currentLoraNodeId,
is_intermediate: true,
lora: { model_name, base_model },
lora: { key },
weight,
};
loraMetadata.push(
zLoRAMetadataItem.parse({
lora: { model_name, base_model },
lora: { key },
weight,
})
);

View File

@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata';
export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter(
(ca) => ca.model?.base_model === state.generation.model?.base_model
(ca) => ca.model?.base === state.generation.model?.base
);
if (validT2IAdapters.length) {

View File

@ -19,7 +19,7 @@ export const buildCanvasGraph = (
let graph: NonNullableGraph;
if (generationMode === 'txt2img') {
if (state.generation.model && state.generation.model.base_model === 'sdxl') {
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = buildCanvasSDXLTextToImageGraph(state);
} else {
graph = buildCanvasTextToImageGraph(state);
@ -28,7 +28,7 @@ export const buildCanvasGraph = (
if (!canvasInitImage) {
throw new Error('Missing canvas init image');
}
if (state.generation.model && state.generation.model.base_model === 'sdxl') {
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = buildCanvasSDXLImageToImageGraph(state, canvasInitImage);
} else {
graph = buildCanvasImageToImageGraph(state, canvasInitImage);
@ -37,7 +37,7 @@ export const buildCanvasGraph = (
if (!canvasInitImage || !canvasMaskImage) {
throw new Error('Missing canvas init and mask images');
}
if (state.generation.model && state.generation.model.base_model === 'sdxl') {
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage);
} else {
graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
@ -46,7 +46,7 @@ export const buildCanvasGraph = (
if (!canvasInitImage) {
throw new Error('Missing canvas init image');
}
if (state.generation.model && state.generation.model.base_model === 'sdxl') {
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage);
} else {
graph = buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage);

View File

@ -105,7 +105,7 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
});
}
if (shouldConcatSDXLStylePrompt && model?.base_model === 'sdxl') {
if (shouldConcatSDXLStylePrompt && model?.base === 'sdxl') {
if (graph.nodes[POSITIVE_CONDITIONING]) {
firstBatchDatumList.push({
node_path: POSITIVE_CONDITIONING,