package com.ironmountain.bedrock.groups.demo;

import java.io.Serializable;
import java.net.InetAddress;
import java.net.URL;
import java.util.Map;
import java.util.Random;
import java.util.TreeMap;
import java.util.Vector;

import org.jgroups.Address;
import org.jgroups.Channel;
import org.jgroups.JChannelFactory;
import org.jgroups.MergeView;
import org.jgroups.View;
import org.jgroups.blocks.ReplicatedHashMap;
import org.jgroups.blocks.ReplicatedHashMap.Notification;
import org.jgroups.mux.MuxChannel;

/**
 * Demonstrate multiple jgroups maps.
 * 
 * @author rnewson
 * 
 */
public class ReplicatedMapDemo {

	private static class MyNotification<K extends Serializable, V extends Serializable>
			implements Notification<K, V> {

		private final Channel channel;
		private final Map map;
		private int size = 0;

		public MyNotification(final Channel channel, final Map map) {
			this.channel = channel;
			this.map = map;
		}

		private void reportSizeChange() {
			final int newSize = map.size();
			if (map.size() != size) {
				size = newSize;
			}
			// System.out.printf("Map now contains %,d entries.\n%s\n", newSize,
			// map);
			// System.out.printf("Map now contains %,d entries.", newSize);
			// System.out.println();
		}

		@Override
		public void contentsCleared() {
		}

		@Override
		public void contentsSet(Map new_entries) {
		}

		@Override
		public void entryRemoved(K key) {
		}

		@Override
		public void entrySet(K key, V value) {
			reportSizeChange();
		}

		@Override
		public void viewChange(View view, Vector<Address> new_mbrs,
				Vector<Address> old_mbrs) {
			System.out.printf("%s:%,d size: %,d\n",
					view instanceof MergeView ? "merge_view" : "view", view
							.getVid().getId(), view.getMembers().size());
			if (channel instanceof MuxChannel) {
				final View clusterView = ((MuxChannel) channel)
						.getClusterView();
				// if (clusterView != null) {
				// System.out.printf("cluster_view %,d size: %,d\n\n",
				// clusterView.getVid().getId(), clusterView.size());
				// }
			}
			reportSizeChange();

			synchronized (viewChange) {
				viewChange.notify();
			}
		}

	}

	private static ReplicatedHashMap<String, Integer>[] maps;
	private static boolean USE_MULTIPLEXER = false;
	private static boolean WAIT_FOR_COMPLETE_VIEW = false;

	private static final Object viewChange = new Object();

	public static void main(final String[] args) throws Exception {
		final int mapCount = args.length > 0 ? Integer.parseInt(args[0]) : 1;
		final int nodeCount = args.length > 1 ? Integer.parseInt(args[1]) : 1;
		final int keyCount = args.length > 2 ? Integer.parseInt(args[2]) : 50;
		final String nodeName = InetAddress.getLocalHost().getHostName();
		final Random random = new Random();

		// System.setProperty("jgroups.ping.timeout", "15000");
		// System.setProperty("jgroups.ping.num_initial_members",
		// Integer.toString(nodeCount));

		final URL url = ReplicatedMapDemo.class.getClassLoader().getResource(
				"META-INF/groups/jgroups.xml");
		final JChannelFactory channelFactory = new JChannelFactory();
		channelFactory.setMultiplexerConfig(url);

		maps = new ReplicatedHashMap[mapCount];
		for (int i = 0; i < mapCount; i++) {
			final String name = "map" + i;
			final Channel channel = channelFactory.createMultiplexerChannel("udp", name);
			channel.setOpt(Channel.AUTO_RECONNECT, true);
			System.out.println("mux : " + (channel instanceof MuxChannel));
			maps[i] = new ReplicatedHashMap<String, Integer>(channel);
			maps[i].setBlockingUpdates(true);
			maps[i].addNotifier(new MyNotification(channel, maps[i]));
			maps[i].setTimeout(2000);
			channel.connect(name);
			final View view = channel.getView();
			System.out.printf("initial view %,d has %,d members.\n", view
					.getVid().getId(), view.size());
			maps[i].start(1000);
		}

		if (WAIT_FOR_COMPLETE_VIEW) {
			synchronized (viewChange) {
				boolean complete = false;
				while (!complete) {
					complete = true;
					for (int i = 0; i < mapCount; i++) {
						if (maps[i].getChannel().getView().size() < mapCount) {
							complete = false;
						}
					}
					if (complete) {
						break;
					}
					System.err.println("Waiting for complete view.");
					viewChange.wait();
				}
			}
		}

		for (int i = 0; i < mapCount; i++) {
			for (int j = 0; j < keyCount; j++) {
				maps[i].put(nodeName + "_" + j, i * 1000000 + j);
				if (random.nextInt(10) < 5) {
					Thread.sleep(100);
				}
			}
		}
		// Verify that everything I put in there is correct

		for (int i = 0; i < mapCount; i++) {
			for (int j = 0; j < keyCount; j++) {
				Integer value = maps[i].get(nodeName + "_" + j);
				if (value == null || value.intValue() != i * 1000000 + j) {
					System.out.println("My own value is wrong for i,j=" + i
							+ "," + j + " value=" + value);
				}
			}
		}

		// Now poll for full results.

		for (int retries = 0; retries < 1000; retries++) {
			final Map<String, int[]> resultMap = new TreeMap<String, int[]>();
			for (int i = 0; i < mapCount; i++) {
				for (final Map.Entry<String, Integer> entry : maps[i]
						.entrySet()) {
					String[] pieces = entry.getKey().split("_");
					assert pieces.length == 2 && pieces[1].length() > 0
							&& pieces[1].length() > 0 : "Invalid map key format for "
							+ entry.getKey();
					final String host = pieces[0];
					final int keyNum = Integer.parseInt(pieces[1]);
					assert keyNum >= 0 && keyNum < keyCount;
					int[] values = resultMap.get(host);
					if (values == null) {
						values = new int[mapCount];
						resultMap.put(host, values);
					}
					values[i]++;
				}
			}
			System.out.println();
			System.out.println("Retry #" + retries);
			boolean passed = true;
			for (final Map.Entry<String, int[]> entry : resultMap.entrySet()) {
				System.out.print(String.format("  %15s", entry.getKey()));
				final int[] values = entry.getValue();
				for (int i = 0; i < mapCount; i++) {
					if (values[i] < keyCount) {
						passed = false;
					}
					System.out.print(String.format("  %02d/%02d", values[i],
							keyCount));
				}
				System.out.println();
			}
			if (resultMap.size() == nodeCount && passed) {
				System.out.println();
				System.out.println("PASSED");
				break;
			}
			Thread.sleep(10000);
		}

		// for (int i = 0; i < mapCount; i++) {
		// maps[i].stop();
		// }

		Thread.sleep(Long.MAX_VALUE);
	}

}
