mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): update model identifier to be key (wip)
- Update most model identifiers to be `{key: string}` instead of name/base/type. Doesn't change the model select components yet. - Update model _parameters_, stored in redux, to be `{key: string, base: BaseModel}` - we need to store the base model to be able to check model compatibility. May want to store the whole config? Not sure...
This commit is contained in:
@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
|
||||
const validControlNets = selectValidControlNets(state.controlAdapters).filter(
|
||||
(ca) => ca.model?.base_model === state.generation.model?.base_model
|
||||
(ca) => ca.model?.base === state.generation.model?.base
|
||||
);
|
||||
|
||||
// const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
|
@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
|
||||
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(
|
||||
(ca) => ca.model?.base_model === state.generation.model?.base_model
|
||||
(ca) => ca.model?.base === state.generation.model?.base
|
||||
);
|
||||
|
||||
if (validIPAdapters.length) {
|
||||
|
@ -28,6 +28,7 @@ export const addLoRAsToGraph = (
|
||||
* So we need to inject a LoRA chain into the graph.
|
||||
*/
|
||||
|
||||
// TODO(MM2): check base model
|
||||
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
|
||||
const loraCount = size(enabledLoRAs);
|
||||
|
||||
@ -48,19 +49,19 @@ export const addLoRAsToGraph = (
|
||||
const loraMetadata: CoreMetadataInvocation['loras'] = [];
|
||||
|
||||
enabledLoRAs.forEach((lora) => {
|
||||
const { model_name, base_model, weight } = lora;
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;
|
||||
const { key, weight } = lora;
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
||||
|
||||
const loraLoaderNode: LoraLoaderInvocation = {
|
||||
type: 'lora_loader',
|
||||
id: currentLoraNodeId,
|
||||
is_intermediate: true,
|
||||
lora: { model_name, base_model },
|
||||
lora: { key },
|
||||
weight,
|
||||
};
|
||||
|
||||
loraMetadata.push({
|
||||
lora: { model_name, base_model },
|
||||
lora: { key },
|
||||
weight,
|
||||
});
|
||||
|
||||
|
@ -31,6 +31,7 @@ export const addSDXLLoRAsToGraph = (
|
||||
* So we need to inject a LoRA chain into the graph.
|
||||
*/
|
||||
|
||||
// TODO(MM2): check base model
|
||||
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
|
||||
const loraCount = size(enabledLoRAs);
|
||||
|
||||
@ -60,20 +61,20 @@ export const addSDXLLoRAsToGraph = (
|
||||
let currentLoraIndex = 0;
|
||||
|
||||
enabledLoRAs.forEach((lora) => {
|
||||
const { model_name, base_model, weight } = lora;
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;
|
||||
const { key, weight } = lora;
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
||||
|
||||
const loraLoaderNode: SDXLLoraLoaderInvocation = {
|
||||
type: 'sdxl_lora_loader',
|
||||
id: currentLoraNodeId,
|
||||
is_intermediate: true,
|
||||
lora: { model_name, base_model },
|
||||
lora: { key },
|
||||
weight,
|
||||
};
|
||||
|
||||
loraMetadata.push(
|
||||
zLoRAMetadataItem.parse({
|
||||
lora: { model_name, base_model },
|
||||
lora: { key },
|
||||
weight,
|
||||
})
|
||||
);
|
||||
|
@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
|
||||
const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter(
|
||||
(ca) => ca.model?.base_model === state.generation.model?.base_model
|
||||
(ca) => ca.model?.base === state.generation.model?.base
|
||||
);
|
||||
|
||||
if (validT2IAdapters.length) {
|
||||
|
@ -19,7 +19,7 @@ export const buildCanvasGraph = (
|
||||
let graph: NonNullableGraph;
|
||||
|
||||
if (generationMode === 'txt2img') {
|
||||
if (state.generation.model && state.generation.model.base_model === 'sdxl') {
|
||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||
graph = buildCanvasSDXLTextToImageGraph(state);
|
||||
} else {
|
||||
graph = buildCanvasTextToImageGraph(state);
|
||||
@ -28,7 +28,7 @@ export const buildCanvasGraph = (
|
||||
if (!canvasInitImage) {
|
||||
throw new Error('Missing canvas init image');
|
||||
}
|
||||
if (state.generation.model && state.generation.model.base_model === 'sdxl') {
|
||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||
graph = buildCanvasSDXLImageToImageGraph(state, canvasInitImage);
|
||||
} else {
|
||||
graph = buildCanvasImageToImageGraph(state, canvasInitImage);
|
||||
@ -37,7 +37,7 @@ export const buildCanvasGraph = (
|
||||
if (!canvasInitImage || !canvasMaskImage) {
|
||||
throw new Error('Missing canvas init and mask images');
|
||||
}
|
||||
if (state.generation.model && state.generation.model.base_model === 'sdxl') {
|
||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||
graph = buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
} else {
|
||||
graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
@ -46,7 +46,7 @@ export const buildCanvasGraph = (
|
||||
if (!canvasInitImage) {
|
||||
throw new Error('Missing canvas init image');
|
||||
}
|
||||
if (state.generation.model && state.generation.model.base_model === 'sdxl') {
|
||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||
graph = buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
} else {
|
||||
graph = buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
|
@ -105,7 +105,7 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
|
||||
});
|
||||
}
|
||||
|
||||
if (shouldConcatSDXLStylePrompt && model?.base_model === 'sdxl') {
|
||||
if (shouldConcatSDXLStylePrompt && model?.base === 'sdxl') {
|
||||
if (graph.nodes[POSITIVE_CONDITIONING]) {
|
||||
firstBatchDatumList.push({
|
||||
node_path: POSITIVE_CONDITIONING,
|
||||
|
Reference in New Issue
Block a user