diff --git a/langchain4j-redis/pom.xml b/langchain4j-redis/pom.xml
index 66c5fe2..230543e 100644
--- a/langchain4j-redis/pom.xml
+++ b/langchain4j-redis/pom.xml
@@ -46,14 +46,8 @@
- org.mockito
- mockito-core
- test
-
-
-
- org.mockito
- mockito-junit-jupiter
+ org.assertj
+ assertj-core
test
@@ -69,6 +63,21 @@
test
+
+ dev.langchain4j
+ langchain4j-embeddings-all-minilm-l6-v2-q
+ test
+
+
+
+
+ Apache-2.0
+ https://www.apache.org/licenses/LICENSE-2.0.txt
+ repo
+ A business-friendly OSS license
+
+
+
\ No newline at end of file
diff --git a/langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStore.java b/langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStore.java
index 3f74ae1..3c5fb2d 100644
--- a/langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStore.java
+++ b/langchain4j-redis/src/main/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStore.java
@@ -219,7 +219,7 @@ public class RedisEmbeddingStore implements EmbeddingStore {
private String user;
private String password;
private Integer dimension;
- private List metadataFieldsName;
+ private List metadataFieldsName = new ArrayList<>();
/**
* @param host Redis Stack host
diff --git a/langchain4j-redis/src/test/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStoreTest.java b/langchain4j-redis/src/test/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStoreTest.java
index ce02422..10b6275 100644
--- a/langchain4j-redis/src/test/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStoreTest.java
+++ b/langchain4j-redis/src/test/java/dev/langchain4j/store/embedding/redis/RedisEmbeddingStoreTest.java
@@ -3,17 +3,24 @@ package dev.langchain4j.store.embedding.redis;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
-import dev.langchain4j.internal.Utils;
+import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
+import dev.langchain4j.model.embedding.EmbeddingModel;
+import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
+import dev.langchain4j.store.embedding.RelevanceScore;
+import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
+import redis.clients.jedis.JedisPooled;
import java.util.List;
+import static dev.langchain4j.internal.Utils.randomUUID;
import static java.util.Arrays.asList;
-import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.data.Percentage.withPercentage;
@Disabled("needs Redis running locally")
class RedisEmbeddingStoreTest {
@@ -24,73 +31,234 @@ class RedisEmbeddingStoreTest {
* docker run -d -p 6379:6379 -p 8001:8001 redis/redis-stack:latest
*/
- private final EmbeddingStore store = new RedisEmbeddingStore(
- "localhost",
- 6379,
- "default",
- "password",
- 4,
- singletonList("field")
- );
+ private static final String HOST = "localhost";
+ private static final int PORT = 6379;
+ private static final String METADATA_KEY = "test-key";
- @Test
- void testAdd() {
- // test add without id
- String id = store.add(Embedding.from(asList(0.50f, 0.85f, 0.760f, 0.24f)),
- TextSegment.from("test string", Metadata.from("field", "value")));
- System.out.println("id=" + id);
+ private EmbeddingStore embeddingStore;
- // test add with id
- String selfId = Utils.randomUUID();
- store.add(selfId, Embedding.from(asList(0.80f, 0.45f, 0.89f, 0.24f)));
- System.out.println("id=" + selfId);
+ private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
+
+ @BeforeEach
+ void initEmptyRedisEmbeddingStore() {
+
+ flushDB();
+
+ embeddingStore = RedisEmbeddingStore.builder()
+ .host(HOST)
+ .port(PORT)
+ .dimension(384)
+ .build();
+ }
+
+ private static void flushDB() {
+ try (JedisPooled jedis = new JedisPooled(HOST, PORT)) {
+ jedis.flushDB();
+ }
}
@Test
- void testAddAll() {
- // test add All Method without embedded
- List ids = store.addAll(asList(
- Embedding.from(asList(0.3f, 0.87f, 0.90f, 0.24f)),
- Embedding.from(asList(0.54f, 0.34f, 0.67f, 0.24f)),
- Embedding.from(asList(0.80f, 0.45f, 0.779f, 0.5556f))
- ));
- System.out.println("ids=" + ids);
+ void should_add_embedding() {
- // test add all method with embedded
- ids = store.addAll(asList(
- Embedding.from(asList(0.3f, 0.87f, 0.90f, 0.24f)),
- Embedding.from(asList(0.54f, 0.34f, 0.67f, 0.24f)),
- Embedding.from(asList(0.80f, 0.45f, 0.779f, 0.5556f))
- ), asList(
- TextSegment.from("testString1", Metadata.from("field", "value1")),
- TextSegment.from("testString2", Metadata.from("field", "value2")),
- TextSegment.from("testingString3", Metadata.from("field", "value3"))
- ));
- System.out.println("ids=" + ids);
+ Embedding embedding = embeddingModel.embed(randomUUID()).content();
+
+ String id = embeddingStore.add(embedding);
+ assertThat(id).isNotNull();
+
+ List> relevant = embeddingStore.findRelevant(embedding, 10);
+ assertThat(relevant).hasSize(1);
+
+ EmbeddingMatch match = relevant.get(0);
+ assertThat(match.score()).isCloseTo(1, withPercentage(1));
+ assertThat(match.embeddingId()).isEqualTo(id);
+ assertThat(match.embedding()).isEqualTo(embedding);
+ assertThat(match.embedded()).isNull();
}
@Test
- void testAddEmpty() {
- // see log
- store.addAll(emptyList());
+ void should_add_embedding_with_id() {
+
+ String id = randomUUID();
+ Embedding embedding = embeddingModel.embed(randomUUID()).content();
+
+ embeddingStore.add(id, embedding);
+
+ List> relevant = embeddingStore.findRelevant(embedding, 10);
+ assertThat(relevant).hasSize(1);
+
+ EmbeddingMatch match = relevant.get(0);
+ assertThat(match.score()).isCloseTo(1, withPercentage(1));
+ assertThat(match.embeddingId()).isEqualTo(id);
+ assertThat(match.embedding()).isEqualTo(embedding);
+ assertThat(match.embedded()).isNull();
}
@Test
- void testFindRelevant() {
- List> res = store.findRelevant(Embedding.from(asList(0.80f, 0.45f, 0.89f, 0.24f)), 5);
- res.forEach(System.out::println);
+ void should_add_embedding_with_segment() {
+
+ TextSegment segment = TextSegment.from(randomUUID());
+ Embedding embedding = embeddingModel.embed(segment.text()).content();
+
+ String id = embeddingStore.add(embedding, segment);
+ assertThat(id).isNotNull();
+
+ List> relevant = embeddingStore.findRelevant(embedding, 10);
+ assertThat(relevant).hasSize(1);
+
+ EmbeddingMatch match = relevant.get(0);
+ assertThat(match.score()).isCloseTo(1, withPercentage(1));
+ assertThat(match.embeddingId()).isEqualTo(id);
+ assertThat(match.embedding()).isEqualTo(embedding);
+ assertThat(match.embedded()).isEqualTo(segment);
}
@Test
- void testScore() {
- String id = store.add(Embedding.from(asList(0.50f, 0.85f, 0.760f, 0.24f)),
- TextSegment.from("test string", Metadata.from("field", "value")));
- System.out.println("id=" + id);
+ void should_add_embedding_with_segment_with_metadata() {
- // use the same embedding to search
- List> res = store.findRelevant(Embedding.from(asList(0.50f, 0.85f, 0.760f, 0.24f)), 1);
- res.forEach(System.out::println);
+ flushDB();
- // the result embeddingMatch score is 5.96046447754E-8, but expected is 1 because they are same vectors.
+ embeddingStore = RedisEmbeddingStore.builder()
+ .host(HOST)
+ .port(PORT)
+ .dimension(384)
+ .metadataFieldsName(singletonList(METADATA_KEY))
+ .build();
+
+ TextSegment segment = TextSegment.from(randomUUID(), Metadata.from(METADATA_KEY, "test-value"));
+ Embedding embedding = embeddingModel.embed(segment.text()).content();
+
+ String id = embeddingStore.add(embedding, segment);
+ assertThat(id).isNotNull();
+
+ List> relevant = embeddingStore.findRelevant(embedding, 10);
+ assertThat(relevant).hasSize(1);
+
+ EmbeddingMatch match = relevant.get(0);
+ assertThat(match.score()).isCloseTo(1, withPercentage(1));
+ assertThat(match.embeddingId()).isEqualTo(id);
+ assertThat(match.embedding()).isEqualTo(embedding);
+ assertThat(match.embedded()).isEqualTo(segment);
+ }
+
+ @Test
+ void should_add_multiple_embeddings() {
+
+ Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
+ Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
+
+ List ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding));
+ assertThat(ids).hasSize(2);
+
+ List> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
+ assertThat(relevant).hasSize(2);
+
+ EmbeddingMatch firstMatch = relevant.get(0);
+ assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
+ assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
+ assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
+ assertThat(firstMatch.embedded()).isNull();
+
+ EmbeddingMatch secondMatch = relevant.get(1);
+ assertThat(secondMatch.score()).isBetween(0d, 1d);
+ assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
+ assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
+ assertThat(secondMatch.embedded()).isNull();
+ }
+
+ @Test
+ void should_add_multiple_embeddings_with_segments() {
+
+ TextSegment firstSegment = TextSegment.from(randomUUID());
+ Embedding firstEmbedding = embeddingModel.embed(firstSegment.text()).content();
+ TextSegment secondSegment = TextSegment.from(randomUUID());
+ Embedding secondEmbedding = embeddingModel.embed(secondSegment.text()).content();
+
+ List ids = embeddingStore.addAll(
+ asList(firstEmbedding, secondEmbedding),
+ asList(firstSegment, secondSegment)
+ );
+ assertThat(ids).hasSize(2);
+
+ List> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
+ assertThat(relevant).hasSize(2);
+
+ EmbeddingMatch firstMatch = relevant.get(0);
+ assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
+ assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
+ assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
+ assertThat(firstMatch.embedded()).isEqualTo(firstSegment);
+
+ EmbeddingMatch secondMatch = relevant.get(1);
+ assertThat(secondMatch.score()).isBetween(0d, 1d);
+ assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
+ assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
+ assertThat(secondMatch.embedded()).isEqualTo(secondSegment);
+ }
+
+ @Test
+ void should_find_with_min_score() {
+
+ String firstId = randomUUID();
+ Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
+ embeddingStore.add(firstId, firstEmbedding);
+
+ String secondId = randomUUID();
+ Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
+ embeddingStore.add(secondId, secondEmbedding);
+
+ List> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
+ assertThat(relevant).hasSize(2);
+ EmbeddingMatch firstMatch = relevant.get(0);
+ assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
+ assertThat(firstMatch.embeddingId()).isEqualTo(firstId);
+ EmbeddingMatch secondMatch = relevant.get(1);
+ assertThat(secondMatch.score()).isBetween(0d, 1d);
+ assertThat(secondMatch.embeddingId()).isEqualTo(secondId);
+
+ List> relevant2 = embeddingStore.findRelevant(
+ firstEmbedding,
+ 10,
+ secondMatch.score() - 0.01
+ );
+ assertThat(relevant2).hasSize(2);
+ assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId);
+ assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId);
+
+ List> relevant3 = embeddingStore.findRelevant(
+ firstEmbedding,
+ 10,
+ secondMatch.score()
+ );
+ assertThat(relevant3).hasSize(2);
+ assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId);
+ assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId);
+
+ List> relevant4 = embeddingStore.findRelevant(
+ firstEmbedding,
+ 10,
+ secondMatch.score() + 0.01
+ );
+ assertThat(relevant4).hasSize(1);
+ assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId);
+ }
+
+ @Test
+ void should_return_correct_score() {
+
+ Embedding embedding = embeddingModel.embed("hello").content();
+
+ String id = embeddingStore.add(embedding);
+ assertThat(id).isNotNull();
+
+ Embedding referenceEmbedding = embeddingModel.embed("hi").content();
+
+ List> relevant = embeddingStore.findRelevant(referenceEmbedding, 1);
+ assertThat(relevant).hasSize(1);
+
+ EmbeddingMatch match = relevant.get(0);
+ assertThat(match.score()).isCloseTo(
+ RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)),
+ withPercentage(1)
+ );
}
}