/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.codecs.lucene104;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorScorer;
import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.lucene104.OffHeapScalarQuantizedFloatVectorValues;
import org.apache.lucene.codecs.lucene104.OffHeapScalarQuantizedVectorValues;
import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataAccessHint;
import org.apache.lucene.store.FileDataHint;
import org.apache.lucene.store.FileTypeHint;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
import org.apache.lucene.util.quantization.QuantizedVectorsReader;
import org.apache.lucene.util.quantization.ScalarQuantizer;

public class Lucene104ScalarQuantizedVectorsReader
extends FlatVectorsReader
implements QuantizedVectorsReader {
    private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Lucene104ScalarQuantizedVectorsReader.class);
    private final Map<String, FieldEntry> fields = new HashMap<String, FieldEntry>();
    private final IndexInput quantizedVectorData;
    private final FlatVectorsReader rawVectorsReader;
    private final Lucene104ScalarQuantizedVectorScorer vectorScorer;
    public static final int EXHAUSTIVE_BULK_SCORE_ORDS = 64;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public Lucene104ScalarQuantizedVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader, Lucene104ScalarQuantizedVectorScorer vectorsScorer) throws IOException {
        super(vectorsScorer);
        this.vectorScorer = vectorsScorer;
        this.rawVectorsReader = rawVectorsReader;
        int versionMeta = -1;
        String metaFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, "vemq");
        try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName);){
            Throwable priorE = null;
            try {
                versionMeta = CodecUtil.checkIndexHeader(meta, "Lucene104ScalarQuantizedVectorsFormatMeta", 0, 0, state.segmentInfo.getId(), state.segmentSuffix);
                this.readFields(meta, state.fieldInfos);
            }
            catch (Throwable exception) {
                priorE = exception;
            }
            finally {
                CodecUtil.checkFooter(meta, priorE);
            }
            this.quantizedVectorData = Lucene104ScalarQuantizedVectorsReader.openDataInput(state, versionMeta, "veq", "Lucene104ScalarQuantizedVectorsFormatData", state.context.withHints(FileTypeHint.DATA, FileDataHint.KNN_VECTORS, DataAccessHint.RANDOM));
        }
        catch (Throwable t) {
            IOUtils.closeWhileHandlingException(this);
            throw t;
        }
    }

    private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException {
        int fieldNumber = meta.readInt();
        while (fieldNumber != -1) {
            FieldInfo info = infos.fieldInfo(fieldNumber);
            if (info == null) {
                throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
            }
            FieldEntry fieldEntry = this.readField(meta, info);
            Lucene104ScalarQuantizedVectorsReader.validateFieldEntry(info, fieldEntry);
            this.fields.put(info.name, fieldEntry);
            fieldNumber = meta.readInt();
        }
    }

    static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) {
        int dimension = info.getVectorDimension();
        if (dimension != fieldEntry.dimension) {
            throw new IllegalStateException("Inconsistent vector dimension for field=\"" + info.name + "\"; " + dimension + " != " + fieldEntry.dimension);
        }
        long numQuantizedVectorBytes = Math.multiplyExact((long)(fieldEntry.scalarEncoding.getDocPackedLength(dimension) + 12 + 4), (long)fieldEntry.size);
        if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) {
            throw new IllegalStateException("vector data length " + fieldEntry.vectorDataLength + " not matching size = " + fieldEntry.size + " * (dims=" + dimension + " + 16) = " + numQuantizedVectorBytes);
        }
    }

    @Override
    public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
        FieldEntry fi = this.fields.get(field);
        if (fi == null) {
            return null;
        }
        return this.vectorScorer.getRandomVectorScorer(fi.similarityFunction, (KnnVectorValues)OffHeapScalarQuantizedVectorValues.load(fi.ordToDocDISIReaderConfiguration, fi.dimension, fi.size, new OptimizedScalarQuantizer(fi.similarityFunction), fi.scalarEncoding, fi.similarityFunction, this.vectorScorer, fi.centroid, fi.centroidDP, fi.vectorDataOffset, fi.vectorDataLength, this.quantizedVectorData), target);
    }

    @Override
    public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException {
        return this.rawVectorsReader.getRandomVectorScorer(field, target);
    }

    @Override
    public void checkIntegrity() throws IOException {
        this.rawVectorsReader.checkIntegrity();
        CodecUtil.checksumEntireFile(this.quantizedVectorData);
    }

    @Override
    public FloatVectorValues getFloatVectorValues(String field) throws IOException {
        FieldEntry fi = this.fields.get(field);
        if (fi == null) {
            return null;
        }
        if (fi.vectorEncoding != VectorEncoding.FLOAT32) {
            throw new IllegalArgumentException("field=\"" + field + "\" is encoded as: " + String.valueOf((Object)fi.vectorEncoding) + " expected: " + String.valueOf((Object)VectorEncoding.FLOAT32));
        }
        FloatVectorValues rawFloatVectorValues = this.rawVectorsReader.getFloatVectorValues(field);
        if (rawFloatVectorValues.size() == 0) {
            return OffHeapScalarQuantizedFloatVectorValues.load(fi.ordToDocDISIReaderConfiguration, fi.dimension, fi.size, fi.scalarEncoding, fi.similarityFunction, this.vectorScorer, fi.centroid, fi.vectorDataOffset, fi.vectorDataLength, this.quantizedVectorData);
        }
        OffHeapScalarQuantizedVectorValues sqvv = OffHeapScalarQuantizedVectorValues.load(fi.ordToDocDISIReaderConfiguration, fi.dimension, fi.size, new OptimizedScalarQuantizer(fi.similarityFunction), fi.scalarEncoding, fi.similarityFunction, this.vectorScorer, fi.centroid, fi.centroidDP, fi.vectorDataOffset, fi.vectorDataLength, this.quantizedVectorData);
        return new ScalarQuantizedVectorValues(rawFloatVectorValues, sqvv);
    }

    @Override
    public ByteVectorValues getByteVectorValues(String field) throws IOException {
        return this.rawVectorsReader.getByteVectorValues(field);
    }

    @Override
    public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
        this.rawVectorsReader.search(field, target, knnCollector, acceptDocs);
    }

    @Override
    public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
        if (knnCollector.k() == 0) {
            return;
        }
        RandomVectorScorer scorer = this.getRandomVectorScorer(field, target);
        if (scorer == null) {
            return;
        }
        Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs.bits());
        int[] ords = new int[64];
        float[] scores = new float[64];
        int numOrds = 0;
        int numVectors = scorer.maxOrd();
        for (int i = 0; i < numVectors; ++i) {
            if (acceptedOrds != null && !acceptedOrds.get(i)) continue;
            if (knnCollector.earlyTerminated()) break;
            ords[numOrds++] = i;
            if (numOrds != ords.length) continue;
            knnCollector.incVisitedCount(numOrds);
            if (scorer.bulkScore(ords, scores, numOrds) > knnCollector.minCompetitiveSimilarity()) {
                for (int j = 0; j < numOrds; ++j) {
                    knnCollector.collect(scorer.ordToDoc(ords[j]), scores[j]);
                }
            }
            numOrds = 0;
        }
        if (numOrds > 0) {
            knnCollector.incVisitedCount(numOrds);
            if (scorer.bulkScore(ords, scores, numOrds) > knnCollector.minCompetitiveSimilarity()) {
                for (int j = 0; j < numOrds; ++j) {
                    knnCollector.collect(scorer.ordToDoc(ords[j]), scores[j]);
                }
            }
        }
    }

    @Override
    public void close() throws IOException {
        IOUtils.close(this.quantizedVectorData, this.rawVectorsReader);
    }

    @Override
    public long ramBytesUsed() {
        long size = SHALLOW_SIZE;
        size += RamUsageEstimator.sizeOfMap(this.fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class));
        return size += this.rawVectorsReader.ramBytesUsed();
    }

    @Override
    public Map<String, Long> getOffHeapByteSize(FieldInfo fieldInfo) {
        Objects.requireNonNull(fieldInfo);
        Map<String, Long> raw = this.rawVectorsReader.getOffHeapByteSize(fieldInfo);
        FieldEntry fieldEntry = this.fields.get(fieldInfo.name);
        if (fieldEntry == null) {
            assert (fieldInfo.getVectorEncoding() == VectorEncoding.BYTE);
            return raw;
        }
        Map<String, Long> quant = Map.of("veq", fieldEntry.vectorDataLength());
        return KnnVectorsReader.mergeOffHeapByteSizeMaps(raw, quant);
    }

    public float[] getCentroid(String field) {
        FieldEntry fieldEntry = this.fields.get(field);
        if (fieldEntry != null) {
            return fieldEntry.centroid;
        }
        return null;
    }

    private static IndexInput openDataInput(SegmentReadState state, int versionMeta, String fileExtension, String codecName, IOContext context) throws IOException {
        String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
        IndexInput in = state.directory.openInput(fileName, context);
        try {
            int versionVectorData = CodecUtil.checkIndexHeader(in, codecName, 0, 0, state.segmentInfo.getId(), state.segmentSuffix);
            if (versionMeta != versionVectorData) {
                throw new CorruptIndexException("Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData, in);
            }
            CodecUtil.retrieveChecksum(in);
            return in;
        }
        catch (Throwable t) {
            IOUtils.closeWhileHandlingException(in);
            throw t;
        }
    }

    private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
        VectorEncoding vectorEncoding = Lucene99HnswVectorsReader.readVectorEncoding(input);
        VectorSimilarityFunction similarityFunction = Lucene99HnswVectorsReader.readSimilarityFunction(input);
        if (similarityFunction != info.getVectorSimilarityFunction()) {
            throw new IllegalStateException("Inconsistent vector similarity function for field=\"" + info.name + "\"; " + String.valueOf((Object)similarityFunction) + " != " + String.valueOf((Object)info.getVectorSimilarityFunction()));
        }
        return FieldEntry.create(input, vectorEncoding, info.getVectorSimilarityFunction());
    }

    @Override
    public org.apache.lucene.util.quantization.QuantizedByteVectorValues getQuantizedVectorValues(String field) throws IOException {
        FieldEntry fi = this.fields.get(field);
        if (fi == null) {
            return null;
        }
        if (fi.vectorEncoding != VectorEncoding.FLOAT32) {
            throw new IllegalArgumentException("field=\"" + field + "\" is encoded as: " + String.valueOf((Object)fi.vectorEncoding) + " expected: " + String.valueOf((Object)VectorEncoding.FLOAT32));
        }
        final OffHeapScalarQuantizedVectorValues qv = OffHeapScalarQuantizedVectorValues.load(fi.ordToDocDISIReaderConfiguration, fi.dimension, fi.size, new OptimizedScalarQuantizer(fi.similarityFunction), fi.scalarEncoding, fi.similarityFunction, this.vectorScorer, fi.centroid, fi.centroidDP, fi.vectorDataOffset, fi.vectorDataLength, this.quantizedVectorData);
        return new org.apache.lucene.util.quantization.QuantizedByteVectorValues(this){

            @Override
            public float getScoreCorrectionConstant(int ord) throws IOException {
                return 0.0f;
            }

            @Override
            public byte[] vectorValue(int ord) throws IOException {
                return qv.vectorValue(ord);
            }

            @Override
            public int dimension() {
                return qv.dimension();
            }

            @Override
            public int size() {
                return qv.size();
            }
        };
    }

    @Override
    public ScalarQuantizer getQuantizationState(String fieldName) {
        return null;
    }

    private record FieldEntry(VectorSimilarityFunction similarityFunction, VectorEncoding vectorEncoding, int dimension, long vectorDataOffset, long vectorDataLength, int size, Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding scalarEncoding, float[] centroid, float centroidDP, OrdToDocDISIReaderConfiguration ordToDocDISIReaderConfiguration) {
        static FieldEntry create(IndexInput input, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) throws IOException {
            float[] centroid;
            int dimension = input.readVInt();
            long vectorDataOffset = input.readVLong();
            long vectorDataLength = input.readVLong();
            int size = input.readVInt();
            float centroidDP = 0.0f;
            Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding scalarEncoding = Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.UNSIGNED_BYTE;
            if (size > 0) {
                int wireNumber = input.readVInt();
                scalarEncoding = Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.fromWireNumber(wireNumber).orElseThrow(() -> new IllegalStateException("Could not get ScalarEncoding from wire number: " + wireNumber));
                centroid = new float[dimension];
                input.readFloats(centroid, 0, dimension);
                centroidDP = Float.intBitsToFloat(input.readInt());
            } else {
                centroid = null;
            }
            OrdToDocDISIReaderConfiguration conf = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size);
            return new FieldEntry(similarityFunction, vectorEncoding, dimension, vectorDataOffset, vectorDataLength, size, scalarEncoding, centroid, centroidDP, conf);
        }
    }

    protected static final class ScalarQuantizedVectorValues
    extends FloatVectorValues {
        private final FloatVectorValues rawVectorValues;
        private final QuantizedByteVectorValues quantizedVectorValues;

        ScalarQuantizedVectorValues(FloatVectorValues rawVectorValues, QuantizedByteVectorValues quantizedVectorValues) {
            this.rawVectorValues = rawVectorValues;
            this.quantizedVectorValues = quantizedVectorValues;
        }

        @Override
        public int dimension() {
            return this.rawVectorValues.dimension();
        }

        @Override
        public int size() {
            return this.rawVectorValues.size();
        }

        @Override
        public float[] vectorValue(int ord) throws IOException {
            return this.rawVectorValues.vectorValue(ord);
        }

        @Override
        public ScalarQuantizedVectorValues copy() throws IOException {
            return new ScalarQuantizedVectorValues(this.rawVectorValues.copy(), this.quantizedVectorValues.copy());
        }

        @Override
        public Bits getAcceptOrds(Bits acceptDocs) {
            return this.rawVectorValues.getAcceptOrds(acceptDocs);
        }

        @Override
        public int ordToDoc(int ord) {
            return this.rawVectorValues.ordToDoc(ord);
        }

        @Override
        public KnnVectorValues.DocIndexIterator iterator() {
            return this.rawVectorValues.iterator();
        }

        @Override
        public VectorScorer scorer(float[] query) throws IOException {
            return this.quantizedVectorValues.scorer(query);
        }

        @Override
        public VectorScorer rescorer(float[] target) throws IOException {
            return this.rawVectorValues.rescorer(target);
        }

        QuantizedByteVectorValues getQuantizedVectorValues() throws IOException {
            return this.quantizedVectorValues;
        }
    }
}

