From a0eecaecd01dabe9d2678cd39ede8e74706cfb38 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 1 Feb 2024 07:40:56 +1100 Subject: [PATCH] feat(item_storage): implement item_storage_memory max_size Implemented with unordered dict and set. --- .../item_storage/item_storage_memory.py | 33 +++++- tests/test_item_storage_memory.py | 111 ++++++++++++++++++ 2 files changed, 139 insertions(+), 5 deletions(-) create mode 100644 tests/test_item_storage_memory.py diff --git a/invokeai/app/services/item_storage/item_storage_memory.py b/invokeai/app/services/item_storage/item_storage_memory.py index 1958c3ee8b..e8846a01a7 100644 --- a/invokeai/app/services/item_storage/item_storage_memory.py +++ b/invokeai/app/services/item_storage/item_storage_memory.py @@ -1,3 +1,4 @@ +from contextlib import suppress from typing import Generic, Optional, TypeVar from pydantic import BaseModel @@ -8,21 +9,43 @@ T = TypeVar("T", bound=BaseModel) class ItemStorageMemory(ItemStorageABC, Generic[T]): - def __init__(self, id_field: str = "id") -> None: + """ + Provides a simple in-memory storage for items, with a maximum number of items to store. + An item is deleted when the maximum number of items is reached and a new item is added. + There is no guarantee about which item will be deleted. + """ + + def __init__(self, id_field: str = "id", max_items: int = 10) -> None: super().__init__() + if max_items < 1: + raise ValueError("max_items must be at least 1") + if not id_field: + raise ValueError("id_field must not be empty") self._id_field = id_field self._items: dict[str, T] = {} + self._item_ids: set[str] = set() + self._max_items = max_items def get(self, item_id: str) -> Optional[T]: return self._items.get(item_id) def set(self, item: T) -> None: - self._items[getattr(item, self._id_field)] = item + item_id = getattr(item, self._id_field) + assert isinstance(item_id, str) + if item_id in self._items or len(self._items) < self._max_items: + # If the item is already stored, or we have room for more items, we can just add it. + self._items[item_id] = item + self._item_ids.add(item_id) + else: + # Otherwise, we need to make room for it first. + self._items.pop(self._item_ids.pop()) + self._items[item_id] = item + self._item_ids.add(item_id) self._on_changed(item) def delete(self, item_id: str) -> None: - try: + # Both of these are no-ops if the item doesn't exist. + with suppress(KeyError): del self._items[item_id] + self._item_ids.remove(item_id) self._on_deleted(item_id) - except KeyError: - pass diff --git a/tests/test_item_storage_memory.py b/tests/test_item_storage_memory.py new file mode 100644 index 0000000000..601bc5c889 --- /dev/null +++ b/tests/test_item_storage_memory.py @@ -0,0 +1,111 @@ +import re + +import pytest +from pydantic import BaseModel + +from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory + + +class MockItemModel(BaseModel): + id: str + value: int + + +@pytest.fixture +def item_storage_memory(): + return ItemStorageMemory[MockItemModel]() + + +def test_item_storage_memory_initializes(): + item_storage_memory = ItemStorageMemory() + assert item_storage_memory._items == {} + assert item_storage_memory._item_ids == set() + assert item_storage_memory._id_field == "id" + assert item_storage_memory._max_items == 10 + + item_storage_memory = ItemStorageMemory(id_field="bananas", max_items=20) + assert item_storage_memory._id_field == "bananas" + assert item_storage_memory._max_items == 20 + + with pytest.raises(ValueError, match=re.escape("max_items must be at least 1")): + item_storage_memory = ItemStorageMemory(max_items=0) + with pytest.raises(ValueError, match=re.escape("id_field must not be empty")): + item_storage_memory = ItemStorageMemory(id_field="") + + +def test_item_storage_memory_sets(item_storage_memory: ItemStorageMemory[MockItemModel]): + item_1 = MockItemModel(id="1", value=1) + item_storage_memory.set(item_1) + assert item_storage_memory._items == {"1": item_1} + assert item_storage_memory._item_ids == {"1"} + + item_2 = MockItemModel(id="2", value=2) + item_storage_memory.set(item_2) + assert item_storage_memory._items == {"1": item_1, "2": item_2} + assert item_storage_memory._item_ids == {"1", "2"} + + # Updating value of existing item + item_2_updated = MockItemModel(id="2", value=9001) + item_storage_memory.set(item_2_updated) + assert item_storage_memory._items == {"1": item_1, "2": item_2_updated} + assert item_storage_memory._item_ids == {"1", "2"} + + +def test_item_storage_memory_gets(item_storage_memory: ItemStorageMemory[MockItemModel]): + item_1 = MockItemModel(id="1", value=1) + item_storage_memory.set(item_1) + item = item_storage_memory.get("1") + assert item == item_1 + + item_2 = MockItemModel(id="2", value=2) + item_storage_memory.set(item_2) + item = item_storage_memory.get("2") + assert item == item_2 + + item = item_storage_memory.get("3") + assert item is None + + +def test_item_storage_memory_deletes(item_storage_memory: ItemStorageMemory[MockItemModel]): + item_1 = MockItemModel(id="1", value=1) + item_2 = MockItemModel(id="2", value=2) + item_storage_memory.set(item_1) + item_storage_memory.set(item_2) + + item_storage_memory.delete("2") + assert item_storage_memory._items == {"1": item_1} + assert item_storage_memory._item_ids == {"1"} + + +def test_item_storage_memory_respects_max(): + item_storage_memory = ItemStorageMemory(max_items=3) + for i in range(10): + item_storage_memory.set(MockItemModel(id=str(i), value=i)) + assert len(item_storage_memory._items) == 3 + + +def test_item_storage_memory_calls_set_callback(item_storage_memory: ItemStorageMemory[MockItemModel]): + called_item = None + item = MockItemModel(id="1", value=1) + + def on_changed(item: MockItemModel): + nonlocal called_item + called_item = item + + item_storage_memory.on_changed(on_changed) + item_storage_memory.set(item) + assert called_item == item + + +def test_item_storage_memory_calls_delete_callback(item_storage_memory: ItemStorageMemory[MockItemModel]): + called_item_id = None + item = MockItemModel(id="1", value=1) + + def on_deleted(item_id: str): + nonlocal called_item_id + called_item_id = item_id + + item_storage_memory.on_deleted(on_deleted) + item_storage_memory.set(item) + item_storage_memory.delete("1") + assert called_item_id == "1"