/* * Created on 30-jan-2006 */ package eden.recent; import java.lang.reflect.Field; import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; import java.util.Random; import sun.misc.Unsafe; public class ArrayVersusBuffer { public static void main(String[] args) throws Exception { Thread.currentThread().setPriority(Thread.MAX_PRIORITY - 1); long tArrSum = 0; long tBufSum = 0; long tPntSum = 0; int runs = 16; int tightLoops = 1024 * 2; int vertices = 1024 * 2; init(vertices); for (int run = 0; run < runs; run++) { resetBuffer(bufA, 42L); resetBuffer(bufB, 4200L); resetArray(arrA, 42L); resetArray(arrB, 4200L); long pntA = getBase(bufA); long pntB = getBase(bufB); long pntC = getBase(bufC); if (run == runs / 2) { tArrSum = 0; tBufSum = 0; tPntSum = 0; System.out.println("<<<< WARMUP DONE >>>>"); } long t; long tArr, tBuf, tPnt; t = System.nanoTime(); for (int i = 0; i < tightLoops; i++) doWeightOnArrays(arrA, arrB, 0.75F, arrC); tArr = (System.nanoTime() - t); t = System.nanoTime(); for (int i = 0; i < tightLoops; i++) doWeightOnBuffers(fbA, fbB, 0.75F, fbC); tBuf = (System.nanoTime() - t); t = System.nanoTime(); for (int i = 0; i < tightLoops; i++) doWeightOnPointers(pntA, pntB, 0.75F, pntC, bufA.capacity()); tPnt = (System.nanoTime() - t); tArrSum += tArr; tBufSum += tBuf; tPntSum += tPnt; System.out.println("arr:\t" + tArr + "ns"); System.out.println("buf:\t" + tBuf + "ns"); System.out.println("pnt:\t" + tPnt + "ns"); System.out.println(); } // ns -> ms tArrSum /= 1000000L; tBufSum /= 1000000L; tPntSum /= 1000000L; System.out.println("duration of last " + (runs / 2) + " runs:"); System.out.println("---> arr:\t" + tArrSum + "ms\t" + (int) (((1000.0 / tArrSum) * tightLoops * vertices * (runs / 2)) / 1000000) + "M vertices/s"); System.out.println("---> buf:\t" + tBufSum + "ms\t" + (int) (((1000.0 / tBufSum) * tightLoops * vertices * (runs / 2)) / 1000000) + "M vertices/s"); System.out.println("---> pnt:\t" + tPntSum + "ms\t" + (int) (((1000.0 / tPntSum) * tightLoops * vertices * (runs / 2)) / 1000000) + "M vertices/s"); } private static final void init(int vectors) { // init buffers bufA = createByteBuffer(4 * (vectors * 3)); bufB = createByteBuffer(4 * (vectors * 3)); bufC = createByteBuffer(4 * (vectors * 3)); fbA = bufA.asFloatBuffer(); fbB = bufB.asFloatBuffer(); fbC = bufC.asFloatBuffer(); // init arrays arrA = new float[vectors * 3]; arrB = new float[vectors * 3]; arrC = new float[vectors * 3]; } static ByteBuffer bufA, bufB, bufC; static FloatBuffer fbA, fbB, fbC; static float[] arrA, arrB, arrC; private static final void resetBuffer(ByteBuffer buf, long seed) { Random r = new Random(seed); buf.clear(); while (buf.hasRemaining()) buf.putFloat(r.nextFloat()); buf.rewind(); } private static final void resetArray(float[] arr, long seed) { Random r = new Random(seed); for (int i = 0; i < arr.length; i++) arr[i] = r.nextFloat(); } private static final void doWeightOnArrays(float[] a, float[] b, float weight, float[] c) { float aMul = weight; float bMul = 1.0F - weight; // unrolling this loop makes it slower for (int i = 0; i < a.length; i++) c[i] = aMul * a[i] + bMul * b[i]; } private static final void doWeightOnBuffers(FloatBuffer a, FloatBuffer b, float weight, FloatBuffer c) { float aMul = weight; float bMul = 1.0F - weight; int i = a.position(); int len = a.remaining(); while (i < len) { c.put(i, aMul * a.get(i) + bMul * b.get(i++)); c.put(i, aMul * a.get(i) + bMul * b.get(i++)); c.put(i, aMul * a.get(i) + bMul * b.get(i++)); c.put(i, aMul * a.get(i) + bMul * b.get(i++)); } } private static final void doWeightOnPointers(long a, long b, float weight, long c, int bytes) { final float aMul = weight; final float bMul = 1.0F - weight; Unsafe unsafe = getAccess(); final int end = bytes - 12; int i = -4; // unroll 4x while (i < end) { unsafe.putFloat((i += 4) + c, unsafe.getFloat(i + a) * aMul + unsafe.getFloat(i + b) * bMul); unsafe.putFloat((i += 4) + c, unsafe.getFloat(i + a) * aMul + unsafe.getFloat(i + b) * bMul); unsafe.putFloat((i += 4) + c, unsafe.getFloat(i + a) * aMul + unsafe.getFloat(i + b) * bMul); unsafe.putFloat((i += 4) + c, unsafe.getFloat(i + a) * aMul + unsafe.getFloat(i + b) * bMul); unsafe.putFloat((i += 4) + c, unsafe.getFloat(i + a) * aMul + unsafe.getFloat(i + b) * bMul); unsafe.putFloat((i += 4) + c, unsafe.getFloat(i + a) * aMul + unsafe.getFloat(i + b) * bMul); unsafe.putFloat((i += 4) + c, unsafe.getFloat(i + a) * aMul + unsafe.getFloat(i + b) * bMul); unsafe.putFloat((i += 4) + c, unsafe.getFloat(i + a) * aMul + unsafe.getFloat(i + b) * bMul); unsafe.putFloat((i += 4) + c, unsafe.getFloat(i + a) * aMul + unsafe.getFloat(i + b) * bMul); // more unrolling will cause a slowdown } // finish remaining bytes while (i < bytes) { unsafe.putFloat((i += 4) + c, unsafe.getFloat(i + a) * aMul + unsafe.getFloat(i + b) * bMul); unsafe.putFloat((i += 4) + c, unsafe.getFloat(i + a) * aMul + unsafe.getFloat(i + b) * bMul); unsafe.putFloat((i += 4) + c, unsafe.getFloat(i + a) * aMul + unsafe.getFloat(i + b) * bMul); } } /** * UNSAFE STUFF */ private static Unsafe unsafe; private static Field addressHack; static { try { ByteBuffer tmp = ByteBuffer.allocateDirect(1); Field unsafeHack = tmp.getClass().getDeclaredField("unsafe"); unsafeHack.setAccessible(true); unsafe = (Unsafe) unsafeHack.get(tmp); addressHack = Buffer.class.getDeclaredField("address"); addressHack.setAccessible(true); tmp = null; } catch (Exception exc) { exc.printStackTrace(); throw new InternalError(); } } public static final Unsafe getAccess() { return unsafe; } public static final long getBase(ByteBuffer bb) { try { return addressHack.getLong(bb); } catch (Exception exc) { exc.printStackTrace(); throw new InternalError(); } } public static final ByteBuffer createByteBuffer(int size) { return ByteBuffer.allocateDirect(size).order(ByteOrder.nativeOrder()); } }