/*
 * Decompiled with CFR 0.152.
 */
package io.milvus.orm.iterator;

import com.amazonaws.util.CollectionUtils;
import com.fasterxml.jackson.core.type.TypeReference;
import com.google.common.collect.Lists;
import io.milvus.common.utils.ExceptionUtils;
import io.milvus.common.utils.JacksonUtils;
import io.milvus.exception.ParamException;
import io.milvus.grpc.DataType;
import io.milvus.grpc.MilvusServiceGrpc;
import io.milvus.grpc.SearchRequest;
import io.milvus.grpc.SearchResults;
import io.milvus.orm.iterator.IteratorCache;
import io.milvus.param.MetricType;
import io.milvus.param.ParamUtils;
import io.milvus.param.collection.FieldType;
import io.milvus.param.dml.SearchIteratorParam;
import io.milvus.param.dml.SearchParam;
import io.milvus.response.QueryResultsWrapper;
import io.milvus.response.SearchResultsWrapper;
import io.milvus.v2.utils.RpcUtils;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SearchIterator {
    private static final Logger logger = LoggerFactory.getLogger(SearchIterator.class);
    private final IteratorCache iteratorCache = new IteratorCache();
    private final MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub;
    private final FieldType primaryField;
    private final SearchIteratorParam searchIteratorParam;
    private final int batchSize;
    private final int topK;
    private final String expr;
    private final String metricType;
    private int cacheId;
    private boolean initSuccess;
    private int returnedCount;
    private float width;
    private float tailBand;
    private List<Object> filteredIds;
    private Float filteredDistance = null;
    private Map<String, Object> params;
    private final RpcUtils rpcUtils;

    public SearchIterator(SearchIteratorParam searchIteratorParam, MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, FieldType primaryField) {
        this.searchIteratorParam = searchIteratorParam;
        this.blockingStub = blockingStub;
        this.primaryField = primaryField;
        this.metricType = searchIteratorParam.getMetricType();
        this.batchSize = (int)searchIteratorParam.getBatchSize();
        this.expr = searchIteratorParam.getExpr();
        this.topK = searchIteratorParam.getTopK();
        this.rpcUtils = new RpcUtils();
        this.initParams();
        this.checkForSpecialIndexParam();
        this.checkRmRangeSearchParameters();
        this.initSearchIterator();
    }

    public List<QueryResultsWrapper.RowRecord> next() {
        if (!this.initSuccess || this.checkReachedLimit()) {
            return Lists.newArrayList();
        }
        int retLen = this.batchSize;
        if (this.topK != -1) {
            int leftLen = this.topK - this.returnedCount;
            retLen = Math.min(leftLen, retLen);
        }
        if (this.isCacheEnough(retLen)) {
            List<QueryResultsWrapper.RowRecord> retPage = this.extractPageFromCache(retLen);
            this.returnedCount += retPage.size();
            return retPage;
        }
        List<QueryResultsWrapper.RowRecord> newPage = this.trySearchFill();
        int cachedPageLen = this.pushNewPageToCache(newPage);
        retLen = Math.min(cachedPageLen, retLen);
        List<QueryResultsWrapper.RowRecord> retPage = this.extractPageFromCache(retLen);
        if (retPage.size() == this.batchSize) {
            this.updateWidth(retPage);
        }
        if (retPage.isEmpty()) {
            this.filteredIds.clear();
        }
        this.returnedCount += retLen;
        return retPage;
    }

    public void close() {
        this.iteratorCache.releaseCache(this.cacheId);
    }

    private void initParams() {
        if (null != this.searchIteratorParam.getParams() && !this.searchIteratorParam.getParams().isEmpty()) {
            this.params = new HashMap<String, Object>();
        }
        this.params = JacksonUtils.fromJson(this.searchIteratorParam.getParams(), new TypeReference<Map<String, Object>>(){});
    }

    private void checkForSpecialIndexParam() {
        if (this.params.containsKey("ef") && (Integer)this.params.get("ef") < this.batchSize) {
            ExceptionUtils.throwUnExpectedException("When using hnsw index, provided ef must be larger than or equal to batch size");
        }
    }

    private void checkRmRangeSearchParameters() {
        if (this.params.containsKey("radius") && this.params.containsKey("range_filter")) {
            String msg;
            float radius = this.getFloatValue("radius");
            float rangeFilter = this.getFloatValue("range_filter");
            if (this.metricsPositiveRelated(this.metricType) && radius <= rangeFilter) {
                msg = String.format("for metrics:%s, radius must be larger than range_filter, please adjust your parameter", this.metricType);
                ExceptionUtils.throwUnExpectedException(msg);
            }
            if (!this.metricsPositiveRelated(this.metricType) && radius >= rangeFilter) {
                msg = String.format("for metrics:%s, radius must be smalled than range_filter, please adjust your parameter", this.metricType);
                ExceptionUtils.throwUnExpectedException(msg);
            }
        }
    }

    private void initSearchIterator() {
        SearchResultsWrapper searchResultsWrapper = this.executeNextSearch(this.params, this.expr, false);
        List<QueryResultsWrapper.RowRecord> result = searchResultsWrapper.getRowRecords(0);
        if (CollectionUtils.isNullOrEmpty(result)) {
            String msg = "Cannot init search iterator because init page contains no matched rows, please check the radius and range_filter set up by searchParams";
            logger.error(msg);
            this.cacheId = -1;
            this.initSuccess = false;
            return;
        }
        this.cacheId = this.iteratorCache.cache(-1, result);
        this.setUpRangeParameters(result);
        this.updateFilteredIds(searchResultsWrapper);
        this.initSuccess = true;
    }

    private void setUpRangeParameters(List<QueryResultsWrapper.RowRecord> page) {
        this.updateWidth(page);
        QueryResultsWrapper.RowRecord lastHit = page.get(page.size() - 1);
        this.tailBand = this.getDistance(lastHit);
        String msg = String.format("set up init parameter for searchIterator width:%s tail_band:%s", Float.valueOf(this.width), Float.valueOf(this.tailBand));
        logger.debug(msg);
        System.out.println(msg);
    }

    private void updateFilteredIds(SearchResultsWrapper searchResultsWrapper) {
        List<SearchResultsWrapper.IDScore> idScores = searchResultsWrapper.getIDScore(0);
        if (CollectionUtils.isNullOrEmpty(idScores)) {
            return;
        }
        SearchResultsWrapper.IDScore lastHit = idScores.get(idScores.size() - 1);
        if (lastHit == null) {
            return;
        }
        if (this.filteredDistance == null || lastHit.getScore() != this.filteredDistance.floatValue()) {
            this.filteredIds = Lists.newArrayList();
            this.filteredDistance = Float.valueOf(lastHit.getScore());
        }
        for (SearchResultsWrapper.IDScore hit : idScores) {
            if (hit.getScore() != lastHit.getScore()) continue;
            if (this.primaryField.getDataType() == DataType.VarChar) {
                this.filteredIds.add(hit.getStrID());
                continue;
            }
            this.filteredIds.add(hit.getLongID());
        }
        if (this.filteredIds.size() > 100000) {
            String msg = String.format("filtered ids length has accumulated to more than %s, there is a danger of overly memory consumption", 100000);
            ExceptionUtils.throwUnExpectedException(msg);
        }
    }

    private SearchResultsWrapper executeNextSearch(Map<String, Object> params, String nextExpr, boolean toExtendBatch) {
        SearchParam searchParam = SearchParam.newBuilder().withDatabaseName(this.searchIteratorParam.getDatabaseName()).withCollectionName(this.searchIteratorParam.getCollectionName()).withPartitionNames(this.searchIteratorParam.getPartitionNames()).withConsistencyLevel(this.searchIteratorParam.getConsistencyLevel()).withVectorFieldName(this.searchIteratorParam.getVectorFieldName()).withTopK(this.extendBatchSize(this.batchSize, toExtendBatch, params)).withExpr(nextExpr).withOutFields(this.searchIteratorParam.getOutFields()).withVectors(this.searchIteratorParam.getVectors()).withRoundDecimal(this.searchIteratorParam.getRoundDecimal()).withParams(JacksonUtils.toJsonString(params)).withMetricType(MetricType.valueOf(this.searchIteratorParam.getMetricType())).withIgnoreGrowing(this.searchIteratorParam.isIgnoreGrowing()).build();
        SearchRequest searchRequest = ParamUtils.convertSearchParam(searchParam);
        SearchResults response = this.blockingStub.search(searchRequest);
        String title = String.format("SearchRequest collectionName:%s", this.searchIteratorParam.getCollectionName());
        this.rpcUtils.handleResponse(title, response.getStatus());
        return new SearchResultsWrapper(response.getResults());
    }

    private int extendBatchSize(int batchSize, boolean toExtendBatchSize, Map<String, Object> nextParams) {
        int extendRate = 1;
        if (toExtendBatchSize) {
            extendRate = 10;
        }
        if (nextParams.containsKey("ef")) {
            int realBatch;
            int ef = (Integer)nextParams.get("ef");
            if (ef > (realBatch = Math.min(16384, Math.min(batchSize * extendRate, ef)))) {
                nextParams.put("ef", realBatch);
            }
            return realBatch;
        }
        return Math.min(16384, batchSize * extendRate);
    }

    private void updateWidth(List<QueryResultsWrapper.RowRecord> page) {
        QueryResultsWrapper.RowRecord firstHit = page.get(0);
        QueryResultsWrapper.RowRecord lastHit = page.get(page.size() - 1);
        this.width = this.metricsPositiveRelated(this.metricType) ? this.getDistance(lastHit) - this.getDistance(firstHit) : this.getDistance(firstHit) - this.getDistance(lastHit);
        if ((double)this.width == 0.0) {
            this.width = 0.05f;
        }
    }

    private boolean metricsPositiveRelated(String metricType) {
        if (Lists.newArrayList((Object[])new String[]{MetricType.L2.name(), MetricType.JACCARD.name(), MetricType.HAMMING.name()}).contains(metricType)) {
            return true;
        }
        if (Lists.newArrayList((Object[])new String[]{MetricType.IP.name(), MetricType.COSINE.name()}).contains(metricType)) {
            return false;
        }
        String msg = String.format("unsupported metrics type for search iteration: %s", metricType);
        ExceptionUtils.throwUnExpectedException(msg);
        return false;
    }

    private boolean checkReachedLimit() {
        if (this.topK == -1 || this.returnedCount < this.topK) {
            return false;
        }
        String msg = String.format("reached search limit:%s, returned_count:%s, directly return", this.topK, this.returnedCount);
        logger.debug(msg);
        return true;
    }

    private boolean isCacheEnough(int count) {
        List<QueryResultsWrapper.RowRecord> cachedPage = this.iteratorCache.fetchCache(this.cacheId);
        return cachedPage != null && cachedPage.size() >= count;
    }

    private List<QueryResultsWrapper.RowRecord> extractPageFromCache(int count) {
        List<QueryResultsWrapper.RowRecord> cachedPage = this.iteratorCache.fetchCache(this.cacheId);
        if (cachedPage == null || cachedPage.size() < count) {
            String msg = String.format("Wrong, try to extract %s result from cache, more than %s there must be sth wrong with code", count, cachedPage == null ? 0 : cachedPage.size());
            throw new ParamException(msg);
        }
        List<QueryResultsWrapper.RowRecord> retPageRes = cachedPage.subList(0, count);
        List<QueryResultsWrapper.RowRecord> leftCachePage = cachedPage.subList(count, cachedPage.size());
        this.iteratorCache.cache(this.cacheId, leftCachePage);
        return retPageRes;
    }

    private List<QueryResultsWrapper.RowRecord> trySearchFill() {
        ArrayList finalPage = Lists.newArrayList();
        int tryTime = 0;
        int coefficient = 1;
        while (true) {
            Map<String, Object> nextParams = this.nextParams(coefficient);
            String nextExpr = this.filteredDuplicatedResultExpr(this.expr);
            SearchResultsWrapper searchResultsWrapper = this.executeNextSearch(nextParams, nextExpr, true);
            this.updateFilteredIds(searchResultsWrapper);
            List<QueryResultsWrapper.RowRecord> newPage = searchResultsWrapper.getRowRecords(0);
            ++tryTime;
            if (!newPage.isEmpty()) {
                finalPage.addAll(newPage);
                this.tailBand = this.getDistance(newPage.get(newPage.size() - 1));
            }
            if (finalPage.size() >= this.batchSize) break;
            if (tryTime > 20) {
                String msg = String.format("Search exceed max try times:%s directly break", 20);
                logger.warn(msg);
                break;
            }
            ++coefficient;
        }
        return finalPage;
    }

    private Map<String, Object> nextParams(int coefficient) {
        coefficient = Math.max(1, coefficient);
        Map<String, Object> nextParams = JacksonUtils.fromJson(JacksonUtils.toJsonString(this.params), new TypeReference<Map<String, Object>>(){});
        if (this.metricsPositiveRelated(this.metricType)) {
            float nextRadius = this.tailBand + this.width * (float)coefficient;
            if (this.params.containsKey("radius") && nextRadius > this.getFloatValue("radius")) {
                nextParams.put("radius", Float.valueOf(this.getFloatValue("radius")));
            } else {
                nextParams.put("radius", Float.valueOf(nextRadius));
            }
        } else {
            double nextRadius = this.tailBand - this.width * (float)coefficient;
            if (this.params.containsKey("radius") && nextRadius < (double)this.getFloatValue("radius")) {
                nextParams.put("radius", Float.valueOf(this.getFloatValue("radius")));
            } else {
                nextParams.put("radius", nextRadius);
            }
        }
        nextParams.put("range_filter", Float.valueOf(this.tailBand));
        String msg = String.format("next round search iteration radius:%s,range_filter:%s,coefficient:%s", this.convertToStr(nextParams.get("radius")), this.convertToStr(nextParams.get("range_filter")), coefficient);
        logger.debug(msg);
        return nextParams;
    }

    private String filteredDuplicatedResultExpr(String expr) {
        if (CollectionUtils.isNullOrEmpty(this.filteredIds)) {
            return expr;
        }
        StringBuilder filteredIdsStr = new StringBuilder();
        for (Object filteredId : this.filteredIds) {
            if (this.primaryField.getDataType() == DataType.VarChar) {
                filteredIdsStr.append("\"").append(filteredId.toString()).append("\",");
                continue;
            }
            filteredIdsStr.append((Long)filteredId).append(",");
        }
        if ((filteredIdsStr = new StringBuilder(filteredIdsStr.substring(0, filteredIdsStr.length() - 1))).length() > 0) {
            if (expr != null && !expr.isEmpty()) {
                String filterExpr = String.format(" and %s not in [%s]", this.primaryField.getName(), filteredIdsStr);
                return expr + filterExpr;
            }
            return String.format("%s not in [%s]", this.primaryField.getName(), filteredIdsStr);
        }
        return expr;
    }

    private int pushNewPageToCache(List<QueryResultsWrapper.RowRecord> page) {
        if (page == null) {
            throw new ParamException("Cannot push None page into cache");
        }
        List<QueryResultsWrapper.RowRecord> cachedPage = this.iteratorCache.fetchCache(this.cacheId);
        if (cachedPage == null) {
            this.iteratorCache.cache(this.cacheId, page);
            cachedPage = page;
        } else {
            cachedPage.addAll(page);
        }
        return cachedPage.size();
    }

    private float getDistance(QueryResultsWrapper.RowRecord record) {
        return ((Float)record.get("distance")).floatValue();
    }

    private String convertToStr(Object value) {
        DecimalFormat df = new DecimalFormat("0.0");
        return df.format(value);
    }

    private float getFloatValue(String key) {
        return ((Double)this.params.get(key)).floatValue();
    }
}

