From f2b2f0214a0bab30ae95f9dfb9c19d6fd00a953d Mon Sep 17 00:00:00 2001 From: deep-learning-dynamo Date: Thu, 28 Sep 2023 18:20:05 +0200 Subject: [PATCH] Redis: more tests --- langchain4j-redis/pom.xml | 25 +- .../embedding/redis/RedisEmbeddingStore.java | 2 +- .../redis/RedisEmbeddingStoreTest.java | 274 ++++++++++++++---- 3 files changed, 239 insertions(+), 62 deletions(-) diff --git a/langchain4j-redis/pom.xml b/langchain4j-redis/pom.xml index 66c5fe2..230543e 100644 --- a/langchain4j-redis/pom.xml +++ b/langchain4j-redis/pom.xml @@ -46,14 +46,8 @@ - org.mockito - mockito-core - test - - - - org.mockito - mockito-junit-jupiter + org.assertj + assertj-core test @@ -69,6 +63,21 @@ test + + dev.langchain4j + langchain4j-embeddings-all-minilm-l6-v2-q + test + + + + + Apache-2.0 + https://www.apache.org/licenses/LICENSE-2.0.txt + repo + A business-friendly OSS license + + + \ No newline at end of file diff --git a/langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStore.java b/langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStore.java index 3f74ae1..3c5fb2d 100644 --- a/langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStore.java +++ b/langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStore.java @@ -219,7 +219,7 @@ public class RedisEmbeddingStore implements EmbeddingStore { private String user; private String password; private Integer dimension; - private List metadataFieldsName; + private List metadataFieldsName = new ArrayList<>(); /** * @param host Redis Stack host diff --git a/langchain4j-redis/src/test/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStoreTest.java b/langchain4j-redis/src/test/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStoreTest.java index ce02422..10b6275 100644 --- a/langchain4j-redis/src/test/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStoreTest.java +++ b/langchain4j-redis/src/test/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStoreTest.java @@ -3,17 +3,24 @@ package dev.langchain4j.store.embedding.redis; import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.internal.Utils; +import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.CosineSimilarity; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.RelevanceScore; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import redis.clients.jedis.JedisPooled; import java.util.List; +import static dev.langchain4j.internal.Utils.randomUUID; import static java.util.Arrays.asList; -import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.data.Percentage.withPercentage; @Disabled("needs Redis running locally") class RedisEmbeddingStoreTest { @@ -24,73 +31,234 @@ class RedisEmbeddingStoreTest { * docker run -d -p 6379:6379 -p 8001:8001 redis/redis-stack:latest */ - private final EmbeddingStore store = new RedisEmbeddingStore( - "localhost", - 6379, - "default", - "password", - 4, - singletonList("field") - ); + private static final String HOST = "localhost"; + private static final int PORT = 6379; + private static final String METADATA_KEY = "test-key"; - @Test - void testAdd() { - // test add without id - String id = store.add(Embedding.from(asList(0.50f, 0.85f, 0.760f, 0.24f)), - TextSegment.from("test string", Metadata.from("field", "value"))); - System.out.println("id=" + id); + private EmbeddingStore embeddingStore; - // test add with id - String selfId = Utils.randomUUID(); - store.add(selfId, Embedding.from(asList(0.80f, 0.45f, 0.89f, 0.24f))); - System.out.println("id=" + selfId); + private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + + @BeforeEach + void initEmptyRedisEmbeddingStore() { + + flushDB(); + + embeddingStore = RedisEmbeddingStore.builder() + .host(HOST) + .port(PORT) + .dimension(384) + .build(); + } + + private static void flushDB() { + try (JedisPooled jedis = new JedisPooled(HOST, PORT)) { + jedis.flushDB(); + } } @Test - void testAddAll() { - // test add All Method without embedded - List ids = store.addAll(asList( - Embedding.from(asList(0.3f, 0.87f, 0.90f, 0.24f)), - Embedding.from(asList(0.54f, 0.34f, 0.67f, 0.24f)), - Embedding.from(asList(0.80f, 0.45f, 0.779f, 0.5556f)) - )); - System.out.println("ids=" + ids); + void should_add_embedding() { - // test add all method with embedded - ids = store.addAll(asList( - Embedding.from(asList(0.3f, 0.87f, 0.90f, 0.24f)), - Embedding.from(asList(0.54f, 0.34f, 0.67f, 0.24f)), - Embedding.from(asList(0.80f, 0.45f, 0.779f, 0.5556f)) - ), asList( - TextSegment.from("testString1", Metadata.from("field", "value1")), - TextSegment.from("testString2", Metadata.from("field", "value2")), - TextSegment.from("testingString3", Metadata.from("field", "value3")) - )); - System.out.println("ids=" + ids); + Embedding embedding = embeddingModel.embed(randomUUID()).content(); + + String id = embeddingStore.add(embedding); + assertThat(id).isNotNull(); + + List> relevant = embeddingStore.findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isNull(); } @Test - void testAddEmpty() { - // see log - store.addAll(emptyList()); + void should_add_embedding_with_id() { + + String id = randomUUID(); + Embedding embedding = embeddingModel.embed(randomUUID()).content(); + + embeddingStore.add(id, embedding); + + List> relevant = embeddingStore.findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isNull(); } @Test - void testFindRelevant() { - List> res = store.findRelevant(Embedding.from(asList(0.80f, 0.45f, 0.89f, 0.24f)), 5); - res.forEach(System.out::println); + void should_add_embedding_with_segment() { + + TextSegment segment = TextSegment.from(randomUUID()); + Embedding embedding = embeddingModel.embed(segment.text()).content(); + + String id = embeddingStore.add(embedding, segment); + assertThat(id).isNotNull(); + + List> relevant = embeddingStore.findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isEqualTo(segment); } @Test - void testScore() { - String id = store.add(Embedding.from(asList(0.50f, 0.85f, 0.760f, 0.24f)), - TextSegment.from("test string", Metadata.from("field", "value"))); - System.out.println("id=" + id); + void should_add_embedding_with_segment_with_metadata() { - // use the same embedding to search - List> res = store.findRelevant(Embedding.from(asList(0.50f, 0.85f, 0.760f, 0.24f)), 1); - res.forEach(System.out::println); + flushDB(); - // the result embeddingMatch score is 5.96046447754E-8, but expected is 1 because they are same vectors. + embeddingStore = RedisEmbeddingStore.builder() + .host(HOST) + .port(PORT) + .dimension(384) + .metadataFieldsName(singletonList(METADATA_KEY)) + .build(); + + TextSegment segment = TextSegment.from(randomUUID(), Metadata.from(METADATA_KEY, "test-value")); + Embedding embedding = embeddingModel.embed(segment.text()).content(); + + String id = embeddingStore.add(embedding, segment); + assertThat(id).isNotNull(); + + List> relevant = embeddingStore.findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isEqualTo(segment); + } + + @Test + void should_add_multiple_embeddings() { + + Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content(); + Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content(); + + List ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding)); + assertThat(ids).hasSize(2); + + List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); + assertThat(relevant).hasSize(2); + + EmbeddingMatch firstMatch = relevant.get(0); + assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); + assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); + assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); + assertThat(firstMatch.embedded()).isNull(); + + EmbeddingMatch secondMatch = relevant.get(1); + assertThat(secondMatch.score()).isBetween(0d, 1d); + assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); + assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); + assertThat(secondMatch.embedded()).isNull(); + } + + @Test + void should_add_multiple_embeddings_with_segments() { + + TextSegment firstSegment = TextSegment.from(randomUUID()); + Embedding firstEmbedding = embeddingModel.embed(firstSegment.text()).content(); + TextSegment secondSegment = TextSegment.from(randomUUID()); + Embedding secondEmbedding = embeddingModel.embed(secondSegment.text()).content(); + + List ids = embeddingStore.addAll( + asList(firstEmbedding, secondEmbedding), + asList(firstSegment, secondSegment) + ); + assertThat(ids).hasSize(2); + + List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); + assertThat(relevant).hasSize(2); + + EmbeddingMatch firstMatch = relevant.get(0); + assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); + assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); + assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); + assertThat(firstMatch.embedded()).isEqualTo(firstSegment); + + EmbeddingMatch secondMatch = relevant.get(1); + assertThat(secondMatch.score()).isBetween(0d, 1d); + assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); + assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); + assertThat(secondMatch.embedded()).isEqualTo(secondSegment); + } + + @Test + void should_find_with_min_score() { + + String firstId = randomUUID(); + Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content(); + embeddingStore.add(firstId, firstEmbedding); + + String secondId = randomUUID(); + Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content(); + embeddingStore.add(secondId, secondEmbedding); + + List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); + assertThat(relevant).hasSize(2); + EmbeddingMatch firstMatch = relevant.get(0); + assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); + assertThat(firstMatch.embeddingId()).isEqualTo(firstId); + EmbeddingMatch secondMatch = relevant.get(1); + assertThat(secondMatch.score()).isBetween(0d, 1d); + assertThat(secondMatch.embeddingId()).isEqualTo(secondId); + + List> relevant2 = embeddingStore.findRelevant( + firstEmbedding, + 10, + secondMatch.score() - 0.01 + ); + assertThat(relevant2).hasSize(2); + assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId); + assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId); + + List> relevant3 = embeddingStore.findRelevant( + firstEmbedding, + 10, + secondMatch.score() + ); + assertThat(relevant3).hasSize(2); + assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId); + assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId); + + List> relevant4 = embeddingStore.findRelevant( + firstEmbedding, + 10, + secondMatch.score() + 0.01 + ); + assertThat(relevant4).hasSize(1); + assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId); + } + + @Test + void should_return_correct_score() { + + Embedding embedding = embeddingModel.embed("hello").content(); + + String id = embeddingStore.add(embedding); + assertThat(id).isNotNull(); + + Embedding referenceEmbedding = embeddingModel.embed("hi").content(); + + List> relevant = embeddingStore.findRelevant(referenceEmbedding, 1); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo( + RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)), + withPercentage(1) + ); } }