From 97ecd99b9cbe764d6c3fd57e549bda52f7c1c149 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 27 Feb 2024 15:09:42 +1100 Subject: [PATCH] fix(ui): fix up MM queries & types (wip) --- .../AddModelPanel/AdvancedImport.tsx | 19 ++++----- .../ScanModels/ScanModelResultItem.tsx | 8 ++-- .../subpanels/AddModelPanel/SimpleImport.tsx | 8 ++-- .../web/src/services/api/endpoints/models.ts | 41 ++++++------------- .../frontend/web/src/services/api/types.ts | 14 ++++++- 5 files changed, 40 insertions(+), 50 deletions(-) diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/AdvancedImport.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/AdvancedImport.tsx index 8f474537cc..e115c68dca 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/AdvancedImport.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/AdvancedImport.tsx @@ -10,17 +10,18 @@ import PredictionTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/F import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect'; import { addToast } from 'features/system/store/systemSlice'; import { makeToast } from 'features/system/util/makeToast'; +import { isNil, omitBy } from 'lodash-es'; import { useCallback, useEffect } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import { useImportAdvancedModelMutation } from 'services/api/endpoints/models'; +import { useInstallModelMutation } from 'services/api/endpoints/models'; import type { AnyModelConfig } from 'services/api/types'; export const AdvancedImport = () => { const dispatch = useAppDispatch(); - const [importAdvancedModel] = useImportAdvancedModelMutation(); + const [installModel] = useInstallModelMutation(); const { t } = useTranslation(); @@ -49,15 +50,9 @@ export const AdvancedImport = () => { const onSubmit = useCallback>( (values) => { - const cleanValues = Object.fromEntries( - Object.entries(values).filter(([value]) => value !== null && value !== undefined) - ); - importAdvancedModel({ - source: { - path: cleanValues.path as string, - type: 'local', - }, - config: cleanValues, + installModel({ + source: values.path, + config: omitBy(values, isNil), }) .unwrap() .then((_) => { @@ -86,7 +81,7 @@ export const AdvancedImport = () => { } }); }, - [dispatch, reset, t, importAdvancedModel] + [installModel, dispatch, t, reset] ); const watchedModelType = watch('type'); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ScanModels/ScanModelResultItem.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ScanModels/ScanModelResultItem.tsx index 6ae7573665..af3cb33a0c 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ScanModels/ScanModelResultItem.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ScanModels/ScanModelResultItem.tsx @@ -6,7 +6,7 @@ import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { IoAdd } from 'react-icons/io5'; import type { ScanFolderResponse } from 'services/api/endpoints/models'; -import { useImportMainModelsMutation } from 'services/api/endpoints/models'; +import { useInstallModelMutation } from 'services/api/endpoints/models'; type Props = { result: ScanFolderResponse[number]; @@ -15,10 +15,10 @@ export const ScanModelResultItem = ({ result }: Props) => { const { t } = useTranslation(); const dispatch = useAppDispatch(); - const [importMainModel] = useImportMainModelsMutation(); + const [installModel] = useInstallModelMutation(); const handleQuickAdd = useCallback(() => { - importMainModel({ source: result.path, config: undefined }) + installModel({ source: result.path }) .unwrap() .then((_) => { dispatch( @@ -42,7 +42,7 @@ export const ScanModelResultItem = ({ result }: Props) => { ); } }); - }, [importMainModel, result, dispatch, t]); + }, [installModel, result, dispatch, t]); return ( diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/SimpleImport.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/SimpleImport.tsx index c056048c5b..6f704abf4f 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/SimpleImport.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/SimpleImport.tsx @@ -6,7 +6,7 @@ import { t } from 'i18next'; import { useCallback } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; -import { useImportMainModelsMutation } from 'services/api/endpoints/models'; +import { useInstallModelMutation } from 'services/api/endpoints/models'; type SimpleImportModelConfig = { location: string; @@ -15,7 +15,7 @@ type SimpleImportModelConfig = { export const SimpleImport = () => { const dispatch = useAppDispatch(); - const [importMainModel, { isLoading }] = useImportMainModelsMutation(); + const [installModel, { isLoading }] = useInstallModelMutation(); const { register, handleSubmit, formState, reset } = useForm({ defaultValues: { @@ -30,7 +30,7 @@ export const SimpleImport = () => { return; } - importMainModel({ source: values.location, config: undefined }) + installModel({ source: values.location }) .unwrap() .then((_) => { dispatch( @@ -57,7 +57,7 @@ export const SimpleImport = () => { } }); }, - [dispatch, reset, importMainModel] + [dispatch, reset, installModel] ); return ( diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 2d1d021bc6..827356cd22 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -1,6 +1,7 @@ import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit'; import { createEntityAdapter } from '@reduxjs/toolkit'; import { getSelectorsOptions } from 'app/store/createMemoizedSelector'; +import type { JSONObject } from 'common/types'; import queryString from 'query-string'; import type { operations, paths } from 'services/api/schema'; import type { @@ -19,8 +20,8 @@ import type { ApiTagDescription, tagTypes } from '..'; import { api, buildV2Url, LIST_TAG } from '..'; type UpdateModelArg = { - key: NonNullable; - body: NonNullable; + key: paths['/api/v2/models/i/{key}']['patch']['parameters']['path']['key']; + body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json']; }; type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json']; @@ -40,14 +41,15 @@ type DeleteMainModelResponse = void; type ConvertMainModelResponse = paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json']; -type ImportMainModelArg = { - source: NonNullable; - access_token?: operations['heuristic_install_model']['parameters']['query']['access_token']; - config: NonNullable; +type InstallModelArg = { + source: paths['/api/v2/models/install']['post']['parameters']['query']['source']; + access_token?: paths['/api/v2/models/install']['post']['parameters']['query']['access_token']; + // TODO(MM2): This is typed as `Optional[Dict[str, Any]]` in backend... + config?: JSONObject; + // config: NonNullable['content']['application/json']; }; -type ImportMainModelResponse = - paths['/api/v2/models/heuristic_install']['post']['responses']['201']['content']['application/json']; +type InstallModelResponse = paths['/api/v2/models/install']['post']['responses']['201']['content']['application/json']; type ListImportModelsResponse = paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json']; @@ -58,14 +60,6 @@ type DeleteImportModelsResponse = type PruneModelImportsResponse = paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json']; -type ImportAdvancedModelArg = { - source: NonNullable; - config: NonNullable; -}; - -type ImportAdvancedModelResponse = - paths['/api/v2/models/import']['post']['responses']['201']['content']['application/json']; - export type ScanFolderResponse = paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json']; type ScanFolderArg = operations['scan_for_models']['parameters']['query']; @@ -183,7 +177,7 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: ['Model'], }), - importMainModels: build.mutation({ + installModel: build.mutation({ query: ({ source, config, access_token }) => { return { url: buildModelsUrl('heuristic_install'), @@ -194,16 +188,6 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: ['Model', 'ModelImports'], }), - importAdvancedModel: build.mutation({ - query: ({ source, config }) => { - return { - url: buildModelsUrl('install'), - method: 'POST', - body: { source, config }, - }; - }, - invalidatesTags: ['Model', 'ModelImports'], - }), deleteModels: build.mutation({ query: ({ key }) => { return { @@ -365,8 +349,7 @@ export const { useGetVaeModelsQuery, useDeleteModelsMutation, useUpdateModelsMutation, - useImportMainModelsMutation, - useImportAdvancedModelMutation, + useInstallModelMutation, useConvertMainModelsMutation, useSyncModelsMutation, useScanModelsQuery, diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index cb848e168e..2ecf0a7a09 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -70,6 +70,7 @@ export type T2IAdapterModelConfig = S['T2IConfig']; export type TextualInversionModelConfig = S['TextualInversionConfig']; export type DiffusersModelConfig = S['MainDiffusersConfig']; export type CheckpointModelConfig = S['MainCheckpointConfig']; +type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig']; export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig; export type RefinerMainModelConfig = Omit & { base: 'sdxl-refiner' }; export type NonRefinerMainModelConfig = Omit & { base: 'any' | 'sd-1' | 'sd-2' | 'sdxl' }; @@ -81,7 +82,18 @@ export type AnyModelConfig = | T2IAdapterModelConfig | TextualInversionModelConfig | RefinerMainModelConfig - | NonRefinerMainModelConfig; + | NonRefinerMainModelConfig + | CLIPVisionDiffusersConfig; + +type AnyModelConfig2 = + | (S['MainDiffusersConfig'] | S['MainCheckpointConfig']) + | (S['VaeDiffusersConfig'] | S['VaeCheckpointConfig']) + | (S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig']) + | S['LoRAConfig'] + | S['TextualInversionConfig'] + | S['IPAdapterConfig'] + | S['CLIPVisionDiffusersConfig'] + | S['T2IConfig']; export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelConfig => { return config.type === 'lora';