fix(ui): handle base model compat when recalling parameters

We had a one-behind issue with recalling metadata items that had a model.

For example, when recalling LoRAs, we check against the current main model to decide whether or not the requested LoRA is compatible and may be recalled.

When recalling all params, we are often also recalling the main model, but the compat logic didn't compare against this new main model.

The logic is updated to check against the new main model, if one is being set.

Closes #5512
This commit is contained in:
psychedelicious 2024-01-23 18:58:30 +11:00 committed by Kent Keirsey
parent 022b32c724
commit b76d2cd716

View File

@ -47,6 +47,7 @@ import {
vaeSelected, vaeSelected,
widthChanged, widthChanged,
} from 'features/parameters/store/generationSlice'; } from 'features/parameters/store/generationSlice';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import { import {
isParameterCFGRescaleMultiplier, isParameterCFGRescaleMultiplier,
isParameterCFGScale, isParameterCFGScale,
@ -480,7 +481,7 @@ export const useRecallParameters = () => {
const { data: loraModels } = useGetLoRAModelsQuery(undefined); const { data: loraModels } = useGetLoRAModelsQuery(undefined);
const prepareLoRAMetadataItem = useCallback( const prepareLoRAMetadataItem = useCallback(
(loraMetadataItem: LoRAMetadataItem) => { (loraMetadataItem: LoRAMetadataItem, newModel?: ParameterModel) => {
if (!isParameterLoRAModel(loraMetadataItem.lora)) { if (!isParameterLoRAModel(loraMetadataItem.lora)) {
return { lora: null, error: 'Invalid LoRA model' }; return { lora: null, error: 'Invalid LoRA model' };
} }
@ -499,7 +500,7 @@ export const useRecallParameters = () => {
} }
const isCompatibleBaseModel = const isCompatibleBaseModel =
matchingLoRA?.base_model === model?.base_model; matchingLoRA?.base_model === (newModel ?? model)?.base_model;
if (!isCompatibleBaseModel) { if (!isCompatibleBaseModel) {
return { return {
@ -510,7 +511,7 @@ export const useRecallParameters = () => {
return { lora: matchingLoRA, error: null }; return { lora: matchingLoRA, error: null };
}, },
[loraModels, model?.base_model] [loraModels, model]
); );
const recallLoRA = useCallback( const recallLoRA = useCallback(
@ -538,7 +539,10 @@ export const useRecallParameters = () => {
const { data: controlNetModels } = useGetControlNetModelsQuery(undefined); const { data: controlNetModels } = useGetControlNetModelsQuery(undefined);
const prepareControlNetMetadataItem = useCallback( const prepareControlNetMetadataItem = useCallback(
(controlnetMetadataItem: ControlNetMetadataItem) => { (
controlnetMetadataItem: ControlNetMetadataItem,
newModel?: ParameterModel
) => {
if (!isParameterControlNetModel(controlnetMetadataItem.control_model)) { if (!isParameterControlNetModel(controlnetMetadataItem.control_model)) {
return { controlnet: null, error: 'Invalid ControlNet model' }; return { controlnet: null, error: 'Invalid ControlNet model' };
} }
@ -565,7 +569,7 @@ export const useRecallParameters = () => {
} }
const isCompatibleBaseModel = const isCompatibleBaseModel =
matchingControlNetModel?.base_model === model?.base_model; matchingControlNetModel?.base_model === (newModel ?? model)?.base_model;
if (!isCompatibleBaseModel) { if (!isCompatibleBaseModel) {
return { return {
@ -600,7 +604,7 @@ export const useRecallParameters = () => {
return { controlnet, error: null }; return { controlnet, error: null };
}, },
[controlNetModels, model?.base_model] [controlNetModels, model]
); );
const recallControlNet = useCallback( const recallControlNet = useCallback(
@ -631,7 +635,10 @@ export const useRecallParameters = () => {
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery(undefined); const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery(undefined);
const prepareT2IAdapterMetadataItem = useCallback( const prepareT2IAdapterMetadataItem = useCallback(
(t2iAdapterMetadataItem: T2IAdapterMetadataItem) => { (
t2iAdapterMetadataItem: T2IAdapterMetadataItem,
newModel?: ParameterModel
) => {
if ( if (
!isParameterControlNetModel(t2iAdapterMetadataItem.t2i_adapter_model) !isParameterControlNetModel(t2iAdapterMetadataItem.t2i_adapter_model)
) { ) {
@ -659,7 +666,7 @@ export const useRecallParameters = () => {
} }
const isCompatibleBaseModel = const isCompatibleBaseModel =
matchingT2IAdapterModel?.base_model === model?.base_model; matchingT2IAdapterModel?.base_model === (newModel ?? model)?.base_model;
if (!isCompatibleBaseModel) { if (!isCompatibleBaseModel) {
return { return {
@ -690,7 +697,7 @@ export const useRecallParameters = () => {
return { t2iAdapter, error: null }; return { t2iAdapter, error: null };
}, },
[model?.base_model, t2iAdapterModels] [model, t2iAdapterModels]
); );
const recallT2IAdapter = useCallback( const recallT2IAdapter = useCallback(
@ -721,7 +728,10 @@ export const useRecallParameters = () => {
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery(undefined); const { data: ipAdapterModels } = useGetIPAdapterModelsQuery(undefined);
const prepareIPAdapterMetadataItem = useCallback( const prepareIPAdapterMetadataItem = useCallback(
(ipAdapterMetadataItem: IPAdapterMetadataItem) => { (
ipAdapterMetadataItem: IPAdapterMetadataItem,
newModel?: ParameterModel
) => {
if (!isParameterIPAdapterModel(ipAdapterMetadataItem?.ip_adapter_model)) { if (!isParameterIPAdapterModel(ipAdapterMetadataItem?.ip_adapter_model)) {
return { ipAdapter: null, error: 'Invalid IP Adapter model' }; return { ipAdapter: null, error: 'Invalid IP Adapter model' };
} }
@ -746,7 +756,7 @@ export const useRecallParameters = () => {
} }
const isCompatibleBaseModel = const isCompatibleBaseModel =
matchingIPAdapterModel?.base_model === model?.base_model; matchingIPAdapterModel?.base_model === (newModel ?? model)?.base_model;
if (!isCompatibleBaseModel) { if (!isCompatibleBaseModel) {
return { return {
@ -768,7 +778,7 @@ export const useRecallParameters = () => {
return { ipAdapter, error: null }; return { ipAdapter, error: null };
}, },
[ipAdapterModels, model?.base_model] [ipAdapterModels, model]
); );
const recallIPAdapter = useCallback( const recallIPAdapter = useCallback(
@ -840,6 +850,13 @@ export const useRecallParameters = () => {
t2iAdapters, t2iAdapters,
} = metadata; } = metadata;
let newModel: ParameterModel | undefined = undefined;
if (isParameterModel(model)) {
newModel = model;
dispatch(modelSelected(model));
}
if (isParameterCFGScale(cfg_scale)) { if (isParameterCFGScale(cfg_scale)) {
dispatch(setCfgScale(cfg_scale)); dispatch(setCfgScale(cfg_scale));
} }
@ -848,10 +865,6 @@ export const useRecallParameters = () => {
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier)); dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
} }
if (isParameterModel(model)) {
dispatch(modelSelected(model));
}
if (isParameterPositivePrompt(positive_prompt)) { if (isParameterPositivePrompt(positive_prompt)) {
dispatch(setPositivePrompt(positive_prompt)); dispatch(setPositivePrompt(positive_prompt));
} }
@ -953,7 +966,7 @@ export const useRecallParameters = () => {
dispatch(lorasCleared()); dispatch(lorasCleared());
loras?.forEach((lora) => { loras?.forEach((lora) => {
const result = prepareLoRAMetadataItem(lora); const result = prepareLoRAMetadataItem(lora, newModel);
if (result.lora) { if (result.lora) {
dispatch(loraRecalled({ ...result.lora, weight: lora.weight })); dispatch(loraRecalled({ ...result.lora, weight: lora.weight }));
} }
@ -961,21 +974,21 @@ export const useRecallParameters = () => {
dispatch(controlAdaptersReset()); dispatch(controlAdaptersReset());
controlnets?.forEach((controlnet) => { controlnets?.forEach((controlnet) => {
const result = prepareControlNetMetadataItem(controlnet); const result = prepareControlNetMetadataItem(controlnet, newModel);
if (result.controlnet) { if (result.controlnet) {
dispatch(controlAdapterRecalled(result.controlnet)); dispatch(controlAdapterRecalled(result.controlnet));
} }
}); });
ipAdapters?.forEach((ipAdapter) => { ipAdapters?.forEach((ipAdapter) => {
const result = prepareIPAdapterMetadataItem(ipAdapter); const result = prepareIPAdapterMetadataItem(ipAdapter, newModel);
if (result.ipAdapter) { if (result.ipAdapter) {
dispatch(controlAdapterRecalled(result.ipAdapter)); dispatch(controlAdapterRecalled(result.ipAdapter));
} }
}); });
t2iAdapters?.forEach((t2iAdapter) => { t2iAdapters?.forEach((t2iAdapter) => {
const result = prepareT2IAdapterMetadataItem(t2iAdapter); const result = prepareT2IAdapterMetadataItem(t2iAdapter, newModel);
if (result.t2iAdapter) { if (result.t2iAdapter) {
dispatch(controlAdapterRecalled(result.t2iAdapter)); dispatch(controlAdapterRecalled(result.t2iAdapter));
} }