|
|
ThreeWayPartition
ThreeWayPartitionTest.java
package threeWayPartition;
/*
* To get it to compile in Eclipse:
*
* 1. Right-click your project --> Properties --> Java Build Path --> Libraries tab.
* 2. Click Add Library... --> select JUnit --> Next.
* 3. Choose JUnit 5 --> Finish --> Apply and Close.
* 4. Eclipse should automatically add the right jars.
*
* If that doesn't work, ask your friends or come to me ASAP.
*/
import static org.junit.jupiter.api.Assertions.*;
import java.util.Arrays;
import java.util.Random;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;
public class ThreeWayPartitionTest {
// ---- helper: check invariants for 3-way partition (inclusive j) ----
private static void assertThreeWayPartitionCorrect(int[] original, int[] arr,
int i, int j, int pivot) {
// shape of equal block
assertTrue(i >= 0 && i < arr.length, "i in range");
assertTrue(j >= i - 1 && j < arr.length, "j in range (j can be i-1 if no equals)");
// left side: strictly less than pivot
for (int idx = 0; idx < i; idx++) {
assertTrue(arr[idx] < pivot, "arr[" + idx + "] < pivot");
}
// middle: equal to pivot (if any)
for (int idx = i; idx <= j; idx++) {
assertEquals(pivot, arr[idx], "arr[" + idx + "] == pivot");
}
// right side: strictly greater than pivot
for (int idx = j + 1; idx < arr.length; idx++) {
assertTrue(arr[idx] > pivot, "arr[" + idx + "] > pivot");
}
// permutation check (multiset preserved)
int[] sortedOrig = original.clone();
int[] sortedArr = arr.clone();
Arrays.sort(sortedOrig);
Arrays.sort(sortedArr);
assertArrayEquals(sortedOrig, sortedArr, "partition must be a permutation");
}
// ---- small, surgical cases ----
@Test
void allEqual() {
int[] a = {7, 7, 7, 7, 7};
int[] copy = a.clone();
ThreeWayPartition.Pair p = ThreeWayPartition.threeWayPartition(copy);
assertThreeWayPartitionCorrect(a, copy, p.i(), p.j(), a[0]);
// whole array should be "equal" block
assertEquals(0, p.i());
assertEquals(a.length - 1, p.j());
}
@Test
void noneEqual() {
int[] a = {5, 1, 2, 3, 4}; // pivot=5, all others < pivot
int[] copy = a.clone();
ThreeWayPartition.Pair p = ThreeWayPartition.threeWayPartition(copy);
assertThreeWayPartitionCorrect(a, copy, p.i(), p.j(), a[0]);
// equal block should be only the pivot (size 1)
assertEquals(4, p.i());
assertEquals(4, p.j());
}
@Test
void noneEqual2() {
int[] a = {51, 23, 43, 89, 11, 98, 43, 21, 7, 34, 56, 12, 9};
int[] copy = a.clone();
ThreeWayPartition.Pair p = ThreeWayPartition.threeWayPartition(copy);
assertThreeWayPartitionCorrect(a, copy, p.i(), p.j(), a[0]);
// equal block should be only the pivot (size 1)
assertEquals(9, p.i());
assertEquals(9, p.j());
}
@Test
void duplicatesOnBothSides() {
int[] a = {4, 6, 4, 2, 9, 4, 1, 7, 4, 5, 3};
int[] copy = a.clone();
ThreeWayPartition.Pair p = ThreeWayPartition.threeWayPartition(copy);
assertThreeWayPartitionCorrect(a, copy, p.i(), p.j(), a[0]);
}
@Test
void negativesAndZeros() {
int[] a = {0, -2, 0, -2, 5, 0, 3, -1, 0};
int[] copy = a.clone();
ThreeWayPartition.Pair p = ThreeWayPartition.threeWayPartition(copy);
assertThreeWayPartitionCorrect(a, copy, p.i(), p.j(), a[0]);
}
@Test
void twoElements() {
int[] a = {3, 1};
int[] copy = a.clone();
ThreeWayPartition.Pair p = ThreeWayPartition.threeWayPartition(copy);
assertThreeWayPartitionCorrect(a, copy, p.i(), p.j(), a[0]);
}
// ---- longer arrays ----
@Test
void longStructuredArray() {
// Many duplicates around the pivot to stress the equal block
int n = 1000;
int[] a = new int[n];
// pivot at index 0 is 10
a[0] = 10;
for (int i = 1; i < n; i++) {
// mix of <, =, >
if (i % 6 == 0) a[i] = 10;
else if (i % 6 <= 2) a[i] = 9; // <
else a[i] = 11; // >
}
int[] copy = a.clone();
ThreeWayPartition.Pair p = ThreeWayPartition.threeWayPartition(copy);
assertThreeWayPartitionCorrect(a, copy, p.i(), p.j(), a[0]);
}
@RepeatedTest(5)
void longRandomArrayWithManyDuplicates() {
int n = 2000;
Random rnd = new Random(12345L + (long) n); // deterministic but mixed per run
int[] a = new int[n];
// choose pivot first; keep values mostly around pivot to ensure big equal block
a[0] = 50;
for (int i = 1; i < n; i++) {
// values drawn from {45..55} with bias to 50 to create duplicates
int r = rnd.nextInt(100);
if (r < 60) a[i] = 50; // 60% equals
else if (r < 80) a[i] = 50 - rnd.nextInt(5) - 1; // 20% <
else a[i] = 50 + rnd.nextInt(5) + 1; // 20% >
}
int[] copy = a.clone();
ThreeWayPartition.Pair p = ThreeWayPartition.threeWayPartition(copy);
assertThreeWayPartitionCorrect(a, copy, p.i(), p.j(), a[0]);
}
}
|