import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.Serializable; import java.util.ArrayDeque; import java.util.Queue; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.ForkJoinTask; import java.util.concurrent.Future; import java.util.concurrent.RecursiveAction; import java.util.concurrent.RecursiveTask; import java.util.concurrent.ThreadPoolExecutor; /** * Construct a binary tree and compute the sum over all tree nodes * of the number of primes between 2 and each node's number * using a sequential implementation. * * @author Holger.Peine@hs-hannover.de */ @SuppressWarnings("serial") public class TreeSumSequential { /** * Probability for a node to have a child node. If this is low, a small tree often results; * if this is high, then the tree often has the maximum allowed number of nodes. */ private static final int SUBTREE_THRESHOLD_PERCENT = 70; /** * Each node contains a random number between 1 and this: * Must count the number of primes in that range */ private static final int MAX_NUMBER_IN_NODE = 300; private static final int MAX_NODES = 300_000; /* * Number of repetitions of the test */ private static final int N_TESTS = 5; /** * A simple binary tree node, holding a random int between 0 and MAX_NUMBER_IN_NODE, * the level/height of this node, and references to up to two child nodes. */ static final class Node implements Serializable { Node(int level) { this.number = (int)(Math.random()*MAX_NUMBER_IN_NODE); this.level = level; left = null; right = null; ++nNodes; } Node left, right; int number, level; } private static int nNodes; private static int maxLevel; /** * Create a random binary tree where each node has children with a certain probability. * Create the tree breadth-first by taking a node from the front of a work queue of nodes, * possibly creating child nodes for this node and adding those at the back of the queue, * and repeating this until the queue is empty or MAX_NODES nodes have been created. * * Note that nodes removed from the queue do not vanish, but are still linked with their * ancestor nodes higher up in the tree. At any time, the queue contains the leaves of the * current tree. */ private static void createTree(Queue queue) { Node current, leftChild, rightChild; nNodes = 0; maxLevel = 0; while (!queue.isEmpty() && nNodes < MAX_NODES) { current = queue.remove(); int level = current.level; if (level > maxLevel) { maxLevel = level; } if ( (int)(Math.random()*100) < SUBTREE_THRESHOLD_PERCENT) { leftChild = new Node(level+1); current.left = leftChild; queue.add(leftChild); } if ( (int)(Math.random()*100) < SUBTREE_THRESHOLD_PERCENT) { rightChild = new Node(level+1); current.right = rightChild; queue.add(rightChild); } } } /** * @return the number of primes between from and to * using a deliberately inefficient implementation */ private static int countPrimes(int from, int to) { int nPrimes = 0, t; for (int n = from; n < to; n++) { // check if n is prime (using a deliberately inefficient test) for (t = 2; t < n; t++) { if (n % t == 0) { break; } } if (t == n) { // n is prime ++nPrimes; } } return nPrimes; } /** * The prime counting method used by the sequential implementation * @param node the tree * @return sum over all tree nodes of the number of primes between 2 and each node's number * * Note that the recursive calls occur after half of the work for this node has been completed. * (Half of the work is about 70% of the numbers, since the amount of work for a number increases * with the square of the number.) This is to sort of simulate the situation that the need for * further work to be done becomes clear only during the work for this node. While this does not * really make a difference in the sequential case here, the time when recursive tasks are submitted * does make a difference for the overall behavior in the parallel case. */ public static int countPrimesInTree(Node node) { int leftResult = 0, rightResult = 0; int nPrimesFirstPart = countPrimes(2, (int)(0.7*node.number)); if (node.left != null) leftResult = countPrimesInTree(node.left); if (node.right != null) rightResult = countPrimesInTree(node.right); int nPrimesSecondPart = countPrimes((int)(0.7*node.number), node.number+1); return nPrimesFirstPart + nPrimesSecondPart + leftResult + rightResult; } static final class TreeSumTaskThreadPool implements Callable { private Node node; private ExecutorService executorService; public TreeSumTaskThreadPool(Node node, ExecutorService executorService) { this.node = node; this.executorService = executorService; } @Override public Integer call() throws InterruptedException, ExecutionException { Future leftFuture = null; Future rightFuture = null; int nPrimesFirstPart = countPrimes(2, (int)(0.7*node.number)); if (node.left != null) { leftFuture = executorService.submit(new TreeSumTaskThreadPool(node.left, executorService)); } if (node.right != null) { rightFuture = executorService.submit(new TreeSumTaskThreadPool(node.right, executorService)); } int nPrimesSecondPart = countPrimes((int)(0.7*node.number), node.number+1); int leftResult = leftFuture == null ? 0 : leftFuture.get(); int rightResult = rightFuture == null ? 0 : rightFuture.get(); return nPrimesFirstPart + nPrimesSecondPart + leftResult + rightResult; } } static final class TreeSumTaskForkJoin extends RecursiveTask { private Node node; public TreeSumTaskForkJoin(Node node) { this.node = node; } @Override protected Integer compute() { int nPrimesFirstPart = countPrimes(2, (int)(0.7*node.number)); TreeSumTaskForkJoin leftTask = null, rightTask = null; if (node.left != null) { leftTask = new TreeSumTaskForkJoin(node.left); leftTask.fork(); } if (node.right != null) { rightTask = new TreeSumTaskForkJoin(node.right); rightTask.fork(); } if (leftTask != null) leftTask.join(); if (rightTask != null) rightTask.join(); int leftResult = 0; int rightResult = 0; try { leftResult = leftTask == null ? 0 : leftTask.get(); rightResult = rightTask == null ? 0 : rightTask.get(); } catch (Exception e) {} int nPrimesSecondPart = countPrimes((int)(0.7*node.number), node.number+1); return nPrimesFirstPart + nPrimesSecondPart + leftResult + rightResult; } } public static int countPrimesInTree_ThreadPoolExecutor(Node node) throws InterruptedException, ExecutionException { ExecutorService ecs = Executors.newCachedThreadPool(); // fixed size causes deadlock Future result = ecs.submit(new TreeSumTaskThreadPool(node, ecs)); return result.get(); } public static int countPrimesInTree_ForkJoinTask(Node node) { ForkJoinPool fjp = ForkJoinPool.commonPool(); TreeSumTaskForkJoin task = new TreeSumTaskForkJoin(node); return fjp.invoke(task); } /** * @param args: args[0] may hold the name of file holding the serialized tree to be used for benchmarking */ public static void main(String[] args) throws InterruptedException, ExecutionException { long startTime, endTime; long totalSeqRuntime = 0; Node tree = null; for (int testRound = 0; testRound < N_TESTS; ++testRound) { // Construct tree: Either read it from a file, or create a new one if (args.length > 0) { // 1st argument is name of file with serialized tree ObjectInputStream ois; try { ois = new ObjectInputStream(new FileInputStream(args[0])); tree = (Node)ois.readObject(); ois.close(); } catch (Exception e) { System.out.println(e); } } else { // Create a new tree tree = new Node(0); // the root node (at level 0) Queue queue = new ArrayDeque<>(); queue.add(tree); createTree(queue); System.out.println("Built binary tree with " + nNodes + " nodes, height " + maxLevel); } // Write the tree to file "tree.ser", no matter how it was constructed: try { ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream("tree.ser")); oos.writeObject(tree); oos.close(); } catch (Exception e) { System.out.println(e); } // Count sequentially startTime = System.currentTimeMillis(); int sequentialResult = countPrimesInTree(tree); endTime = System.currentTimeMillis(); totalSeqRuntime += endTime - startTime; System.out.println("Sequential result: Found " + sequentialResult + " primes in " + (double)(endTime - startTime)/1000 + " sec"); startTime = System.currentTimeMillis(); int threadPoolExecutorResult = countPrimesInTree_ThreadPoolExecutor(tree); endTime = System.currentTimeMillis(); totalSeqRuntime += endTime - startTime; System.out.println("ThreadPoolExecutor result: Found " + threadPoolExecutorResult + " primes in " + (double)(endTime - startTime)/1000 + " sec"); startTime = System.currentTimeMillis(); int forkJoinResult = countPrimesInTree_ForkJoinTask(tree); endTime = System.currentTimeMillis(); totalSeqRuntime += endTime - startTime; System.out.println("ForkJoin result: Found " + forkJoinResult + " primes in " + (double)(endTime - startTime)/1000 + " sec"); } // for testRound System.out.println(); System.out.println("Average SEQ time after " + N_TESTS + " tests = " + ((double)totalSeqRuntime)/(1000*N_TESTS) + " sec"); } }