From 356661459bfaa617d723a165a94c0aacea272872 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 15 Aug 2024 14:45:04 +1000 Subject: [PATCH] feat(api): support JSON for preset imports This allows us to support Fooocus format presets. --- invokeai/app/api/routers/style_presets.py | 48 ++++++++++++------- .../style_preset_records_common.py | 19 +------- 2 files changed, 33 insertions(+), 34 deletions(-) diff --git a/invokeai/app/api/routers/style_presets.py b/invokeai/app/api/routers/style_presets.py index d7673cc25d..ccea914750 100644 --- a/invokeai/app/api/routers/style_presets.py +++ b/invokeai/app/api/routers/style_presets.py @@ -1,13 +1,15 @@ +import csv import io import json import traceback +from codecs import iterdecode from typing import Optional import pydantic from fastapi import APIRouter, File, Form, HTTPException, Path, UploadFile from fastapi.responses import FileResponse from PIL import Image -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ValidationError from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.api.routers.model_manager import IMAGE_MAX_AGE @@ -16,11 +18,10 @@ from invokeai.app.services.style_preset_records.style_preset_records_common impo PresetData, PresetType, StylePresetChanges, - StylePresetImportValidationError, + StylePresetImportListTypeAdapter, StylePresetNotFoundError, StylePresetRecordWithImage, StylePresetWithoutId, - parse_csv, ) @@ -234,21 +235,36 @@ async def get_style_preset_image( operation_id="import_style_presets", ) async def import_style_presets(file: UploadFile = File(description="The file to import")): - if not file.filename.endswith(".csv"): - raise HTTPException(status_code=400, detail="Invalid file type") + if file.content_type not in ["text/csv", "application/json"]: + raise HTTPException(status_code=400, detail="Unsupported file type") try: - parsed_data = parse_csv(file) - except StylePresetImportValidationError: - raise HTTPException( - status_code=400, detail="Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt'" - ) + if file.content_type == "text/csv": + csv_reader = csv.DictReader(iterdecode(file.file, "utf-8")) + data = list(csv_reader) + else: # file.content_type == "application/json": + json_data = await file.read() + data = json.loads(json_data) - style_presets: list[StylePresetWithoutId] = [] + imported_presets = StylePresetImportListTypeAdapter.validate_python(data) - for style_preset in parsed_data: - preset_data = PresetData(positive_prompt=style_preset.prompt, negative_prompt=style_preset.negative_prompt) - style_preset = StylePresetWithoutId(name=style_preset.name, preset_data=preset_data, type=PresetType.User) - style_presets.append(style_preset) + style_presets: list[StylePresetWithoutId] = [] - ApiDependencies.invoker.services.style_preset_records.create_many(style_presets) + for imported in imported_presets: + preset_data = PresetData(positive_prompt=imported.positive_prompt, negative_prompt=imported.negative_prompt) + style_preset = StylePresetWithoutId(name=imported.name, preset_data=preset_data, type=PresetType.User) + style_presets.append(style_preset) + ApiDependencies.invoker.services.style_preset_records.create_many(style_presets) + except ValidationError: + if file.content_type == "text/csv": + raise HTTPException( + status_code=400, + detail="Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt'", + ) + else: # file.content_type == "application/json": + raise HTTPException( + status_code=400, + detail="Invalid JSON format: must be a list of objects with keys 'name', 'prompt', and 'negative_prompt'", + ) + finally: + file.file.close() diff --git a/invokeai/app/services/style_preset_records/style_preset_records_common.py b/invokeai/app/services/style_preset_records/style_preset_records_common.py index 11c5c3b8c0..2d33a7ea76 100644 --- a/invokeai/app/services/style_preset_records/style_preset_records_common.py +++ b/invokeai/app/services/style_preset_records/style_preset_records_common.py @@ -1,9 +1,6 @@ -import csv -import io from enum import Enum -from typing import Any, Generator, Optional +from typing import Any, Optional -from fastapi import UploadFile from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter from invokeai.app.util.metaenum import MetaEnum @@ -72,17 +69,3 @@ class StylePresetImportRow(BaseModel): StylePresetImportList = list[StylePresetImportRow] StylePresetImportListTypeAdapter = TypeAdapter(StylePresetImportList) - - -def parse_csv(file: UploadFile) -> Generator[StylePresetImportRow, None, None]: - """Yield parsed and validated rows from the CSV file.""" - file_content = file.file.read().decode("utf-8") - csv_reader = csv.DictReader(io.StringIO(file_content)) - - for row in csv_reader: - if "name" not in row or "prompt" not in row or "negative_prompt" not in row: - raise StylePresetImportValidationError() - - yield StylePresetImportRow( - name=row["name"].strip(), positive_prompt=row["prompt"].strip(), negative_prompt=row["negative_prompt"].strip() - )