/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index;

import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentCommitInfo;
import org.apache.lucene.index.SegmentInfo;
import org.apache.lucene.index.SegmentReader;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.Version;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.index.engine.Engine;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.util.KNNCodecUtil;
import org.opensearch.knn.index.codec.util.NativeMemoryCacheKeyHelper;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.MemoryOptimizedSearchSupportSpec;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.memory.NativeMemoryEntryContext;
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
import org.opensearch.knn.index.query.SegmentLevelQuantizationInfo;
import org.opensearch.knn.index.query.SegmentLevelQuantizationUtil;
import org.opensearch.knn.index.util.IndexUtil;

public class KNNIndexShard {
    @Generated
    private static final Logger log = LogManager.getLogger(KNNIndexShard.class);
    private final IndexShard indexShard;
    private final NativeMemoryCacheManager nativeMemoryCacheManager;
    private static final String INDEX_SHARD_CLEAR_CACHE_SEARCHER = "knn-clear-cache";

    public KNNIndexShard(IndexShard indexShard) {
        this.indexShard = indexShard;
        this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance();
    }

    public String getIndexName() {
        return this.indexShard.shardId().getIndexName();
    }

    private Set<String> warmUpMemoryOptimizedSearcher(LeafReader leafReader, MapperService mapperService, String indexName) {
        Set fieldsForMemoryOptimizedSearch = StreamSupport.stream(leafReader.getFieldInfos().spliterator(), false).filter(fieldInfo -> fieldInfo.attributes().containsKey("knn_field")).filter(fieldInfo -> {
            MappedFieldType fieldType = mapperService.fieldType(fieldInfo.getName());
            if (fieldType instanceof KNNVectorFieldType) {
                KNNVectorFieldType knnFieldType = (KNNVectorFieldType)fieldType;
                return MemoryOptimizedSearchSupportSpec.isSupportedFieldType(knnFieldType, indexName);
            }
            return false;
        }).collect(Collectors.toSet());
        SegmentReader segmentReader = Lucene.segmentReader((LeafReader)leafReader);
        for (FieldInfo field : fieldsForMemoryOptimizedSearch) {
            String dataTypeStr = field.getAttribute("data_type");
            if (dataTypeStr == null) continue;
            try {
                VectorDataType vectorDataType = VectorDataType.get(dataTypeStr);
                if (vectorDataType == VectorDataType.FLOAT) {
                    segmentReader.getVectorReader().search(field.getName(), (float[])null, null, null);
                    continue;
                }
                segmentReader.getVectorReader().search(field.getName(), (byte[])null, null, null);
            }
            catch (Exception exception) {}
        }
        return fieldsForMemoryOptimizedSearch.stream().map(FieldInfo::getName).collect(Collectors.toSet());
    }

    public void warmup() throws IOException {
        String indexName = this.indexShard.shardId().getIndexName();
        log.info("[KNN] Warming up index: [{}]", (Object)indexName);
        MapperService mapperService = this.indexShard.mapperService();
        Directory directory = this.indexShard.store().directory();
        try (Engine.Searcher searcher = this.indexShard.acquireSearcher("knn-warmup-mem");){
            for (LeafReaderContext leafReaderContext : searcher.getIndexReader().leaves()) {
                Set<String> loadedFieldNames = this.warmUpMemoryOptimizedSearcher(leafReaderContext.reader(), mapperService, indexName);
                log.info("[KNN] Loaded memory optimized searchers for fields {}", loadedFieldNames);
                List<EngineFileContext> engineFileContexts = this.getAllEngineFileContexts(loadedFieldNames, leafReaderContext);
                this.warmUpOffHeapIndex(engineFileContexts, directory);
                log.info("[KNN] Loaded off-heap indices for fields {}", engineFileContexts.stream().map(ctx -> ctx.fieldName).collect(Collectors.toSet()));
            }
        }
        catch (Exception e) {
            log.error("Failed warm-up index: [{}]", (Object)indexName, (Object)e);
            throw e;
        }
    }

    private void warmUpOffHeapIndex(List<EngineFileContext> engineFileContexts, Directory directory) {
        for (EngineFileContext engineFileContext : engineFileContexts) {
            try {
                String cacheKey = NativeMemoryCacheKeyHelper.constructCacheKey(engineFileContext.vectorFileName, engineFileContext.segmentInfo);
                this.nativeMemoryCacheManager.get(new NativeMemoryEntryContext.IndexEntryContext(directory, cacheKey, NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), IndexUtil.getParametersAtLoading(engineFileContext.getSpaceType(), KNNEngine.getEngineNameFromPath(engineFileContext.getVectorFileName()), this.getIndexName(), engineFileContext.getVectorDataType(), engineFileContext.getSegmentLevelQuantizationInfo()), this.getIndexName(), engineFileContext.getModelId()), true);
            }
            catch (ExecutionException ex) {
                throw new RuntimeException(ex);
            }
        }
    }

    public void clearCache() {
        String indexName = this.getIndexName();
        Optional<NativeMemoryAllocation> indexAllocationOptional = this.nativeMemoryCacheManager.getIndexMemoryAllocation(indexName);
        if (indexAllocationOptional.isPresent()) {
            NativeMemoryAllocation indexAllocation = indexAllocationOptional.get();
            indexAllocation.writeLock();
            log.info("[KNN] Evicting index from cache: [{}]", (Object)indexName);
            try (Engine.Searcher searcher = this.indexShard.acquireSearcher(INDEX_SHARD_CLEAR_CACHE_SEARCHER);){
                for (LeafReaderContext leafReaderContext : searcher.getIndexReader().leaves()) {
                    this.getAllEngineFileContexts(Collections.emptySet(), leafReaderContext).forEach(engineFileContext -> {
                        String cacheKey = NativeMemoryCacheKeyHelper.constructCacheKey(engineFileContext.vectorFileName, engineFileContext.segmentInfo);
                        this.nativeMemoryCacheManager.invalidate(cacheKey);
                    });
                }
            }
            catch (IOException ex) {
                log.error("[KNN] Failed to evict index from cache: [{}]", (Object)indexName, (Object)ex);
                throw new RuntimeException(ex);
            }
            finally {
                indexAllocation.writeUnlock();
            }
        }
    }

    @VisibleForTesting
    List<EngineFileContext> getAllEngineFileContexts(Set<String> loadedFieldNames, LeafReaderContext leafReaderContext) throws IOException {
        ArrayList<EngineFileContext> engineFiles = new ArrayList<EngineFileContext>();
        for (KNNEngine knnEngine : KNNEngine.getEnginesThatCreateCustomSegmentFiles()) {
            engineFiles.addAll(this.getEngineFileContexts(loadedFieldNames, leafReaderContext, knnEngine));
        }
        return engineFiles;
    }

    List<EngineFileContext> getEngineFileContexts(Set<String> loadedFieldNames, LeafReaderContext leafReaderContext, KNNEngine knnEngine) throws IOException {
        ArrayList<EngineFileContext> engineFiles = new ArrayList<EngineFileContext>();
        SegmentReader reader = Lucene.segmentReader((LeafReader)leafReaderContext.reader());
        String fileExtension = reader.getSegmentInfo().info.getUseCompoundFile() ? knnEngine.getCompoundExtension() : knnEngine.getExtension();
        for (FieldInfo fieldInfo : reader.getFieldInfos()) {
            if (loadedFieldNames.contains(fieldInfo.getName()) || !fieldInfo.attributes().containsKey("knn_field")) continue;
            String spaceTypeName = fieldInfo.attributes().getOrDefault("spaceType", SpaceType.L2.getValue());
            SpaceType spaceType = SpaceType.getSpace(spaceTypeName);
            String modelId = fieldInfo.attributes().getOrDefault("model_id", null);
            SegmentLevelQuantizationInfo segmentLevelQuantizationInfo = SegmentLevelQuantizationInfo.build((LeafReader)reader, fieldInfo, fieldInfo.name, reader.getSegmentInfo().info.getVersion());
            VectorDataType vectorDataType = this.determineVectorDataType(fieldInfo, segmentLevelQuantizationInfo, reader.getSegmentInfo().info.getVersion());
            engineFiles.addAll(this.getEngineFileContexts(reader.getSegmentInfo(), segmentLevelQuantizationInfo, fieldInfo.name, fileExtension, spaceType, modelId, vectorDataType));
        }
        return engineFiles;
    }

    @VisibleForTesting
    List<EngineFileContext> getEngineFileContexts(SegmentCommitInfo segmentCommitInfo, SegmentLevelQuantizationInfo segmentLevelQuantizationInfo, String fieldName, String fileExtension, SpaceType spaceType, String modelId, VectorDataType vectorDataType) throws IOException {
        String prefix = KNNCodecUtil.buildEngineFilePrefix(segmentCommitInfo.info.name);
        String suffix = KNNCodecUtil.buildEngineFileSuffix(fieldName, fileExtension);
        return segmentCommitInfo.files().stream().filter(fileName -> fileName.startsWith(prefix)).filter(fileName -> fileName.endsWith(suffix)).map(vectorFileName -> new EngineFileContext(fieldName, spaceType, modelId, (String)vectorFileName, vectorDataType, segmentCommitInfo.info, segmentLevelQuantizationInfo)).collect(Collectors.toList());
    }

    @VisibleForTesting
    VectorDataType determineVectorDataType(FieldInfo fieldInfo, SegmentLevelQuantizationInfo segmentLevelQuantizationInfo, Version segmentVersion) {
        if (FieldInfoExtractor.extractQuantizationConfig(fieldInfo, segmentVersion) == QuantizationConfig.EMPTY) {
            return VectorDataType.get(fieldInfo.attributes().getOrDefault("data_type", VectorDataType.FLOAT.getValue()));
        }
        if (SegmentLevelQuantizationUtil.isAdcEnabled(segmentLevelQuantizationInfo)) {
            return VectorDataType.FLOAT;
        }
        return VectorDataType.BINARY;
    }

    @Generated
    public IndexShard getIndexShard() {
        return this.indexShard;
    }

    @VisibleForTesting
    static class EngineFileContext {
        private final String fieldName;
        private final SpaceType spaceType;
        private final String modelId;
        private final String vectorFileName;
        private final VectorDataType vectorDataType;
        private final SegmentInfo segmentInfo;
        private final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo;

        @Generated
        public EngineFileContext(String fieldName, SpaceType spaceType, String modelId, String vectorFileName, VectorDataType vectorDataType, SegmentInfo segmentInfo, SegmentLevelQuantizationInfo segmentLevelQuantizationInfo) {
            this.fieldName = fieldName;
            this.spaceType = spaceType;
            this.modelId = modelId;
            this.vectorFileName = vectorFileName;
            this.vectorDataType = vectorDataType;
            this.segmentInfo = segmentInfo;
            this.segmentLevelQuantizationInfo = segmentLevelQuantizationInfo;
        }

        @Generated
        public String getFieldName() {
            return this.fieldName;
        }

        @Generated
        public SpaceType getSpaceType() {
            return this.spaceType;
        }

        @Generated
        public String getModelId() {
            return this.modelId;
        }

        @Generated
        public String getVectorFileName() {
            return this.vectorFileName;
        }

        @Generated
        public VectorDataType getVectorDataType() {
            return this.vectorDataType;
        }

        @Generated
        public SegmentInfo getSegmentInfo() {
            return this.segmentInfo;
        }

        @Generated
        public SegmentLevelQuantizationInfo getSegmentLevelQuantizationInfo() {
            return this.segmentLevelQuantizationInfo;
        }
    }
}

