package org.apache.lucene.util.hnsw;

import java.io.IOException;
import java.util.Comparator;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.SplittableRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import org.apache.lucene.internal.hppc.IntHashSet;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.hnsw.HnswUtil;
import org.hsqldb.types.DTIType;

/* loaded from: input_file:org/apache/lucene/util/hnsw/HnswGraphBuilder.class */
public class HnswGraphBuilder implements HnswBuilder {
    public static final int DEFAULT_MAX_CONN = 16;
    public static final int DEFAULT_BEAM_WIDTH = 100;
    private static final long DEFAULT_RAND_SEED = 42;
    public static final String HNSW_COMPONENT = "HNSW";
    public static long randSeed;
    private final int M;
    private final double ml;
    private final SplittableRandom random;
    protected final RandomVectorScorerSupplier scorerSupplier;
    private final HnswGraphSearcher graphSearcher;
    private final GraphBuilderKnnCollector entryCandidates;
    private final GraphBuilderKnnCollector beamCandidates;
    private final GraphBuilderKnnCollector beamCandidates0;
    protected final OnHeapHnswGraph hnsw;
    protected final HnswLock hnswLock;
    protected InfoStream infoStream;
    protected boolean frozen;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/apache/lucene/util/hnsw/HnswGraphBuilder$GraphBuilderKnnCollector.class */
    public static final class GraphBuilderKnnCollector implements KnnCollector {
        private final NeighborQueue queue;
        private final int k;
        private long visitedCount;

        public GraphBuilderKnnCollector(int i) {
            this.queue = new NeighborQueue(i, false);
            this.k = i;
        }

        public int size() {
            return this.queue.size();
        }

        public int popNode() {
            return this.queue.pop();
        }

        public int[] popUntilNearestKNodes() {
            while (size() > k()) {
                this.queue.pop();
            }
            return this.queue.nodes();
        }

        float minimumScore() {
            return this.queue.topScore();
        }

        public void clear() {
            this.queue.clear();
            this.visitedCount = 0L;
        }

        @Override // org.apache.lucene.search.KnnCollector
        public boolean earlyTerminated() {
            return false;
        }

        @Override // org.apache.lucene.search.KnnCollector
        public void incVisitedCount(int i) {
            this.visitedCount += i;
        }

        @Override // org.apache.lucene.search.KnnCollector
        public long visitedCount() {
            return this.visitedCount;
        }

        @Override // org.apache.lucene.search.KnnCollector
        public long visitLimit() {
            return Long.MAX_VALUE;
        }

        @Override // org.apache.lucene.search.KnnCollector
        public int k() {
            return this.k;
        }

        @Override // org.apache.lucene.search.KnnCollector
        public boolean collect(int i, float f) {
            return this.queue.insertWithOverflow(i, f);
        }

        @Override // org.apache.lucene.search.KnnCollector
        public float minCompetitiveSimilarity() {
            if (this.queue.size() >= k()) {
                return this.queue.topScore();
            }
            return Float.NEGATIVE_INFINITY;
        }

        @Override // org.apache.lucene.search.KnnCollector
        public TopDocs topDocs() {
            throw new IllegalArgumentException();
        }

        @Override // org.apache.lucene.search.KnnCollector
        public KnnSearchStrategy getSearchStrategy() {
            return null;
        }
    }

    public static HnswGraphBuilder create(RandomVectorScorerSupplier randomVectorScorerSupplier, int i, int i2, long j) throws IOException {
        return new HnswGraphBuilder(randomVectorScorerSupplier, i, i2, j, -1);
    }

    public static HnswGraphBuilder create(RandomVectorScorerSupplier randomVectorScorerSupplier, int i, int i2, long j, int i3) throws IOException {
        return new HnswGraphBuilder(randomVectorScorerSupplier, i, i2, j, i3);
    }

    protected HnswGraphBuilder(RandomVectorScorerSupplier randomVectorScorerSupplier, int i, int i2, long j, int i3) throws IOException {
        this(randomVectorScorerSupplier, i, i2, j, new OnHeapHnswGraph(i, i3));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public HnswGraphBuilder(RandomVectorScorerSupplier randomVectorScorerSupplier, int i, int i2, long j, OnHeapHnswGraph onHeapHnswGraph) throws IOException {
        this(randomVectorScorerSupplier, i, i2, j, onHeapHnswGraph, null, new HnswGraphSearcher(new NeighborQueue(i2, true), new FixedBitSet(onHeapHnswGraph.size())));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public HnswGraphBuilder(RandomVectorScorerSupplier randomVectorScorerSupplier, int i, int i2, long j, OnHeapHnswGraph onHeapHnswGraph, HnswLock hnswLock, HnswGraphSearcher hnswGraphSearcher) throws IOException {
        this.infoStream = InfoStream.getDefault();
        if (i <= 0) {
            throw new IllegalArgumentException("M (max connections) must be positive");
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("beamWidth must be positive");
        }
        this.M = i;
        this.scorerSupplier = (RandomVectorScorerSupplier) Objects.requireNonNull(randomVectorScorerSupplier, "scorer supplier must not be null");
        this.ml = i == 1 ? 1.0d : 1.0d / Math.log(1.0d * i);
        this.random = new SplittableRandom(j);
        this.hnsw = onHeapHnswGraph;
        this.hnswLock = hnswLock;
        this.graphSearcher = hnswGraphSearcher;
        this.entryCandidates = new GraphBuilderKnnCollector(1);
        this.beamCandidates = new GraphBuilderKnnCollector(i2);
        this.beamCandidates0 = new GraphBuilderKnnCollector(Math.min(i2 / 2, i * 3));
    }

    @Override // org.apache.lucene.util.hnsw.HnswBuilder
    public OnHeapHnswGraph build(int i) throws IOException {
        if (this.frozen) {
            throw new IllegalStateException("This HnswGraphBuilder is frozen and cannot be updated");
        }
        if (this.infoStream.isEnabled(HNSW_COMPONENT)) {
            this.infoStream.message(HNSW_COMPONENT, "build graph from " + i + " vectors");
        }
        addVectors(i);
        return getCompletedGraph();
    }

    @Override // org.apache.lucene.util.hnsw.HnswBuilder
    public void setInfoStream(InfoStream infoStream) {
        this.infoStream = infoStream;
    }

    @Override // org.apache.lucene.util.hnsw.HnswBuilder
    public OnHeapHnswGraph getCompletedGraph() throws IOException {
        if (!this.frozen) {
            finish();
        }
        return getGraph();
    }

    @Override // org.apache.lucene.util.hnsw.HnswBuilder
    public OnHeapHnswGraph getGraph() {
        return this.hnsw;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void addVectors(int i, int i2) throws IOException {
        if (this.frozen) {
            throw new IllegalStateException("This HnswGraphBuilder is frozen and cannot be updated");
        }
        long nanoTime = System.nanoTime();
        long j = nanoTime;
        if (this.infoStream.isEnabled(HNSW_COMPONENT)) {
            this.infoStream.message(HNSW_COMPONENT, "addVectors [" + i + " " + i2 + ")");
        }
        UpdateableRandomVectorScorer scorer = this.scorerSupplier.scorer();
        for (int i3 = i; i3 < i2; i3++) {
            scorer.setScoringOrdinal(i3);
            addGraphNode(i3, scorer);
            if (i3 % 10000 == 0 && this.infoStream.isEnabled(HNSW_COMPONENT)) {
                j = printGraphBuildStatus(i3, nanoTime, j);
            }
        }
    }

    private void addVectors(int i) throws IOException {
        addVectors(0, i);
    }

    public void addGraphNode(int i, UpdateableRandomVectorScorer updateableRandomVectorScorer) throws IOException {
        addGraphNodeInternal(i, updateableRandomVectorScorer, null);
    }

    private void addGraphNodeInternal(int i, UpdateableRandomVectorScorer updateableRandomVectorScorer, IntHashSet intHashSet) throws IOException {
        int numLevels;
        if (this.frozen) {
            throw new IllegalStateException("Graph builder is already frozen");
        }
        int randomGraphLevel = getRandomGraphLevel(this.ml, this.random);
        for (int i2 = randomGraphLevel; i2 >= 0; i2--) {
            this.hnsw.addNode(i2, i);
        }
        if (this.hnsw.trySetNewEntryNode(i, randomGraphLevel)) {
            return;
        }
        int i3 = 0;
        do {
            numLevels = this.hnsw.numLevels() - 1;
            int[] iArr = {this.hnsw.entryNode()};
            GraphBuilderKnnCollector graphBuilderKnnCollector = this.entryCandidates;
            for (int i4 = numLevels; i4 > randomGraphLevel; i4--) {
                graphBuilderKnnCollector.clear();
                this.graphSearcher.searchLevel(graphBuilderKnnCollector, updateableRandomVectorScorer, i4, iArr, this.hnsw, null);
                iArr[0] = graphBuilderKnnCollector.popNode();
            }
            GraphBuilderKnnCollector graphBuilderKnnCollector2 = this.beamCandidates;
            NeighborArray[] neighborArrayArr = new NeighborArray[(Math.min(randomGraphLevel, numLevels) - i3) + 1];
            for (int length = neighborArrayArr.length - 1; length >= 0; length--) {
                int i5 = length + i3;
                graphBuilderKnnCollector2.clear();
                if (i5 == 0 && intHashSet != null && intHashSet.size() > 0) {
                    iArr = intHashSet.toArray();
                    graphBuilderKnnCollector2 = this.beamCandidates0;
                }
                this.graphSearcher.searchLevel(graphBuilderKnnCollector2, updateableRandomVectorScorer, i5, iArr, this.hnsw, null);
                iArr = graphBuilderKnnCollector2.popUntilNearestKNodes();
                neighborArrayArr[length] = new NeighborArray(Math.max(graphBuilderKnnCollector2.k(), this.M + 1), false);
                popToScratch(graphBuilderKnnCollector2, neighborArrayArr[length]);
            }
            for (int i6 = 0; i6 < neighborArrayArr.length; i6++) {
                addDiverseNeighbors(i6 + i3, i, neighborArrayArr[i6], updateableRandomVectorScorer);
            }
            i3 += neighborArrayArr.length;
            if (!$assertionsDisabled && i3 != Math.min(randomGraphLevel, numLevels) + 1) {
                throw new AssertionError();
            }
            if (i3 > randomGraphLevel) {
                return;
            }
            if (!$assertionsDisabled && (i3 != numLevels + 1 || randomGraphLevel <= numLevels)) {
                throw new AssertionError();
            }
            if (this.hnsw.tryPromoteNewEntryNode(i, randomGraphLevel, numLevels)) {
                return;
            }
        } while (this.hnsw.numLevels() != numLevels + 1);
        throw new IllegalStateException("We're not able to promote node " + i + " at level " + randomGraphLevel + " as entry node. But the max graph level " + numLevels + " has not changed while we are inserting the node.");
    }

    @Override // org.apache.lucene.util.hnsw.HnswBuilder
    public void addGraphNode(int i) throws IOException {
        UpdateableRandomVectorScorer scorer = this.scorerSupplier.scorer();
        scorer.setScoringOrdinal(i);
        addGraphNodeInternal(i, scorer, null);
    }

    public void addGraphNodeWithEps(int i, IntHashSet intHashSet) throws IOException {
        UpdateableRandomVectorScorer scorer = this.scorerSupplier.scorer();
        scorer.setScoringOrdinal(i);
        addGraphNodeInternal(i, scorer, intHashSet);
    }

    private long printGraphBuildStatus(int i, long j, long j2) {
        long nanoTime = System.nanoTime();
        this.infoStream.message(HNSW_COMPONENT, String.format(Locale.ROOT, "built %d in %d/%d ms", Integer.valueOf(i), Long.valueOf(TimeUnit.NANOSECONDS.toMillis(nanoTime - j2)), Long.valueOf(TimeUnit.NANOSECONDS.toMillis(nanoTime - j))));
        return nanoTime;
    }

    private void addDiverseNeighbors(int i, int i2, NeighborArray neighborArray, UpdateableRandomVectorScorer updateableRandomVectorScorer) throws IOException {
        NeighborArray neighbors = this.hnsw.getNeighbors(i, i2);
        if (!$assertionsDisabled && neighbors.size() != 0) {
            throw new AssertionError();
        }
        boolean[] selectAndLinkDiverse = selectAndLinkDiverse(neighbors, neighborArray, i == 0 ? this.M * 2 : this.M, updateableRandomVectorScorer);
        for (int i3 = 0; i3 < neighborArray.size(); i3++) {
            if (selectAndLinkDiverse[i3]) {
                int i4 = neighborArray.nodes()[i3];
                if (this.hnswLock != null) {
                    Lock write = this.hnswLock.write(i, i4);
                    try {
                        getGraph().getNeighbors(i, i4).addAndEnsureDiversity(i2, neighborArray.getScores(i3), i4, updateableRandomVectorScorer);
                        write.unlock();
                    } catch (Throwable th) {
                        write.unlock();
                        throw th;
                    }
                } else {
                    this.hnsw.getNeighbors(i, i4).addAndEnsureDiversity(i2, neighborArray.getScores(i3), i4, updateableRandomVectorScorer);
                }
            }
        }
    }

    private boolean[] selectAndLinkDiverse(NeighborArray neighborArray, NeighborArray neighborArray2, int i, UpdateableRandomVectorScorer updateableRandomVectorScorer) throws IOException {
        boolean[] zArr = new boolean[neighborArray2.size()];
        for (int size = neighborArray2.size() - 1; neighborArray.size() < i && size >= 0; size--) {
            int i2 = neighborArray2.nodes()[size];
            float scores = neighborArray2.getScores(size);
            if (!$assertionsDisabled && i2 > this.hnsw.maxNodeId()) {
                throw new AssertionError();
            }
            updateableRandomVectorScorer.setScoringOrdinal(i2);
            if (diversityCheck(scores, neighborArray, updateableRandomVectorScorer)) {
                zArr[size] = true;
                neighborArray.addInOrder(i2, scores);
            }
        }
        return zArr;
    }

    private static void popToScratch(GraphBuilderKnnCollector graphBuilderKnnCollector, NeighborArray neighborArray) {
        neighborArray.clear();
        int size = graphBuilderKnnCollector.size();
        for (int i = 0; i < size; i++) {
            neighborArray.addInOrder(graphBuilderKnnCollector.popNode(), graphBuilderKnnCollector.minimumScore());
        }
    }

    private boolean diversityCheck(float f, NeighborArray neighborArray, RandomVectorScorer randomVectorScorer) throws IOException {
        for (int i = 0; i < neighborArray.size(); i++) {
            if (randomVectorScorer.score(neighborArray.nodes()[i]) >= f) {
                return false;
            }
        }
        return true;
    }

    private static int getRandomGraphLevel(double d, SplittableRandom splittableRandom) {
        double nextDouble;
        do {
            nextDouble = splittableRandom.nextDouble();
        } while (nextDouble == 0.0d);
        return (int) ((-Math.log(nextDouble)) * d);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void finish() throws IOException {
        this.frozen = true;
    }

    private void connectComponents() throws IOException {
        long nanoTime = System.nanoTime();
        for (int i = 0; i < this.hnsw.numLevels(); i++) {
            if (!connectComponents(i) && this.infoStream.isEnabled(HNSW_COMPONENT)) {
                this.infoStream.message(HNSW_COMPONENT, "connectComponents failed on level " + i);
            }
        }
        if (this.infoStream.isEnabled(HNSW_COMPONENT)) {
            this.infoStream.message(HNSW_COMPONENT, "connectComponents " + ((System.nanoTime() - nanoTime) / DTIType.nanosInMilli) + " ms");
        }
    }

    private boolean connectComponents(int i) throws IOException {
        FixedBitSet fixedBitSet = new FixedBitSet(this.hnsw.size());
        int i2 = this.M;
        if (i == 0) {
            i2 *= 2;
        }
        List<HnswUtil.Component> components = HnswUtil.components(this.hnsw, i, fixedBitSet, i2);
        if (this.infoStream.isEnabled(HNSW_COMPONENT)) {
            this.infoStream.message(HNSW_COMPONENT, "connect " + components.size() + " components on level=" + i);
        }
        boolean z = true;
        if (components.size() > 1) {
            HnswUtil.Component component = components.stream().max(Comparator.comparingInt((v0) -> {
                return v0.size();
            })).get();
            if (component.start() == Integer.MAX_VALUE) {
                return false;
            }
            GraphBuilderKnnCollector graphBuilderKnnCollector = new GraphBuilderKnnCollector(2);
            int[] iArr = new int[1];
            UpdateableRandomVectorScorer scorer = this.scorerSupplier.scorer();
            for (HnswUtil.Component component2 : components) {
                if (component2 != component && component2.start() != Integer.MAX_VALUE) {
                    if (this.infoStream.isEnabled(HNSW_COMPONENT)) {
                        this.infoStream.message(HNSW_COMPONENT, "connect component " + String.valueOf(component2) + " to " + String.valueOf(component));
                    }
                    graphBuilderKnnCollector.clear();
                    iArr[0] = component.start();
                    scorer.setScoringOrdinal(component2.start());
                    this.graphSearcher.searchLevel(graphBuilderKnnCollector, scorer, i, iArr, this.hnsw, fixedBitSet);
                    boolean z2 = false;
                    while (graphBuilderKnnCollector.size() > 0) {
                        int popNode = graphBuilderKnnCollector.popNode();
                        if (popNode != component2.start() && fixedBitSet.get(popNode)) {
                            float minimumScore = graphBuilderKnnCollector.minimumScore();
                            if (!$assertionsDisabled && !fixedBitSet.get(popNode)) {
                                throw new AssertionError();
                            }
                            link(i, popNode, component2.start(), minimumScore, fixedBitSet);
                            z2 = true;
                            if (this.infoStream.isEnabled(HNSW_COMPONENT)) {
                                this.infoStream.message(HNSW_COMPONENT, "connected ok " + popNode + " -> " + component2.start());
                            }
                        }
                    }
                    if (!z2) {
                        if (this.infoStream.isEnabled(HNSW_COMPONENT)) {
                            this.infoStream.message(HNSW_COMPONENT, "not connected; no free nodes found");
                        }
                        z = false;
                    }
                }
            }
        }
        return z;
    }

    private void link(int i, int i2, int i3, float f, FixedBitSet fixedBitSet) {
        NeighborArray neighbors = this.hnsw.getNeighbors(i, i2);
        NeighborArray neighbors2 = this.hnsw.getNeighbors(i, i3);
        int maxSize = neighbors.maxSize() - 1;
        if (!$assertionsDisabled && !fixedBitSet.get(i2)) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && neighbors.size() >= maxSize) {
            throw new AssertionError("node " + i2 + " is full, has " + neighbors.size() + " friends");
        }
        neighbors.addOutOfOrder(i3, f);
        if (neighbors.size() == maxSize) {
            fixedBitSet.clear(i2);
        }
        if (neighbors2.size() < maxSize) {
            neighbors2.addOutOfOrder(i2, f);
            if (neighbors2.size() == maxSize) {
                fixedBitSet.clear(i3);
            }
        }
    }

    static {
        $assertionsDisabled = !HnswGraphBuilder.class.desiredAssertionStatus();
        randSeed = DEFAULT_RAND_SEED;
    }
}
