package org.apache.lucene.internal.vectorization;

import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.nio.ByteOrder;
import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.IntVector;
import jdk.incubator.vector.LongVector;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorShape;
import jdk.incubator.vector.VectorSpecies;
import org.apache.lucene.internal.vectorization.DefaultVectorUtilSupport;
import org.apache.lucene.util.Constants;
import org.apache.lucene.util.SuppressForbidden;

/* loaded from: input_file:org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.class */
final class PanamaVectorUtilSupport implements VectorUtilSupport {
    private static final VectorSpecies<Float> FLOAT_SPECIES;
    private static final VectorSpecies<Integer> INT_SPECIES;
    private static final VectorSpecies<Byte> BYTE_SPECIES;
    private static final VectorSpecies<Short> SHORT_SPECIES;
    private static final VectorSpecies<Byte> BYTE_SPECIES_128;
    private static final VectorSpecies<Byte> BYTE_SPECIES_256;
    static final int VECTOR_BITSIZE;
    private static final boolean ENABLE_FIND_NEXT_GEQ_VECTOR_OPTO;
    static final /* synthetic */ boolean $assertionsDisabled;

    private static FloatVector fma(FloatVector floatVector, FloatVector floatVector2, FloatVector floatVector3) {
        return Constants.HAS_FAST_VECTOR_FMA ? floatVector.fma(floatVector2, floatVector3) : floatVector.mul(floatVector2).add(floatVector3);
    }

    @SuppressForbidden(reason = "Uses FMA only where fast and carefully contained")
    private static float fma(float f, float f2, float f3) {
        return Constants.HAS_FAST_SCALAR_FMA ? Math.fma(f, f2, f3) : (f * f2) + f3;
    }

    @Override // org.apache.lucene.internal.vectorization.VectorUtilSupport
    public float dotProduct(float[] fArr, float[] fArr2) {
        int i = 0;
        float f = 0.0f;
        if (fArr.length > 2 * FLOAT_SPECIES.length()) {
            i = 0 + FLOAT_SPECIES.loopBound(fArr.length);
            f = 0.0f + dotProductBody(fArr, fArr2, i);
        }
        while (i < fArr.length) {
            f = fma(fArr[i], fArr2[i], f);
            i++;
        }
        return f;
    }

    private float dotProductBody(float[] fArr, float[] fArr2, int i) {
        int i2 = 0;
        FloatVector zero = FloatVector.zero(FLOAT_SPECIES);
        FloatVector zero2 = FloatVector.zero(FLOAT_SPECIES);
        FloatVector zero3 = FloatVector.zero(FLOAT_SPECIES);
        FloatVector zero4 = FloatVector.zero(FLOAT_SPECIES);
        int length = i - (3 * FLOAT_SPECIES.length());
        while (i2 < length) {
            zero = fma(FloatVector.fromArray(FLOAT_SPECIES, fArr, i2), FloatVector.fromArray(FLOAT_SPECIES, fArr2, i2), zero);
            zero2 = fma(FloatVector.fromArray(FLOAT_SPECIES, fArr, i2 + FLOAT_SPECIES.length()), FloatVector.fromArray(FLOAT_SPECIES, fArr2, i2 + FLOAT_SPECIES.length()), zero2);
            zero3 = fma(FloatVector.fromArray(FLOAT_SPECIES, fArr, i2 + (2 * FLOAT_SPECIES.length())), FloatVector.fromArray(FLOAT_SPECIES, fArr2, i2 + (2 * FLOAT_SPECIES.length())), zero3);
            zero4 = fma(FloatVector.fromArray(FLOAT_SPECIES, fArr, i2 + (3 * FLOAT_SPECIES.length())), FloatVector.fromArray(FLOAT_SPECIES, fArr2, i2 + (3 * FLOAT_SPECIES.length())), zero4);
            i2 += 4 * FLOAT_SPECIES.length();
        }
        while (i2 < i) {
            zero = fma(FloatVector.fromArray(FLOAT_SPECIES, fArr, i2), FloatVector.fromArray(FLOAT_SPECIES, fArr2, i2), zero);
            i2 += FLOAT_SPECIES.length();
        }
        return zero.add(zero2).add(zero3.add(zero4)).reduceLanes(VectorOperators.ADD);
    }

    @Override // org.apache.lucene.internal.vectorization.VectorUtilSupport
    public float cosine(float[] fArr, float[] fArr2) {
        int i = 0;
        float f = 0.0f;
        float f2 = 0.0f;
        float f3 = 0.0f;
        if (fArr.length > 2 * FLOAT_SPECIES.length()) {
            i = 0 + FLOAT_SPECIES.loopBound(fArr.length);
            float[] cosineBody = cosineBody(fArr, fArr2, i);
            f = 0.0f + cosineBody[0];
            f2 = 0.0f + cosineBody[1];
            f3 = 0.0f + cosineBody[2];
        }
        while (i < fArr.length) {
            f = fma(fArr[i], fArr2[i], f);
            f2 = fma(fArr[i], fArr[i], f2);
            f3 = fma(fArr2[i], fArr2[i], f3);
            i++;
        }
        return (float) (f / Math.sqrt(f2 * f3));
    }

    private float[] cosineBody(float[] fArr, float[] fArr2, int i) {
        int i2 = 0;
        FloatVector zero = FloatVector.zero(FLOAT_SPECIES);
        FloatVector zero2 = FloatVector.zero(FLOAT_SPECIES);
        FloatVector zero3 = FloatVector.zero(FLOAT_SPECIES);
        FloatVector zero4 = FloatVector.zero(FLOAT_SPECIES);
        FloatVector zero5 = FloatVector.zero(FLOAT_SPECIES);
        FloatVector zero6 = FloatVector.zero(FLOAT_SPECIES);
        int length = i - FLOAT_SPECIES.length();
        while (i2 < length) {
            FloatVector fromArray = FloatVector.fromArray(FLOAT_SPECIES, fArr, i2);
            FloatVector fromArray2 = FloatVector.fromArray(FLOAT_SPECIES, fArr2, i2);
            zero = fma(fromArray, fromArray2, zero);
            zero3 = fma(fromArray, fromArray, zero3);
            zero5 = fma(fromArray2, fromArray2, zero5);
            FloatVector fromArray3 = FloatVector.fromArray(FLOAT_SPECIES, fArr, i2 + FLOAT_SPECIES.length());
            FloatVector fromArray4 = FloatVector.fromArray(FLOAT_SPECIES, fArr2, i2 + FLOAT_SPECIES.length());
            zero2 = fma(fromArray3, fromArray4, zero2);
            zero4 = fma(fromArray3, fromArray3, zero4);
            zero6 = fma(fromArray4, fromArray4, zero6);
            i2 += 2 * FLOAT_SPECIES.length();
        }
        while (i2 < i) {
            FloatVector fromArray5 = FloatVector.fromArray(FLOAT_SPECIES, fArr, i2);
            FloatVector fromArray6 = FloatVector.fromArray(FLOAT_SPECIES, fArr2, i2);
            zero = fma(fromArray5, fromArray6, zero);
            zero3 = fma(fromArray5, fromArray5, zero3);
            zero5 = fma(fromArray6, fromArray6, zero5);
            i2 += FLOAT_SPECIES.length();
        }
        return new float[]{zero.add(zero2).reduceLanes(VectorOperators.ADD), zero3.add(zero4).reduceLanes(VectorOperators.ADD), zero5.add(zero6).reduceLanes(VectorOperators.ADD)};
    }

    @Override // org.apache.lucene.internal.vectorization.VectorUtilSupport
    public float squareDistance(float[] fArr, float[] fArr2) {
        int i = 0;
        float f = 0.0f;
        if (fArr.length > 2 * FLOAT_SPECIES.length()) {
            i = 0 + FLOAT_SPECIES.loopBound(fArr.length);
            f = 0.0f + squareDistanceBody(fArr, fArr2, i);
        }
        while (i < fArr.length) {
            float f2 = fArr[i] - fArr2[i];
            f = fma(f2, f2, f);
            i++;
        }
        return f;
    }

    private float squareDistanceBody(float[] fArr, float[] fArr2, int i) {
        int i2 = 0;
        FloatVector zero = FloatVector.zero(FLOAT_SPECIES);
        FloatVector zero2 = FloatVector.zero(FLOAT_SPECIES);
        FloatVector zero3 = FloatVector.zero(FLOAT_SPECIES);
        FloatVector zero4 = FloatVector.zero(FLOAT_SPECIES);
        int length = i - (3 * FLOAT_SPECIES.length());
        while (i2 < length) {
            FloatVector sub = FloatVector.fromArray(FLOAT_SPECIES, fArr, i2).sub(FloatVector.fromArray(FLOAT_SPECIES, fArr2, i2));
            zero = fma(sub, sub, zero);
            FloatVector sub2 = FloatVector.fromArray(FLOAT_SPECIES, fArr, i2 + FLOAT_SPECIES.length()).sub(FloatVector.fromArray(FLOAT_SPECIES, fArr2, i2 + FLOAT_SPECIES.length()));
            zero2 = fma(sub2, sub2, zero2);
            FloatVector sub3 = FloatVector.fromArray(FLOAT_SPECIES, fArr, i2 + (2 * FLOAT_SPECIES.length())).sub(FloatVector.fromArray(FLOAT_SPECIES, fArr2, i2 + (2 * FLOAT_SPECIES.length())));
            zero3 = fma(sub3, sub3, zero3);
            FloatVector sub4 = FloatVector.fromArray(FLOAT_SPECIES, fArr, i2 + (3 * FLOAT_SPECIES.length())).sub(FloatVector.fromArray(FLOAT_SPECIES, fArr2, i2 + (3 * FLOAT_SPECIES.length())));
            zero4 = fma(sub4, sub4, zero4);
            i2 += 4 * FLOAT_SPECIES.length();
        }
        while (i2 < i) {
            FloatVector sub5 = FloatVector.fromArray(FLOAT_SPECIES, fArr, i2).sub(FloatVector.fromArray(FLOAT_SPECIES, fArr2, i2));
            zero = fma(sub5, sub5, zero);
            i2 += FLOAT_SPECIES.length();
        }
        return zero.add(zero2).add(zero3.add(zero4)).reduceLanes(VectorOperators.ADD);
    }

    @Override // org.apache.lucene.internal.vectorization.VectorUtilSupport
    public int dotProduct(byte[] bArr, byte[] bArr2) {
        return dotProduct(MemorySegment.ofArray(bArr), MemorySegment.ofArray(bArr2));
    }

    public static int dotProduct(MemorySegment memorySegment, MemorySegment memorySegment2) {
        if (!$assertionsDisabled && memorySegment.byteSize() != memorySegment2.byteSize()) {
            throw new AssertionError();
        }
        int i = 0;
        int i2 = 0;
        if (memorySegment.byteSize() >= 16 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) {
            if (VECTOR_BITSIZE >= 512) {
                i = (int) (0 + BYTE_SPECIES.loopBound(memorySegment.byteSize()));
                i2 = 0 + dotProductBody512(memorySegment, memorySegment2, i);
            } else if (VECTOR_BITSIZE == 256) {
                i = (int) (0 + BYTE_SPECIES.loopBound(memorySegment.byteSize()));
                i2 = 0 + dotProductBody256(memorySegment, memorySegment2, i);
            } else {
                i = (int) (0 + ByteVector.SPECIES_64.loopBound(memorySegment.byteSize() - ByteVector.SPECIES_64.length()));
                i2 = 0 + dotProductBody128(memorySegment, memorySegment2, i);
            }
        }
        while (i < memorySegment.byteSize()) {
            i2 += memorySegment2.get(ValueLayout.JAVA_BYTE, i) * memorySegment.get(ValueLayout.JAVA_BYTE, i);
            i++;
        }
        return i2;
    }

    private static int dotProductBody512(MemorySegment memorySegment, MemorySegment memorySegment2, int i) {
        IntVector zero = IntVector.zero(INT_SPECIES);
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (i3 >= i) {
                return zero.reduceLanes(VectorOperators.ADD);
            }
            zero = zero.add(ByteVector.fromMemorySegment(BYTE_SPECIES, memorySegment, i3, ByteOrder.LITTLE_ENDIAN).convertShape(VectorOperators.B2S, SHORT_SPECIES, 0).mul(ByteVector.fromMemorySegment(BYTE_SPECIES, memorySegment2, i3, ByteOrder.LITTLE_ENDIAN).convertShape(VectorOperators.B2S, SHORT_SPECIES, 0)).convertShape(VectorOperators.S2I, INT_SPECIES, 0));
            i2 = i3 + BYTE_SPECIES.length();
        }
    }

    private static int dotProductBody256(MemorySegment memorySegment, MemorySegment memorySegment2, int i) {
        IntVector zero = IntVector.zero(IntVector.SPECIES_256);
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (i3 >= i) {
                return zero.reduceLanes(VectorOperators.ADD);
            }
            zero = zero.add(ByteVector.fromMemorySegment(ByteVector.SPECIES_64, memorySegment, i3, ByteOrder.LITTLE_ENDIAN).convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0).mul(ByteVector.fromMemorySegment(ByteVector.SPECIES_64, memorySegment2, i3, ByteOrder.LITTLE_ENDIAN).convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0)));
            i2 = i3 + ByteVector.SPECIES_64.length();
        }
    }

    private static int dotProductBody128(MemorySegment memorySegment, MemorySegment memorySegment2, int i) {
        IntVector zero = IntVector.zero(IntVector.SPECIES_128);
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (i3 >= i) {
                return zero.reduceLanes(VectorOperators.ADD);
            }
            zero = zero.add(ByteVector.fromMemorySegment(ByteVector.SPECIES_64, memorySegment, i3, ByteOrder.LITTLE_ENDIAN).convert(VectorOperators.B2S, 0).mul(ByteVector.fromMemorySegment(ByteVector.SPECIES_64, memorySegment2, i3, ByteOrder.LITTLE_ENDIAN).convert(VectorOperators.B2S, 0)).convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0));
            i2 = i3 + (ByteVector.SPECIES_64.length() >> 1);
        }
    }

    @Override // org.apache.lucene.internal.vectorization.VectorUtilSupport
    public int int4DotProduct(byte[] bArr, boolean z, byte[] bArr2, boolean z2) {
        if (!$assertionsDisabled) {
            if (z && z2) {
                throw new AssertionError();
            }
        }
        int i = 0;
        int i2 = 0;
        if (z || z2) {
            byte[] bArr3 = z ? bArr : bArr2;
            byte[] bArr4 = z ? bArr2 : bArr;
            if (bArr3.length >= 32) {
                if (VECTOR_BITSIZE >= 512) {
                    i = 0 + ByteVector.SPECIES_256.loopBound(bArr3.length);
                    i2 = 0 + dotProductBody512Int4Packed(bArr4, bArr3, i);
                } else if (VECTOR_BITSIZE == 256) {
                    i = 0 + ByteVector.SPECIES_128.loopBound(bArr3.length);
                    i2 = 0 + dotProductBody256Int4Packed(bArr4, bArr3, i);
                } else if (PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) {
                    i = 0 + ByteVector.SPECIES_64.loopBound(bArr3.length);
                    i2 = 0 + dotProductBody128Int4Packed(bArr4, bArr3, i);
                }
            }
            while (i < bArr3.length) {
                byte b = bArr3[i];
                i2 = i2 + ((b & 15) * bArr4[i + bArr3.length]) + (((b & 255) >> 4) * bArr4[i]);
                i++;
            }
        } else {
            if (VECTOR_BITSIZE >= 512 || VECTOR_BITSIZE == 256) {
                return dotProduct(bArr, bArr2);
            }
            if (bArr.length >= 32 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) {
                i = 0 + ByteVector.SPECIES_128.loopBound(bArr.length);
                i2 = 0 + int4DotProductBody128(bArr, bArr2, i);
            }
            while (i < bArr.length) {
                i2 += bArr2[i] * bArr[i];
                i++;
            }
        }
        return i2;
    }

    private int dotProductBody512Int4Packed(byte[] bArr, byte[] bArr2, int i) {
        int i2 = 0;
        for (int i3 = 0; i3 < i; i3 += 4096) {
            ShortVector zero = ShortVector.zero(ShortVector.SPECIES_512);
            ShortVector zero2 = ShortVector.zero(ShortVector.SPECIES_512);
            int min = Math.min(i - i3, 4096);
            int i4 = 0;
            while (true) {
                int i5 = i4;
                if (i5 < min) {
                    ByteVector fromArray = ByteVector.fromArray(ByteVector.SPECIES_256, bArr2, i3 + i5);
                    zero = zero.add(fromArray.and((byte) 15).mul(ByteVector.fromArray(ByteVector.SPECIES_256, bArr, i3 + i5 + bArr2.length)).convertShape(VectorOperators.ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0));
                    zero2 = zero2.add(fromArray.lanewise(VectorOperators.LSHR, 4L).mul(ByteVector.fromArray(ByteVector.SPECIES_256, bArr, i3 + i5)).convertShape(VectorOperators.ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0));
                    i4 = i5 + ByteVector.SPECIES_256.length();
                }
            }
            IntVector reinterpretAsInts = zero.convertShape(VectorOperators.S2I, IntVector.SPECIES_512, 0).reinterpretAsInts();
            IntVector reinterpretAsInts2 = zero.convertShape(VectorOperators.S2I, IntVector.SPECIES_512, 1).reinterpretAsInts();
            i2 += reinterpretAsInts.add(reinterpretAsInts2).add(zero2.convertShape(VectorOperators.S2I, IntVector.SPECIES_512, 0).reinterpretAsInts()).add(zero2.convertShape(VectorOperators.S2I, IntVector.SPECIES_512, 1).reinterpretAsInts()).reduceLanes(VectorOperators.ADD);
        }
        return i2;
    }

    private int dotProductBody256Int4Packed(byte[] bArr, byte[] bArr2, int i) {
        int i2 = 0;
        for (int i3 = 0; i3 < i; i3 += 2048) {
            ShortVector zero = ShortVector.zero(ShortVector.SPECIES_256);
            ShortVector zero2 = ShortVector.zero(ShortVector.SPECIES_256);
            int min = Math.min(i - i3, 2048);
            int i4 = 0;
            while (true) {
                int i5 = i4;
                if (i5 < min) {
                    ByteVector fromArray = ByteVector.fromArray(ByteVector.SPECIES_128, bArr2, i3 + i5);
                    zero = zero.add(fromArray.and((byte) 15).mul(ByteVector.fromArray(ByteVector.SPECIES_128, bArr, i3 + i5 + bArr2.length)).convertShape(VectorOperators.ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0));
                    zero2 = zero2.add(fromArray.lanewise(VectorOperators.LSHR, 4L).mul(ByteVector.fromArray(ByteVector.SPECIES_128, bArr, i3 + i5)).convertShape(VectorOperators.ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0));
                    i4 = i5 + ByteVector.SPECIES_128.length();
                }
            }
            IntVector reinterpretAsInts = zero.convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 0).reinterpretAsInts();
            IntVector reinterpretAsInts2 = zero.convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 1).reinterpretAsInts();
            i2 += reinterpretAsInts.add(reinterpretAsInts2).add(zero2.convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 0).reinterpretAsInts()).add(zero2.convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 1).reinterpretAsInts()).reduceLanes(VectorOperators.ADD);
        }
        return i2;
    }

    private int dotProductBody128Int4Packed(byte[] bArr, byte[] bArr2, int i) {
        int i2 = 0;
        for (int i3 = 0; i3 < i; i3 += 1024) {
            ShortVector zero = ShortVector.zero(ShortVector.SPECIES_128);
            ShortVector zero2 = ShortVector.zero(ShortVector.SPECIES_128);
            int min = Math.min(i - i3, 1024);
            int i4 = 0;
            while (true) {
                int i5 = i4;
                if (i5 < min) {
                    ByteVector fromArray = ByteVector.fromArray(ByteVector.SPECIES_64, bArr2, i3 + i5);
                    zero = zero.add(fromArray.and((byte) 15).mul(ByteVector.fromArray(ByteVector.SPECIES_64, bArr, i3 + i5 + bArr2.length)).convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts().and((short) 255));
                    zero2 = zero2.add(fromArray.lanewise(VectorOperators.LSHR, 4L).mul(ByteVector.fromArray(ByteVector.SPECIES_64, bArr, i3 + i5)).convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts().and((short) 255));
                    i4 = i5 + ByteVector.SPECIES_64.length();
                }
            }
            IntVector reinterpretAsInts = zero.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0).reinterpretAsInts();
            IntVector reinterpretAsInts2 = zero.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1).reinterpretAsInts();
            i2 += reinterpretAsInts.add(reinterpretAsInts2).add(zero2.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0).reinterpretAsInts()).add(zero2.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1).reinterpretAsInts()).reduceLanes(VectorOperators.ADD);
        }
        return i2;
    }

    private int int4DotProductBody128(byte[] bArr, byte[] bArr2, int i) {
        int i2 = 0;
        for (int i3 = 0; i3 < i; i3 += 1024) {
            ShortVector zero = ShortVector.zero(ShortVector.SPECIES_128);
            ShortVector zero2 = ShortVector.zero(ShortVector.SPECIES_128);
            int min = Math.min(i - i3, 1024);
            int i4 = 0;
            while (true) {
                int i5 = i4;
                if (i5 < min) {
                    zero = zero.add(ByteVector.fromArray(ByteVector.SPECIES_64, bArr, i3 + i5).mul(ByteVector.fromArray(ByteVector.SPECIES_64, bArr2, i3 + i5)).convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts().and((short) 255));
                    zero2 = zero2.add(ByteVector.fromArray(ByteVector.SPECIES_64, bArr, i3 + i5 + 8).mul(ByteVector.fromArray(ByteVector.SPECIES_64, bArr2, i3 + i5 + 8)).convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts().and((short) 255));
                    i4 = i5 + ByteVector.SPECIES_128.length();
                }
            }
            IntVector reinterpretAsInts = zero.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0).reinterpretAsInts();
            IntVector reinterpretAsInts2 = zero.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1).reinterpretAsInts();
            i2 += reinterpretAsInts.add(reinterpretAsInts2).add(zero2.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0).reinterpretAsInts()).add(zero2.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1).reinterpretAsInts()).reduceLanes(VectorOperators.ADD);
        }
        return i2;
    }

    @Override // org.apache.lucene.internal.vectorization.VectorUtilSupport
    public float cosine(byte[] bArr, byte[] bArr2) {
        return cosine(MemorySegment.ofArray(bArr), MemorySegment.ofArray(bArr2));
    }

    public static float cosine(MemorySegment memorySegment, MemorySegment memorySegment2) {
        float[] cosineBody128;
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        if (memorySegment.byteSize() >= 16 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) {
            if (VECTOR_BITSIZE >= 512) {
                i = 0 + BYTE_SPECIES.loopBound((int) memorySegment.byteSize());
                cosineBody128 = cosineBody512(memorySegment, memorySegment2, i);
            } else if (VECTOR_BITSIZE == 256) {
                i = 0 + BYTE_SPECIES.loopBound((int) memorySegment.byteSize());
                cosineBody128 = cosineBody256(memorySegment, memorySegment2, i);
            } else {
                i = (int) (0 + ByteVector.SPECIES_64.loopBound(memorySegment.byteSize() - ByteVector.SPECIES_64.length()));
                cosineBody128 = cosineBody128(memorySegment, memorySegment2, i);
            }
            i2 = (int) (0 + cosineBody128[0]);
            i3 = (int) (0 + cosineBody128[1]);
            i4 = (int) (0 + cosineBody128[2]);
        }
        while (i < memorySegment.byteSize()) {
            byte b = memorySegment.get(ValueLayout.JAVA_BYTE, i);
            byte b2 = memorySegment2.get(ValueLayout.JAVA_BYTE, i);
            i2 += b * b2;
            i3 += b * b;
            i4 += b2 * b2;
            i++;
        }
        return (float) (i2 / Math.sqrt(i3 * i4));
    }

    private static float[] cosineBody512(MemorySegment memorySegment, MemorySegment memorySegment2, int i) {
        IntVector zero = IntVector.zero(INT_SPECIES);
        IntVector zero2 = IntVector.zero(INT_SPECIES);
        IntVector zero3 = IntVector.zero(INT_SPECIES);
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (i3 >= i) {
                return new float[]{zero.reduceLanes(VectorOperators.ADD), zero2.reduceLanes(VectorOperators.ADD), zero3.reduceLanes(VectorOperators.ADD)};
            }
            ByteVector fromMemorySegment = ByteVector.fromMemorySegment(BYTE_SPECIES, memorySegment, i3, ByteOrder.LITTLE_ENDIAN);
            ByteVector fromMemorySegment2 = ByteVector.fromMemorySegment(BYTE_SPECIES, memorySegment2, i3, ByteOrder.LITTLE_ENDIAN);
            Vector convertShape = fromMemorySegment.convertShape(VectorOperators.B2S, SHORT_SPECIES, 0);
            Vector convertShape2 = fromMemorySegment2.convertShape(VectorOperators.B2S, SHORT_SPECIES, 0);
            Vector mul = convertShape.mul(convertShape);
            Vector mul2 = convertShape2.mul(convertShape2);
            Vector mul3 = convertShape.mul(convertShape2);
            Vector convertShape3 = mul.convertShape(VectorOperators.S2I, INT_SPECIES, 0);
            Vector convertShape4 = mul2.convertShape(VectorOperators.S2I, INT_SPECIES, 0);
            Vector convertShape5 = mul3.convertShape(VectorOperators.S2I, INT_SPECIES, 0);
            zero2 = zero2.add(convertShape3);
            zero3 = zero3.add(convertShape4);
            zero = zero.add(convertShape5);
            i2 = i3 + BYTE_SPECIES.length();
        }
    }

    private static float[] cosineBody256(MemorySegment memorySegment, MemorySegment memorySegment2, int i) {
        IntVector zero = IntVector.zero(IntVector.SPECIES_256);
        IntVector zero2 = IntVector.zero(IntVector.SPECIES_256);
        IntVector zero3 = IntVector.zero(IntVector.SPECIES_256);
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (i3 >= i) {
                return new float[]{zero.reduceLanes(VectorOperators.ADD), zero2.reduceLanes(VectorOperators.ADD), zero3.reduceLanes(VectorOperators.ADD)};
            }
            ByteVector fromMemorySegment = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, memorySegment, i3, ByteOrder.LITTLE_ENDIAN);
            ByteVector fromMemorySegment2 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, memorySegment2, i3, ByteOrder.LITTLE_ENDIAN);
            Vector convertShape = fromMemorySegment.convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0);
            Vector convertShape2 = fromMemorySegment2.convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0);
            Vector mul = convertShape.mul(convertShape);
            Vector mul2 = convertShape2.mul(convertShape2);
            Vector mul3 = convertShape.mul(convertShape2);
            zero2 = zero2.add(mul);
            zero3 = zero3.add(mul2);
            zero = zero.add(mul3);
            i2 = i3 + ByteVector.SPECIES_64.length();
        }
    }

    private static float[] cosineBody128(MemorySegment memorySegment, MemorySegment memorySegment2, int i) {
        IntVector zero = IntVector.zero(IntVector.SPECIES_128);
        IntVector zero2 = IntVector.zero(IntVector.SPECIES_128);
        IntVector zero3 = IntVector.zero(IntVector.SPECIES_128);
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (i3 >= i) {
                return new float[]{zero.reduceLanes(VectorOperators.ADD), zero2.reduceLanes(VectorOperators.ADD), zero3.reduceLanes(VectorOperators.ADD)};
            }
            ByteVector fromMemorySegment = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, memorySegment, i3, ByteOrder.LITTLE_ENDIAN);
            ByteVector fromMemorySegment2 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, memorySegment2, i3, ByteOrder.LITTLE_ENDIAN);
            Vector convert = fromMemorySegment.convert(VectorOperators.B2S, 0);
            Vector convert2 = fromMemorySegment2.convert(VectorOperators.B2S, 0);
            Vector mul = convert.mul(convert);
            Vector mul2 = convert2.mul(convert2);
            Vector mul3 = convert.mul(convert2);
            zero2 = zero2.add(mul.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0));
            zero3 = zero3.add(mul2.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0));
            zero = zero.add(mul3.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0));
            i2 = i3 + (ByteVector.SPECIES_64.length() >> 1);
        }
    }

    @Override // org.apache.lucene.internal.vectorization.VectorUtilSupport
    public int squareDistance(byte[] bArr, byte[] bArr2) {
        return squareDistance(MemorySegment.ofArray(bArr), MemorySegment.ofArray(bArr2));
    }

    public static int squareDistance(MemorySegment memorySegment, MemorySegment memorySegment2) {
        if (!$assertionsDisabled && memorySegment.byteSize() != memorySegment2.byteSize()) {
            throw new AssertionError();
        }
        int i = 0;
        int i2 = 0;
        if (memorySegment.byteSize() >= 16 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) {
            if (VECTOR_BITSIZE >= 256) {
                i = 0 + BYTE_SPECIES.loopBound((int) memorySegment.byteSize());
                i2 = 0 + squareDistanceBody256(memorySegment, memorySegment2, i);
            } else {
                i = 0 + ByteVector.SPECIES_64.loopBound((int) memorySegment.byteSize());
                i2 = 0 + squareDistanceBody128(memorySegment, memorySegment2, i);
            }
        }
        while (i < memorySegment.byteSize()) {
            int i3 = memorySegment.get(ValueLayout.JAVA_BYTE, i) - memorySegment2.get(ValueLayout.JAVA_BYTE, i);
            i2 += i3 * i3;
            i++;
        }
        return i2;
    }

    private static int squareDistanceBody256(MemorySegment memorySegment, MemorySegment memorySegment2, int i) {
        IntVector zero = IntVector.zero(INT_SPECIES);
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (i3 >= i) {
                return zero.reduceLanes(VectorOperators.ADD);
            }
            Vector sub = ByteVector.fromMemorySegment(BYTE_SPECIES, memorySegment, i3, ByteOrder.LITTLE_ENDIAN).convertShape(VectorOperators.B2I, INT_SPECIES, 0).sub(ByteVector.fromMemorySegment(BYTE_SPECIES, memorySegment2, i3, ByteOrder.LITTLE_ENDIAN).convertShape(VectorOperators.B2I, INT_SPECIES, 0));
            zero = zero.add(sub.mul(sub));
            i2 = i3 + BYTE_SPECIES.length();
        }
    }

    private static int squareDistanceBody128(MemorySegment memorySegment, MemorySegment memorySegment2, int i) {
        IntVector zero = IntVector.zero(IntVector.SPECIES_128);
        IntVector zero2 = IntVector.zero(IntVector.SPECIES_128);
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (i3 >= i) {
                return zero.add(zero2).reduceLanes(VectorOperators.ADD);
            }
            Vector sub = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, memorySegment, i3, ByteOrder.LITTLE_ENDIAN).convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0).sub(ByteVector.fromMemorySegment(ByteVector.SPECIES_64, memorySegment2, i3, ByteOrder.LITTLE_ENDIAN).convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0));
            Vector convertShape = sub.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0);
            Vector convertShape2 = sub.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1);
            zero = zero.add(convertShape.mul(convertShape));
            zero2 = zero2.add(convertShape2.mul(convertShape2));
            i2 = i3 + ByteVector.SPECIES_64.length();
        }
    }

    @Override // org.apache.lucene.internal.vectorization.VectorUtilSupport
    public int findNextGEQ(int[] iArr, int i, int i2, int i3) {
        if (ENABLE_FIND_NEXT_GEQ_VECTOR_OPTO) {
            while (i2 + INT_SPECIES.length() < i3) {
                if (iArr[i2 + INT_SPECIES.length()] >= i) {
                    return i2 + IntVector.fromArray(INT_SPECIES, iArr, i2).compare(VectorOperators.LT, i).trueCount();
                }
                i2 += INT_SPECIES.length() + 1;
            }
        }
        for (int i4 = i2; i4 < i3; i4++) {
            if (iArr[i4] >= i) {
                return i4;
            }
        }
        return i3;
    }

    @Override // org.apache.lucene.internal.vectorization.VectorUtilSupport
    public long int4BitDotProduct(byte[] bArr, byte[] bArr2) {
        if (!$assertionsDisabled && bArr.length != bArr2.length * 4) {
            throw new AssertionError();
        }
        if (bArr2.length >= 16 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) {
            if (VECTOR_BITSIZE >= 256) {
                return int4BitDotProduct256(bArr, bArr2);
            }
            if (VECTOR_BITSIZE == 128) {
                return int4BitDotProduct128(bArr, bArr2);
            }
        }
        return DefaultVectorUtilSupport.int4BitDotProductImpl(bArr, bArr2);
    }

    static long int4BitDotProduct256(byte[] bArr, byte[] bArr2) {
        long j = 0;
        long j2 = 0;
        long j3 = 0;
        long j4 = 0;
        int i = 0;
        if (bArr2.length >= ByteVector.SPECIES_256.vectorByteSize() * 2) {
            int loopBound = ByteVector.SPECIES_256.loopBound(bArr2.length);
            LongVector zero = LongVector.zero(LongVector.SPECIES_256);
            LongVector zero2 = LongVector.zero(LongVector.SPECIES_256);
            LongVector zero3 = LongVector.zero(LongVector.SPECIES_256);
            LongVector zero4 = LongVector.zero(LongVector.SPECIES_256);
            while (i < loopBound) {
                LongVector reinterpretAsLongs = ByteVector.fromArray(BYTE_SPECIES_256, bArr, i).reinterpretAsLongs();
                LongVector reinterpretAsLongs2 = ByteVector.fromArray(BYTE_SPECIES_256, bArr, i + bArr2.length).reinterpretAsLongs();
                LongVector reinterpretAsLongs3 = ByteVector.fromArray(BYTE_SPECIES_256, bArr, i + (bArr2.length * 2)).reinterpretAsLongs();
                LongVector reinterpretAsLongs4 = ByteVector.fromArray(BYTE_SPECIES_256, bArr, i + (bArr2.length * 3)).reinterpretAsLongs();
                LongVector reinterpretAsLongs5 = ByteVector.fromArray(BYTE_SPECIES_256, bArr2, i).reinterpretAsLongs();
                zero = zero.add(reinterpretAsLongs.and(reinterpretAsLongs5).lanewise(VectorOperators.BIT_COUNT));
                zero2 = zero2.add(reinterpretAsLongs2.and(reinterpretAsLongs5).lanewise(VectorOperators.BIT_COUNT));
                zero3 = zero3.add(reinterpretAsLongs3.and(reinterpretAsLongs5).lanewise(VectorOperators.BIT_COUNT));
                zero4 = zero4.add(reinterpretAsLongs4.and(reinterpretAsLongs5).lanewise(VectorOperators.BIT_COUNT));
                i += ByteVector.SPECIES_256.length();
            }
            j = 0 + zero.reduceLanes(VectorOperators.ADD);
            j2 = 0 + zero2.reduceLanes(VectorOperators.ADD);
            j3 = 0 + zero3.reduceLanes(VectorOperators.ADD);
            j4 = 0 + zero4.reduceLanes(VectorOperators.ADD);
        }
        if (bArr2.length - i >= ByteVector.SPECIES_128.vectorByteSize()) {
            LongVector zero5 = LongVector.zero(LongVector.SPECIES_128);
            LongVector zero6 = LongVector.zero(LongVector.SPECIES_128);
            LongVector zero7 = LongVector.zero(LongVector.SPECIES_128);
            LongVector zero8 = LongVector.zero(LongVector.SPECIES_128);
            int loopBound2 = ByteVector.SPECIES_128.loopBound(bArr2.length);
            while (i < loopBound2) {
                LongVector reinterpretAsLongs6 = ByteVector.fromArray(BYTE_SPECIES_128, bArr, i).reinterpretAsLongs();
                LongVector reinterpretAsLongs7 = ByteVector.fromArray(BYTE_SPECIES_128, bArr, i + bArr2.length).reinterpretAsLongs();
                LongVector reinterpretAsLongs8 = ByteVector.fromArray(BYTE_SPECIES_128, bArr, i + (bArr2.length * 2)).reinterpretAsLongs();
                LongVector reinterpretAsLongs9 = ByteVector.fromArray(BYTE_SPECIES_128, bArr, i + (bArr2.length * 3)).reinterpretAsLongs();
                LongVector reinterpretAsLongs10 = ByteVector.fromArray(BYTE_SPECIES_128, bArr2, i).reinterpretAsLongs();
                zero5 = zero5.add(reinterpretAsLongs6.and(reinterpretAsLongs10).lanewise(VectorOperators.BIT_COUNT));
                zero6 = zero6.add(reinterpretAsLongs7.and(reinterpretAsLongs10).lanewise(VectorOperators.BIT_COUNT));
                zero7 = zero7.add(reinterpretAsLongs8.and(reinterpretAsLongs10).lanewise(VectorOperators.BIT_COUNT));
                zero8 = zero8.add(reinterpretAsLongs9.and(reinterpretAsLongs10).lanewise(VectorOperators.BIT_COUNT));
                i += ByteVector.SPECIES_128.length();
            }
            j += zero5.reduceLanes(VectorOperators.ADD);
            j2 += zero6.reduceLanes(VectorOperators.ADD);
            j3 += zero7.reduceLanes(VectorOperators.ADD);
            j4 += zero8.reduceLanes(VectorOperators.ADD);
        }
        while (i < bArr2.length) {
            j += Integer.bitCount(bArr[i] & bArr2[i] & 255);
            j2 += Integer.bitCount(bArr[i + bArr2.length] & bArr2[i] & 255);
            j3 += Integer.bitCount(bArr[i + (2 * bArr2.length)] & bArr2[i] & 255);
            j4 += Integer.bitCount(bArr[i + (3 * bArr2.length)] & bArr2[i] & 255);
            i++;
        }
        return j + (j2 << 1) + (j3 << 2) + (j4 << 3);
    }

    public static long int4BitDotProduct128(byte[] bArr, byte[] bArr2) {
        int i = 0;
        IntVector zero = IntVector.zero(IntVector.SPECIES_128);
        IntVector zero2 = IntVector.zero(IntVector.SPECIES_128);
        IntVector zero3 = IntVector.zero(IntVector.SPECIES_128);
        IntVector zero4 = IntVector.zero(IntVector.SPECIES_128);
        int loopBound = ByteVector.SPECIES_128.loopBound(bArr2.length);
        while (i < loopBound) {
            IntVector reinterpretAsInts = ByteVector.fromArray(BYTE_SPECIES_128, bArr2, i).reinterpretAsInts();
            IntVector reinterpretAsInts2 = ByteVector.fromArray(BYTE_SPECIES_128, bArr, i).reinterpretAsInts();
            IntVector reinterpretAsInts3 = ByteVector.fromArray(BYTE_SPECIES_128, bArr, i + bArr2.length).reinterpretAsInts();
            IntVector reinterpretAsInts4 = ByteVector.fromArray(BYTE_SPECIES_128, bArr, i + (bArr2.length * 2)).reinterpretAsInts();
            IntVector reinterpretAsInts5 = ByteVector.fromArray(BYTE_SPECIES_128, bArr, i + (bArr2.length * 3)).reinterpretAsInts();
            zero = zero.add(reinterpretAsInts.and(reinterpretAsInts2).lanewise(VectorOperators.BIT_COUNT));
            zero2 = zero2.add(reinterpretAsInts.and(reinterpretAsInts3).lanewise(VectorOperators.BIT_COUNT));
            zero3 = zero3.add(reinterpretAsInts.and(reinterpretAsInts4).lanewise(VectorOperators.BIT_COUNT));
            zero4 = zero4.add(reinterpretAsInts.and(reinterpretAsInts5).lanewise(VectorOperators.BIT_COUNT));
            i += ByteVector.SPECIES_128.length();
        }
        long reduceLanes = 0 + zero.reduceLanes(VectorOperators.ADD);
        long reduceLanes2 = 0 + zero2.reduceLanes(VectorOperators.ADD);
        long reduceLanes3 = 0 + zero3.reduceLanes(VectorOperators.ADD);
        long reduceLanes4 = 0 + zero4.reduceLanes(VectorOperators.ADD);
        while (i < bArr2.length) {
            byte b = bArr2[i];
            reduceLanes += Integer.bitCount(b & bArr[i] & 255);
            reduceLanes2 += Integer.bitCount(b & bArr[i + bArr2.length] & 255);
            reduceLanes3 += Integer.bitCount(b & bArr[i + (2 * bArr2.length)] & 255);
            reduceLanes4 += Integer.bitCount(b & bArr[i + (3 * bArr2.length)] & 255);
            i++;
        }
        return reduceLanes + (reduceLanes2 << 1) + (reduceLanes3 << 2) + (reduceLanes4 << 3);
    }

    @Override // org.apache.lucene.internal.vectorization.VectorUtilSupport
    public float minMaxScalarQuantize(float[] fArr, byte[] bArr, float f, float f2, float f3, float f4) {
        if (!$assertionsDisabled && fArr.length != bArr.length) {
            throw new AssertionError();
        }
        float f5 = 0.0f;
        int i = 0;
        if (VECTOR_BITSIZE >= 256) {
            FloatVector zero = FloatVector.zero(FLOAT_SPECIES);
            while (i < FLOAT_SPECIES.loopBound(fArr.length)) {
                FloatVector fromArray = FloatVector.fromArray(FLOAT_SPECIES, fArr, i);
                FloatVector sub = fromArray.min(f4).max(f3).sub(f3);
                Vector convert = fma(sub, sub.broadcast(f), sub.broadcast(0.5f)).convert(VectorOperators.F2I, 0);
                convert.castShape(BYTE_SPECIES, 0).intoArray(bArr, i);
                FloatVector mul = convert.castShape(FLOAT_SPECIES, 0).mul(f2);
                zero = fma(fromArray.sub(f3 / 2.0f), fromArray.broadcast(f3), fma(fromArray.sub(f3).sub(mul), mul, zero));
                i += FLOAT_SPECIES.length();
            }
            f5 = zero.reduceLanes(VectorOperators.ADD);
        }
        return f5 + new DefaultVectorUtilSupport.ScalarQuantizer(f2, f, f3, f4).quantize(fArr, bArr, i);
    }

    @Override // org.apache.lucene.internal.vectorization.VectorUtilSupport
    public float recalculateScalarQuantizationOffset(byte[] bArr, float f, float f2, float f3, float f4, float f5, float f6) {
        float f7 = 0.0f;
        int i = 0;
        if (VECTOR_BITSIZE >= 256) {
            FloatVector zero = FloatVector.zero(FLOAT_SPECIES);
            while (i < BYTE_SPECIES.loopBound(bArr.length)) {
                FloatVector castShape = ByteVector.fromArray(BYTE_SPECIES, bArr, i).castShape(FLOAT_SPECIES, 0);
                FloatVector fma = fma(castShape, castShape.broadcast(f), castShape.broadcast(f2));
                FloatVector sub = fma.min(f6).max(f5).sub(f5);
                FloatVector mul = fma(sub, sub.broadcast(f3), sub.broadcast(0.5f)).convert(VectorOperators.F2I, 0).castShape(FLOAT_SPECIES, 0).mul(f4);
                zero = fma(fma.sub(f5 / 2.0f), fma.broadcast(f5), fma(fma.sub(f5).sub(mul), mul, zero));
                i += BYTE_SPECIES.length();
            }
            f7 = zero.reduceLanes(VectorOperators.ADD);
        }
        return f7 + new DefaultVectorUtilSupport.ScalarQuantizer(f4, f3, f5, f6).recalculateOffset(bArr, i, f, f2);
    }

    static {
        $assertionsDisabled = !PanamaVectorUtilSupport.class.desiredAssertionStatus();
        INT_SPECIES = PanamaVectorConstants.PRERERRED_INT_SPECIES;
        BYTE_SPECIES_128 = ByteVector.SPECIES_128;
        BYTE_SPECIES_256 = ByteVector.SPECIES_256;
        VECTOR_BITSIZE = PanamaVectorConstants.PREFERRED_VECTOR_BITSIZE;
        FLOAT_SPECIES = INT_SPECIES.withLanes(Float.TYPE);
        if (VECTOR_BITSIZE >= 256) {
            BYTE_SPECIES = ByteVector.SPECIES_MAX.withShape(VectorShape.forBitSize(VECTOR_BITSIZE >> 2));
            SHORT_SPECIES = ShortVector.SPECIES_MAX.withShape(VectorShape.forBitSize(VECTOR_BITSIZE >> 1));
        } else {
            BYTE_SPECIES = null;
            SHORT_SPECIES = null;
        }
        ENABLE_FIND_NEXT_GEQ_VECTOR_OPTO = INT_SPECIES.length() >= 8;
    }
}
