package threeWayPartition;

/*
 * To get it to compile in Eclipse:
 * 
 * 1. Right-click your project --> Properties --> Java Build Path --> Libraries tab.
 * 2. Click Add Library... --> select JUnit --> Next.
 * 3. Choose JUnit 5 --> Finish --> Apply and Close.
 * 4. Eclipse should automatically add the right jars.
 * 
 * If that doesn't work, ask your friends or come to me ASAP.
 */

import static org.junit.jupiter.api.Assertions.*;

import java.util.Arrays;
import java.util.Random;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;

public class ThreeWayPartitionTest {

    // ---- helper: check invariants for 3-way partition (inclusive j) ----
    private static void assertThreeWayPartitionCorrect(int[] original, int[] arr,
                                                       int i, int j, int pivot) {
        // shape of equal block
        assertTrue(i >= 0 && i < arr.length, "i in range");
        assertTrue(j >= i - 1 && j < arr.length, "j in range (j can be i-1 if no equals)");

        // left side: strictly less than pivot
        for (int idx = 0; idx < i; idx++) {
            assertTrue(arr[idx] < pivot, "arr[" + idx + "] < pivot");
        }
        // middle: equal to pivot (if any)
        for (int idx = i; idx <= j; idx++) {
            assertEquals(pivot, arr[idx], "arr[" + idx + "] == pivot");
        }
        // right side: strictly greater than pivot
        for (int idx = j + 1; idx < arr.length; idx++) {
            assertTrue(arr[idx] > pivot, "arr[" + idx + "] > pivot");
        }

        // permutation check (multiset preserved)
        int[] sortedOrig = original.clone();
        int[] sortedArr  = arr.clone();
        Arrays.sort(sortedOrig);
        Arrays.sort(sortedArr);
        assertArrayEquals(sortedOrig, sortedArr, "partition must be a permutation");
    }

    // ---- small, surgical cases ----

    @Test
    void allEqual() {
        int[] a = {7, 7, 7, 7, 7};
        int[] copy = a.clone();
        ThreeWayPartition.Pair p = ThreeWayPartition.threeWayPartition(copy);
        assertThreeWayPartitionCorrect(a, copy, p.i(), p.j(), a[0]);
        // whole array should be "equal" block
        assertEquals(0, p.i());
        assertEquals(a.length - 1, p.j());
    }

    @Test
    void noneEqual() {
        int[] a = {5, 1, 2, 3, 4}; // pivot=5, all others < pivot
        int[] copy = a.clone();
        ThreeWayPartition.Pair p = ThreeWayPartition.threeWayPartition(copy);
        assertThreeWayPartitionCorrect(a, copy, p.i(), p.j(), a[0]);
        // equal block should be only the pivot (size 1)
        assertEquals(4, p.i());
        assertEquals(4, p.j());
    }
    
    @Test
    void noneEqual2() {
        int[] a = {51, 23, 43, 89, 11, 98, 43, 21, 7, 34, 56, 12, 9}; 
        int[] copy = a.clone();
        ThreeWayPartition.Pair p = ThreeWayPartition.threeWayPartition(copy);
        assertThreeWayPartitionCorrect(a, copy, p.i(), p.j(), a[0]);
        // equal block should be only the pivot (size 1)
        assertEquals(9, p.i());
        assertEquals(9, p.j());
    }

    @Test
    void duplicatesOnBothSides() {
        int[] a = {4, 6, 4, 2, 9, 4, 1, 7, 4, 5, 3};
        int[] copy = a.clone();
        ThreeWayPartition.Pair p = ThreeWayPartition.threeWayPartition(copy);
        assertThreeWayPartitionCorrect(a, copy, p.i(), p.j(), a[0]);
    }

    @Test
    void negativesAndZeros() {
        int[] a = {0, -2, 0, -2, 5, 0, 3, -1, 0};
        int[] copy = a.clone();
        ThreeWayPartition.Pair p = ThreeWayPartition.threeWayPartition(copy);
        assertThreeWayPartitionCorrect(a, copy, p.i(), p.j(), a[0]);
    }

    @Test
    void twoElements() {
        int[] a = {3, 1};
        int[] copy = a.clone();
        ThreeWayPartition.Pair p = ThreeWayPartition.threeWayPartition(copy);
        assertThreeWayPartitionCorrect(a, copy, p.i(), p.j(), a[0]);
    }

    // ---- longer arrays ----

    @Test
    void longStructuredArray() {
        // Many duplicates around the pivot to stress the equal block
        int n = 1000;
        int[] a = new int[n];
        // pivot at index 0 is 10
        a[0] = 10;
        for (int i = 1; i < n; i++) {
            // mix of <, =, >
            if (i % 6 == 0) a[i] = 10;
            else if (i % 6 <= 2) a[i] = 9;   // <
            else a[i] = 11;                  // >
        }
        int[] copy = a.clone();
        ThreeWayPartition.Pair p = ThreeWayPartition.threeWayPartition(copy);
        assertThreeWayPartitionCorrect(a, copy, p.i(), p.j(), a[0]);
    }

    @RepeatedTest(5)
    void longRandomArrayWithManyDuplicates() {
        int n = 2000;
        Random rnd = new Random(12345L + (long) n); // deterministic but mixed per run
        int[] a = new int[n];
        // choose pivot first; keep values mostly around pivot to ensure big equal block
        a[0] = 50;
        for (int i = 1; i < n; i++) {
            // values drawn from {45..55} with bias to 50 to create duplicates
            int r = rnd.nextInt(100);
            if (r < 60) a[i] = 50;          // 60% equals
            else if (r < 80) a[i] = 50 - rnd.nextInt(5) - 1; // 20% <
            else a[i] = 50 + rnd.nextInt(5) + 1;             // 20% >
        }
        int[] copy = a.clone();
        ThreeWayPartition.Pair p = ThreeWayPartition.threeWayPartition(copy);
        assertThreeWayPartitionCorrect(a, copy, p.i(), p.j(), a[0]);
    }
}
