feat(ui, api): prompt template export (#6745)

## Summary

Adds option to download all prompt templates to a CSV

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
This commit is contained in:
psychedelicious 2024-08-16 10:38:50 +10:00 committed by GitHub
commit 713bd11177
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 311 additions and 127 deletions

View File

@ -1,10 +1,11 @@
import csv
import io
import json
import traceback
from typing import Optional
import pydantic
from fastapi import APIRouter, File, Form, HTTPException, Path, UploadFile
from fastapi import APIRouter, File, Form, HTTPException, Path, Response, UploadFile
from fastapi.responses import FileResponse
from PIL import Image
from pydantic import BaseModel, Field
@ -230,6 +231,35 @@ async def get_style_preset_image(
raise HTTPException(status_code=404)
@style_presets_router.get(
"/export",
operation_id="export_style_presets",
responses={200: {"content": {"text/csv": {}}, "description": "A CSV file with the requested data."}},
status_code=200,
)
async def export_style_presets():
# Create an in-memory stream to store the CSV data
output = io.StringIO()
writer = csv.writer(output)
# Write the header
writer.writerow(["name", "prompt", "negative_prompt"])
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many(type=PresetType.User)
for preset in style_presets:
writer.writerow([preset.name, preset.preset_data.positive_prompt, preset.preset_data.negative_prompt])
csv_data = output.getvalue()
output.close()
return Response(
content=csv_data,
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=prompt_templates.csv"},
)
@style_presets_router.post(
"/import",
operation_id="import_style_presets",

View File

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from invokeai.app.services.style_preset_records.style_preset_records_common import (
PresetType,
StylePresetChanges,
StylePresetRecordDTO,
StylePresetWithoutId,
@ -36,6 +37,6 @@ class StylePresetRecordsStorageBase(ABC):
pass
@abstractmethod
def get_many(self) -> list[StylePresetRecordDTO]:
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
"""Gets many workflows."""
pass

View File

@ -128,9 +128,9 @@ async def parse_presets_from_file(file: UploadFile) -> list[StylePresetWithoutId
style_presets.append(style_preset)
except pydantic.ValidationError as e:
if file.content_type == "text/csv":
msg = "Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt'"
msg = "Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt' and name cannot be blank"
else: # file.content_type == "application/json":
msg = "Invalid JSON format: must be a list of objects with keys 'name', 'prompt', and 'negative_prompt'"
msg = "Invalid JSON format: must be a list of objects with keys 'name', 'prompt', and 'negative_prompt' and name cannot be blank"
raise InvalidPresetImportDataError(msg) from e
finally:
file.file.close()

View File

@ -5,6 +5,7 @@ from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
from invokeai.app.services.style_preset_records.style_preset_records_common import (
PresetType,
StylePresetChanges,
StylePresetNotFoundError,
StylePresetRecordDTO,
@ -159,19 +160,25 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
self._lock.release()
return None
def get_many(
self,
) -> list[StylePresetRecordDTO]:
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
try:
self._lock.acquire()
main_query = """
SELECT
*
FROM style_presets
ORDER BY LOWER(name) ASC
"""
if type is not None:
main_query += "WHERE type = ? "
main_query += "ORDER BY LOWER(name) ASC"
if type is not None:
self._cursor.execute(main_query, (type,))
else:
self._cursor.execute(main_query)
rows = self._cursor.fetchall()
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]

View File

@ -1701,14 +1701,21 @@
"deleteImage": "Delete Image",
"deleteTemplate": "Delete Template",
"deleteTemplate2": "Are you sure you want to delete this template? This cannot be undone.",
"exportPromptTemplates": "Export My Prompt Templates (CSV)",
"editTemplate": "Edit Template",
"exportDownloaded": "Export Downloaded",
"exportFailed": "Unable to generate and download CSV",
"flatten": "Flatten selected template into current prompt",
"importTemplates": "Import Prompt Templates",
"importTemplatesDesc": "Format must be either a CSV with columns: 'name', 'prompt' or 'positive_prompt', and 'negative_prompt' included, or a JSON file with keys 'name', 'prompt' or 'positive_prompt', and 'negative_prompt",
"importTemplates": "Import Prompt Templates (CSV/JSON)",
"acceptedColumnsKeys": "Accepted columns/keys:",
"nameColumn": "'name'",
"positivePromptColumn": "'prompt' or 'positive_prompt'",
"negativePromptColumn": "'negative_prompt'",
"insertPlaceholder": "Insert placeholder",
"myTemplates": "My Templates",
"name": "Name",
"negativePrompt": "Negative Prompt",
"noTemplates": "No templates",
"noMatchingTemplates": "No matching templates",
"promptTemplatesDesc1": "Prompt templates add text to the prompts you write in the prompt box.",
"promptTemplatesDesc2": "Use the placeholder string <Pre>{{placeholder}}</Pre> to specify where your prompt should be included in the template.",
@ -1719,6 +1726,7 @@
"searchByName": "Search by name",
"shared": "Shared",
"sharedTemplates": "Shared Templates",
"templateActions": "Template Actions",
"templateDeleted": "Prompt template deleted",
"toggleViewMode": "Toggle View Mode",
"type": "Type",

View File

@ -47,6 +47,7 @@ export const IAINoContentFallback = memo((props: IAINoImageFallbackProps) => {
userSelect: 'none',
opacity: 0.7,
color: 'base.500',
fontSize: 'md',
...sx,
}),
[sx]
@ -55,11 +56,7 @@ export const IAINoContentFallback = memo((props: IAINoImageFallbackProps) => {
return (
<Flex sx={styles} {...rest}>
{icon && <Icon as={icon} boxSize={boxSize} opacity={0.7} />}
{props.label && (
<Text textAlign="center" fontSize="md">
{props.label}
</Text>
)}
{props.label && <Text textAlign="center">{props.label}</Text>}
</Flex>
);
});

View File

@ -0,0 +1,30 @@
import { IconButton } from '@invoke-ai/ui-library';
import { $stylePresetModalState } from 'features/stylePresets/store/stylePresetModal';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
export const StylePresetCreateButton = () => {
const handleClickAddNew = useCallback(() => {
$stylePresetModalState.set({
prefilledFormData: null,
updatingStylePresetId: null,
isModalOpen: true,
});
}, []);
const { t } = useTranslation();
return (
<IconButton
icon={<PiPlusBold />}
tooltip={t('stylePresets.createPromptTemplate')}
aria-label={t('stylePresets.createPromptTemplate')}
onClick={handleClickAddNew}
size="md"
variant="ghost"
w={8}
h={8}
/>
);
};

View File

@ -0,0 +1,68 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { IconButton, spinAnimation } from '@invoke-ai/ui-library';
import { EMPTY_ARRAY } from 'app/store/constants';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiDownloadSimpleBold, PiSpinner } from 'react-icons/pi';
import { useLazyExportStylePresetsQuery, useListStylePresetsQuery } from 'services/api/endpoints/stylePresets';
const loadingStyles: SystemStyleObject = {
svg: { animation: spinAnimation },
};
export const StylePresetExportButton = () => {
const [exportStylePresets, { isLoading }] = useLazyExportStylePresetsQuery();
const { t } = useTranslation();
const { presetCount } = useListStylePresetsQuery(undefined, {
selectFromResult: ({ data }) => {
const userPresets = data?.filter((preset) => preset.type === 'user') ?? EMPTY_ARRAY;
return {
presetCount: userPresets.length,
};
},
});
const handleClickDownloadCsv = useCallback(async () => {
let blob;
try {
const response = await exportStylePresets().unwrap();
blob = new Blob([response], { type: 'text/csv' });
} catch (error) {
toast({
status: 'error',
title: t('stylePresets.exportFailed'),
});
return;
}
if (blob) {
const url = window.URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = 'data.csv';
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
window.URL.revokeObjectURL(url);
toast({
status: 'success',
title: t('stylePresets.exportDownloaded'),
});
}
}, [exportStylePresets, t]);
return (
<IconButton
onClick={handleClickDownloadCsv}
icon={!isLoading ? <PiDownloadSimpleBold /> : <PiSpinner />}
tooltip={t('stylePresets.exportPromptTemplates')}
aria-label={t('stylePresets.exportPromptTemplates')}
size="md"
variant="link"
w={8}
h={8}
sx={isLoading ? loadingStyles : undefined}
isDisabled={isLoading || presetCount === 0}
/>
);
};

View File

@ -39,6 +39,8 @@ export const StylePresetPromptField = (props: Props) => {
} else {
field.onChange(value + PRESET_PLACEHOLDER);
}
textareaRef.current?.focus();
}, [value, field, textareaRef]);
const isPromptPresent = useMemo(() => value?.includes(PRESET_PLACEHOLDER), [value]);

View File

@ -1,69 +0,0 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { IconButton, spinAnimation, Text } from '@invoke-ai/ui-library';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { PiSpinner, PiUploadBold } from 'react-icons/pi';
import { useImportStylePresetsMutation } from 'services/api/endpoints/stylePresets';
const loadingStyles: SystemStyleObject = {
svg: { animation: spinAnimation },
};
export const StylePresetImport = () => {
const [importStylePresets, { isLoading }] = useImportStylePresetsMutation();
const { t } = useTranslation();
const onDropAccepted = useCallback(
async (files: File[]) => {
const file = files[0];
if (!file) {
return;
}
try {
await importStylePresets(file).unwrap();
toast({
status: 'success',
title: t('toast.importSuccessful'),
});
} catch (error) {
toast({
status: 'error',
title: t('toast.importFailed'),
});
}
},
[importStylePresets, t]
);
const { getInputProps, getRootProps } = useDropzone({
accept: { 'text/csv': ['.csv'], 'application/json': ['.json'] },
onDropAccepted,
noDrag: true,
multiple: false,
});
return (
<>
<IconButton
icon={!isLoading ? <PiUploadBold /> : <PiSpinner />}
tooltip={
<>
<Text fontWeight="semibold">{t('stylePresets.importTemplates')}</Text>
<Text>{t('stylePresets.importTemplatesDesc')}</Text>
</>
}
aria-label={t('stylePresets.importTemplates')}
size="md"
variant="link"
w={8}
h={8}
sx={isLoading ? loadingStyles : undefined}
isDisabled={isLoading}
{...getRootProps()}
/>
<input {...getInputProps()} />
</>
);
};

View File

@ -0,0 +1,84 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, IconButton, ListItem, spinAnimation, Text, UnorderedList } from '@invoke-ai/ui-library';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { PiSpinner, PiUploadSimpleBold } from 'react-icons/pi';
import { useImportStylePresetsMutation } from 'services/api/endpoints/stylePresets';
const loadingStyles: SystemStyleObject = {
svg: { animation: spinAnimation },
};
export const StylePresetImportButton = () => {
const [importStylePresets, { isLoading }] = useImportStylePresetsMutation();
const { t } = useTranslation();
const onDropAccepted = useCallback(
(files: File[]) => {
const file = files[0];
if (!file) {
return;
}
importStylePresets(file)
.unwrap()
.then(() => {
toast({
status: 'success',
title: t('toast.importSuccessful'),
});
})
.catch((error) => {
toast({
status: 'error',
title: t('toast.importFailed'),
description: error ? `${error.data.detail}` : undefined,
});
});
},
[importStylePresets, t]
);
const { getInputProps, getRootProps } = useDropzone({
accept: { 'text/csv': ['.csv'], 'application/json': ['.json'] },
onDropAccepted,
noDrag: true,
multiple: false,
});
return (
<>
<IconButton
icon={!isLoading ? <PiUploadSimpleBold /> : <PiSpinner />}
tooltip={<TooltipContent />}
aria-label={t('stylePresets.importTemplates')}
size="md"
variant="link"
w={8}
h={8}
sx={isLoading ? loadingStyles : undefined}
isDisabled={isLoading}
{...getRootProps()}
/>
<input {...getInputProps()} />
</>
);
};
const TooltipContent = () => {
const { t } = useTranslation();
return (
<Flex flexDir="column">
<Text pb={1} fontWeight="semibold">
{t('stylePresets.importTemplates')}
</Text>
<Text>{t('stylePresets.acceptedColumnsKeys')}</Text>
<UnorderedList>
<ListItem>{t('stylePresets.nameColumn')}</ListItem>
<ListItem>{t('stylePresets.positivePromptColumn')}</ListItem>
<ListItem>{t('stylePresets.negativePromptColumn')}</ListItem>
</UnorderedList>
</Flex>
);
};

View File

@ -1,15 +1,16 @@
import { Button, Collapse, Flex, Icon, Text, useDisclosure } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { useTranslation } from 'react-i18next';
import { PiCaretDownBold } from 'react-icons/pi';
import type { StylePresetRecordWithImage } from 'services/api/endpoints/stylePresets';
import { StylePresetListItem } from './StylePresetListItem';
export const StylePresetList = ({ title, data }: { title: string; data: StylePresetRecordWithImage[] }) => {
const { t } = useTranslation();
const { onToggle, isOpen } = useDisclosure({ defaultIsOpen: true });
if (!data.length) {
return <></>;
}
const searchTerm = useAppSelector((s) => s.stylePreset.searchTerm);
return (
<Flex flexDir="column">
@ -22,9 +23,16 @@ export const StylePresetList = ({ title, data }: { title: string; data: StylePre
</Flex>
</Button>
<Collapse in={isOpen}>
{data.map((preset) => (
<StylePresetListItem preset={preset} key={preset.id} />
))}
{data.length ? (
data.map((preset) => <StylePresetListItem preset={preset} key={preset.id} />)
) : (
<IAINoContentFallback
fontSize="sm"
py={4}
label={searchTerm ? t('stylePresets.noMatchingTemplates') : t('stylePresets.noTemplates')}
icon={null}
/>
)}
</Collapse>
</Flex>
);

View File

@ -1,14 +1,13 @@
import { Flex, IconButton, Text } from '@invoke-ai/ui-library';
import { Flex } from '@invoke-ai/ui-library';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppSelector } from 'app/store/storeHooks';
import { $stylePresetModalState } from 'features/stylePresets/store/stylePresetModal';
import { useCallback } from 'react';
import { StylePresetExportButton } from 'features/stylePresets/components/StylePresetExportButton';
import { StylePresetImportButton } from 'features/stylePresets/components/StylePresetImportButton';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
import type { StylePresetRecordWithImage } from 'services/api/endpoints/stylePresets';
import { useListStylePresetsQuery } from 'services/api/endpoints/stylePresets';
import { StylePresetImport } from './StylePresetImport';
import { StylePresetCreateButton } from './StylePresetCreateButton';
import { StylePresetList } from './StylePresetList';
import StylePresetSearch from './StylePresetSearch';
@ -49,46 +48,19 @@ export const StylePresetMenu = () => {
const { t } = useTranslation();
const handleClickAddNew = useCallback(() => {
$stylePresetModalState.set({
prefilledFormData: null,
updatingStylePresetId: null,
isModalOpen: true,
});
}, []);
return (
<Flex flexDir="column" gap={2} padding={3} layerStyle="second" borderRadius="base">
<Flex alignItems="center" gap={2} w="full" justifyContent="space-between">
<StylePresetSearch />
<Flex alignItems="center" justifyContent="space-between">
<StylePresetImport />
<IconButton
icon={<PiPlusBold />}
tooltip={t('stylePresets.createPromptTemplate')}
aria-label={t('stylePresets.createPromptTemplate')}
onClick={handleClickAddNew}
size="md"
variant="link"
w={8}
h={8}
/>
<StylePresetCreateButton />
<StylePresetImportButton />
<StylePresetExportButton />
</Flex>
</Flex>
{data.presets.length === 0 && data.defaultPresets.length === 0 && (
<Text p={10} textAlign="center">
{t('stylePresets.noMatchingTemplates')}
</Text>
)}
<StylePresetList title={t('stylePresets.myTemplates')} data={data.presets} />
{allowPrivateStylePresets && (
<StylePresetList title={t('stylePresets.sharedTemplates')} data={data.sharedPresets} />
)}
<StylePresetList title={t('stylePresets.defaultTemplates')} data={data.defaultPresets} />
</Flex>
);

View File

@ -92,6 +92,13 @@ export const stylePresetsApi = api.injectEndpoints({
}),
providesTags: ['FetchOnReconnect', { type: 'StylePreset', id: LIST_TAG }],
}),
exportStylePresets: build.query<string, void>({
query: () => ({
url: buildStylePresetsUrl('/export'),
responseHandler: (response) => response.text(),
}),
providesTags: ['FetchOnReconnect', { type: 'StylePreset', id: LIST_TAG }],
}),
importStylePresets: build.mutation<
paths['/api/v1/style_presets/import']['post']['responses']['200']['content']['application/json'],
paths['/api/v1/style_presets/import']['post']['requestBody']['content']['multipart/form-data']['file']
@ -117,5 +124,6 @@ export const {
useDeleteStylePresetMutation,
useUpdateStylePresetMutation,
useListStylePresetsQuery,
useLazyExportStylePresetsQuery,
useImportStylePresetsMutation,
} = stylePresetsApi;

View File

@ -1344,6 +1344,23 @@ export type paths = {
patch?: never;
trace?: never;
};
"/api/v1/style_presets/export": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/** Export Style Presets */
get: operations["export_style_presets"];
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/v1/style_presets/import": {
parameters: {
query?: never;
@ -18109,6 +18126,27 @@ export interface operations {
};
};
};
export_style_presets: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description A CSV file with the requested data. */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": unknown;
"text/csv": unknown;
};
};
};
};
import_style_presets: {
parameters: {
query?: never;