export only user style presets

This commit is contained in:
Mary Hipp 2024-08-15 16:07:32 -04:00
parent 24f298283f
commit 599db7296f
3 changed files with 15 additions and 7 deletions

View File

@ -245,7 +245,7 @@ async def export_style_presets():
# Write the header # Write the header
writer.writerow(["name", "prompt", "negative_prompt"]) writer.writerow(["name", "prompt", "negative_prompt"])
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many() style_presets = ApiDependencies.invoker.services.style_preset_records.get_many(type=PresetType.User)
for preset in style_presets: for preset in style_presets:
writer.writerow([preset.name, preset.preset_data.positive_prompt, preset.preset_data.negative_prompt]) writer.writerow([preset.name, preset.preset_data.positive_prompt, preset.preset_data.negative_prompt])

View File

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from invokeai.app.services.style_preset_records.style_preset_records_common import ( from invokeai.app.services.style_preset_records.style_preset_records_common import (
PresetType,
StylePresetChanges, StylePresetChanges,
StylePresetRecordDTO, StylePresetRecordDTO,
StylePresetWithoutId, StylePresetWithoutId,
@ -36,6 +37,6 @@ class StylePresetRecordsStorageBase(ABC):
pass pass
@abstractmethod @abstractmethod
def get_many(self) -> list[StylePresetRecordDTO]: def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
"""Gets many workflows.""" """Gets many workflows."""
pass pass

View File

@ -5,6 +5,7 @@ from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
from invokeai.app.services.style_preset_records.style_preset_records_common import ( from invokeai.app.services.style_preset_records.style_preset_records_common import (
PresetType,
StylePresetChanges, StylePresetChanges,
StylePresetNotFoundError, StylePresetNotFoundError,
StylePresetRecordDTO, StylePresetRecordDTO,
@ -159,19 +160,25 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
self._lock.release() self._lock.release()
return None return None
def get_many( def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
self,
) -> list[StylePresetRecordDTO]:
try: try:
self._lock.acquire() self._lock.acquire()
main_query = """ main_query = """
SELECT SELECT
* *
FROM style_presets FROM style_presets
ORDER BY LOWER(name) ASC
""" """
self._cursor.execute(main_query) if type is not None:
main_query += "WHERE type = ? "
main_query += "ORDER BY LOWER(name) ASC"
if type is not None:
self._cursor.execute(main_query, (type,))
else:
self._cursor.execute(main_query)
rows = self._cursor.fetchall() rows = self._cursor.fetchall()
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows] style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]