package org.springframework.ai.rag.retrieval.search;

import java.util.List;
import java.util.function.Supplier;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/* loaded from: input_file:org/springframework/ai/rag/retrieval/search/VectorStoreDocumentRetriever.class */
public final class VectorStoreDocumentRetriever implements DocumentRetriever {
    public static final String FILTER_EXPRESSION = "vector_store_filter_expression";
    private final VectorStore vectorStore;
    private final Double similarityThreshold;
    private final Integer topK;
    private final Supplier<Filter.Expression> filterExpression;

    /* loaded from: input_file:org/springframework/ai/rag/retrieval/search/VectorStoreDocumentRetriever$Builder.class */
    public static final class Builder {
        private VectorStore vectorStore;
        private Double similarityThreshold;
        private Integer topK;
        private Supplier<Filter.Expression> filterExpression;

        private Builder() {
        }

        public Builder vectorStore(VectorStore vectorStore) {
            this.vectorStore = vectorStore;
            return this;
        }

        public Builder similarityThreshold(Double d) {
            this.similarityThreshold = d;
            return this;
        }

        public Builder topK(Integer num) {
            this.topK = num;
            return this;
        }

        public Builder filterExpression(Filter.Expression expression) {
            this.filterExpression = () -> {
                return expression;
            };
            return this;
        }

        public Builder filterExpression(Supplier<Filter.Expression> supplier) {
            this.filterExpression = supplier;
            return this;
        }

        public VectorStoreDocumentRetriever build() {
            return new VectorStoreDocumentRetriever(this.vectorStore, this.similarityThreshold, this.topK, this.filterExpression);
        }
    }

    public VectorStoreDocumentRetriever(VectorStore vectorStore, @Nullable Double d, @Nullable Integer num, @Nullable Supplier<Filter.Expression> supplier) {
        Assert.notNull(vectorStore, "vectorStore cannot be null");
        Assert.isTrue(d == null || d.doubleValue() >= 0.0d, "similarityThreshold must be equal to or greater than 0.0");
        Assert.isTrue(num == null || num.intValue() > 0, "topK must be greater than 0");
        this.vectorStore = vectorStore;
        this.similarityThreshold = Double.valueOf(d != null ? d.doubleValue() : 0.0d);
        this.topK = Integer.valueOf(num != null ? num.intValue() : 4);
        this.filterExpression = supplier != null ? supplier : () -> {
            return null;
        };
    }

    @Override // org.springframework.ai.rag.retrieval.search.DocumentRetriever
    public List<Document> retrieve(Query query) {
        Assert.notNull(query, "query cannot be null");
        return this.vectorStore.similaritySearch(SearchRequest.builder().query(query.text()).filterExpression(computeRequestFilterExpression(query)).similarityThreshold(this.similarityThreshold.doubleValue()).topK(this.topK.intValue()).build());
    }

    private Filter.Expression computeRequestFilterExpression(Query query) {
        Object obj = query.context().get(FILTER_EXPRESSION);
        if (obj != null) {
            if (obj instanceof Filter.Expression) {
                return (Filter.Expression) obj;
            }
            if (StringUtils.hasText(obj.toString())) {
                return new FilterExpressionTextParser().parse(obj.toString());
            }
        }
        return this.filterExpression.get();
    }

    public static Builder builder() {
        return new Builder();
    }
}
