mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): handle base model compat when recalling parameters
We had a one-behind issue with recalling metadata items that had a model. For example, when recalling LoRAs, we check against the current main model to decide whether or not the requested LoRA is compatible and may be recalled. When recalling all params, we are often also recalling the main model, but the compat logic didn't compare against this new main model. The logic is updated to check against the new main model, if one is being set. Closes #5512
This commit is contained in:
parent
022b32c724
commit
b76d2cd716
@ -47,6 +47,7 @@ import {
|
|||||||
vaeSelected,
|
vaeSelected,
|
||||||
widthChanged,
|
widthChanged,
|
||||||
} from 'features/parameters/store/generationSlice';
|
} from 'features/parameters/store/generationSlice';
|
||||||
|
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||||
import {
|
import {
|
||||||
isParameterCFGRescaleMultiplier,
|
isParameterCFGRescaleMultiplier,
|
||||||
isParameterCFGScale,
|
isParameterCFGScale,
|
||||||
@ -480,7 +481,7 @@ export const useRecallParameters = () => {
|
|||||||
const { data: loraModels } = useGetLoRAModelsQuery(undefined);
|
const { data: loraModels } = useGetLoRAModelsQuery(undefined);
|
||||||
|
|
||||||
const prepareLoRAMetadataItem = useCallback(
|
const prepareLoRAMetadataItem = useCallback(
|
||||||
(loraMetadataItem: LoRAMetadataItem) => {
|
(loraMetadataItem: LoRAMetadataItem, newModel?: ParameterModel) => {
|
||||||
if (!isParameterLoRAModel(loraMetadataItem.lora)) {
|
if (!isParameterLoRAModel(loraMetadataItem.lora)) {
|
||||||
return { lora: null, error: 'Invalid LoRA model' };
|
return { lora: null, error: 'Invalid LoRA model' };
|
||||||
}
|
}
|
||||||
@ -499,7 +500,7 @@ export const useRecallParameters = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const isCompatibleBaseModel =
|
const isCompatibleBaseModel =
|
||||||
matchingLoRA?.base_model === model?.base_model;
|
matchingLoRA?.base_model === (newModel ?? model)?.base_model;
|
||||||
|
|
||||||
if (!isCompatibleBaseModel) {
|
if (!isCompatibleBaseModel) {
|
||||||
return {
|
return {
|
||||||
@ -510,7 +511,7 @@ export const useRecallParameters = () => {
|
|||||||
|
|
||||||
return { lora: matchingLoRA, error: null };
|
return { lora: matchingLoRA, error: null };
|
||||||
},
|
},
|
||||||
[loraModels, model?.base_model]
|
[loraModels, model]
|
||||||
);
|
);
|
||||||
|
|
||||||
const recallLoRA = useCallback(
|
const recallLoRA = useCallback(
|
||||||
@ -538,7 +539,10 @@ export const useRecallParameters = () => {
|
|||||||
const { data: controlNetModels } = useGetControlNetModelsQuery(undefined);
|
const { data: controlNetModels } = useGetControlNetModelsQuery(undefined);
|
||||||
|
|
||||||
const prepareControlNetMetadataItem = useCallback(
|
const prepareControlNetMetadataItem = useCallback(
|
||||||
(controlnetMetadataItem: ControlNetMetadataItem) => {
|
(
|
||||||
|
controlnetMetadataItem: ControlNetMetadataItem,
|
||||||
|
newModel?: ParameterModel
|
||||||
|
) => {
|
||||||
if (!isParameterControlNetModel(controlnetMetadataItem.control_model)) {
|
if (!isParameterControlNetModel(controlnetMetadataItem.control_model)) {
|
||||||
return { controlnet: null, error: 'Invalid ControlNet model' };
|
return { controlnet: null, error: 'Invalid ControlNet model' };
|
||||||
}
|
}
|
||||||
@ -565,7 +569,7 @@ export const useRecallParameters = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const isCompatibleBaseModel =
|
const isCompatibleBaseModel =
|
||||||
matchingControlNetModel?.base_model === model?.base_model;
|
matchingControlNetModel?.base_model === (newModel ?? model)?.base_model;
|
||||||
|
|
||||||
if (!isCompatibleBaseModel) {
|
if (!isCompatibleBaseModel) {
|
||||||
return {
|
return {
|
||||||
@ -600,7 +604,7 @@ export const useRecallParameters = () => {
|
|||||||
|
|
||||||
return { controlnet, error: null };
|
return { controlnet, error: null };
|
||||||
},
|
},
|
||||||
[controlNetModels, model?.base_model]
|
[controlNetModels, model]
|
||||||
);
|
);
|
||||||
|
|
||||||
const recallControlNet = useCallback(
|
const recallControlNet = useCallback(
|
||||||
@ -631,7 +635,10 @@ export const useRecallParameters = () => {
|
|||||||
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery(undefined);
|
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery(undefined);
|
||||||
|
|
||||||
const prepareT2IAdapterMetadataItem = useCallback(
|
const prepareT2IAdapterMetadataItem = useCallback(
|
||||||
(t2iAdapterMetadataItem: T2IAdapterMetadataItem) => {
|
(
|
||||||
|
t2iAdapterMetadataItem: T2IAdapterMetadataItem,
|
||||||
|
newModel?: ParameterModel
|
||||||
|
) => {
|
||||||
if (
|
if (
|
||||||
!isParameterControlNetModel(t2iAdapterMetadataItem.t2i_adapter_model)
|
!isParameterControlNetModel(t2iAdapterMetadataItem.t2i_adapter_model)
|
||||||
) {
|
) {
|
||||||
@ -659,7 +666,7 @@ export const useRecallParameters = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const isCompatibleBaseModel =
|
const isCompatibleBaseModel =
|
||||||
matchingT2IAdapterModel?.base_model === model?.base_model;
|
matchingT2IAdapterModel?.base_model === (newModel ?? model)?.base_model;
|
||||||
|
|
||||||
if (!isCompatibleBaseModel) {
|
if (!isCompatibleBaseModel) {
|
||||||
return {
|
return {
|
||||||
@ -690,7 +697,7 @@ export const useRecallParameters = () => {
|
|||||||
|
|
||||||
return { t2iAdapter, error: null };
|
return { t2iAdapter, error: null };
|
||||||
},
|
},
|
||||||
[model?.base_model, t2iAdapterModels]
|
[model, t2iAdapterModels]
|
||||||
);
|
);
|
||||||
|
|
||||||
const recallT2IAdapter = useCallback(
|
const recallT2IAdapter = useCallback(
|
||||||
@ -721,7 +728,10 @@ export const useRecallParameters = () => {
|
|||||||
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery(undefined);
|
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery(undefined);
|
||||||
|
|
||||||
const prepareIPAdapterMetadataItem = useCallback(
|
const prepareIPAdapterMetadataItem = useCallback(
|
||||||
(ipAdapterMetadataItem: IPAdapterMetadataItem) => {
|
(
|
||||||
|
ipAdapterMetadataItem: IPAdapterMetadataItem,
|
||||||
|
newModel?: ParameterModel
|
||||||
|
) => {
|
||||||
if (!isParameterIPAdapterModel(ipAdapterMetadataItem?.ip_adapter_model)) {
|
if (!isParameterIPAdapterModel(ipAdapterMetadataItem?.ip_adapter_model)) {
|
||||||
return { ipAdapter: null, error: 'Invalid IP Adapter model' };
|
return { ipAdapter: null, error: 'Invalid IP Adapter model' };
|
||||||
}
|
}
|
||||||
@ -746,7 +756,7 @@ export const useRecallParameters = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const isCompatibleBaseModel =
|
const isCompatibleBaseModel =
|
||||||
matchingIPAdapterModel?.base_model === model?.base_model;
|
matchingIPAdapterModel?.base_model === (newModel ?? model)?.base_model;
|
||||||
|
|
||||||
if (!isCompatibleBaseModel) {
|
if (!isCompatibleBaseModel) {
|
||||||
return {
|
return {
|
||||||
@ -768,7 +778,7 @@ export const useRecallParameters = () => {
|
|||||||
|
|
||||||
return { ipAdapter, error: null };
|
return { ipAdapter, error: null };
|
||||||
},
|
},
|
||||||
[ipAdapterModels, model?.base_model]
|
[ipAdapterModels, model]
|
||||||
);
|
);
|
||||||
|
|
||||||
const recallIPAdapter = useCallback(
|
const recallIPAdapter = useCallback(
|
||||||
@ -840,6 +850,13 @@ export const useRecallParameters = () => {
|
|||||||
t2iAdapters,
|
t2iAdapters,
|
||||||
} = metadata;
|
} = metadata;
|
||||||
|
|
||||||
|
let newModel: ParameterModel | undefined = undefined;
|
||||||
|
|
||||||
|
if (isParameterModel(model)) {
|
||||||
|
newModel = model;
|
||||||
|
dispatch(modelSelected(model));
|
||||||
|
}
|
||||||
|
|
||||||
if (isParameterCFGScale(cfg_scale)) {
|
if (isParameterCFGScale(cfg_scale)) {
|
||||||
dispatch(setCfgScale(cfg_scale));
|
dispatch(setCfgScale(cfg_scale));
|
||||||
}
|
}
|
||||||
@ -848,10 +865,6 @@ export const useRecallParameters = () => {
|
|||||||
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
|
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isParameterModel(model)) {
|
|
||||||
dispatch(modelSelected(model));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isParameterPositivePrompt(positive_prompt)) {
|
if (isParameterPositivePrompt(positive_prompt)) {
|
||||||
dispatch(setPositivePrompt(positive_prompt));
|
dispatch(setPositivePrompt(positive_prompt));
|
||||||
}
|
}
|
||||||
@ -953,7 +966,7 @@ export const useRecallParameters = () => {
|
|||||||
|
|
||||||
dispatch(lorasCleared());
|
dispatch(lorasCleared());
|
||||||
loras?.forEach((lora) => {
|
loras?.forEach((lora) => {
|
||||||
const result = prepareLoRAMetadataItem(lora);
|
const result = prepareLoRAMetadataItem(lora, newModel);
|
||||||
if (result.lora) {
|
if (result.lora) {
|
||||||
dispatch(loraRecalled({ ...result.lora, weight: lora.weight }));
|
dispatch(loraRecalled({ ...result.lora, weight: lora.weight }));
|
||||||
}
|
}
|
||||||
@ -961,21 +974,21 @@ export const useRecallParameters = () => {
|
|||||||
|
|
||||||
dispatch(controlAdaptersReset());
|
dispatch(controlAdaptersReset());
|
||||||
controlnets?.forEach((controlnet) => {
|
controlnets?.forEach((controlnet) => {
|
||||||
const result = prepareControlNetMetadataItem(controlnet);
|
const result = prepareControlNetMetadataItem(controlnet, newModel);
|
||||||
if (result.controlnet) {
|
if (result.controlnet) {
|
||||||
dispatch(controlAdapterRecalled(result.controlnet));
|
dispatch(controlAdapterRecalled(result.controlnet));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
ipAdapters?.forEach((ipAdapter) => {
|
ipAdapters?.forEach((ipAdapter) => {
|
||||||
const result = prepareIPAdapterMetadataItem(ipAdapter);
|
const result = prepareIPAdapterMetadataItem(ipAdapter, newModel);
|
||||||
if (result.ipAdapter) {
|
if (result.ipAdapter) {
|
||||||
dispatch(controlAdapterRecalled(result.ipAdapter));
|
dispatch(controlAdapterRecalled(result.ipAdapter));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
t2iAdapters?.forEach((t2iAdapter) => {
|
t2iAdapters?.forEach((t2iAdapter) => {
|
||||||
const result = prepareT2IAdapterMetadataItem(t2iAdapter);
|
const result = prepareT2IAdapterMetadataItem(t2iAdapter, newModel);
|
||||||
if (result.t2iAdapter) {
|
if (result.t2iAdapter) {
|
||||||
dispatch(controlAdapterRecalled(result.t2iAdapter));
|
dispatch(controlAdapterRecalled(result.t2iAdapter));
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user