package org.jgroups.tests;

import com.lmax.disruptor.*;
import com.lmax.disruptor.dsl.Disruptor;
import org.jgroups.Address;
import org.jgroups.Message;
import org.jgroups.Version;
import org.jgroups.logging.Log;
import org.jgroups.logging.LogFactory;
import org.jgroups.util.*;

import java.io.DataOutputStream;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;


public class BundlerStressTest {
    static int NUM_THREADS=10;
    static int NUM=1000000;
    static int PRINT=NUM / 10;


    static final AtomicInteger added=new AtomicInteger(0);
    static final AtomicInteger removed=new AtomicInteger(0);

    public static void main(String[] args) throws InterruptedException {
        int capacity=50000, max_bundle_size=60000;
        for(int i=0; i < args.length; i++) {
            if(args[i].equals("-num")) {
                NUM=Integer.parseInt(args[++i]);
                PRINT=NUM / 10;
                continue;
            }
            if(args[i].equals("-adders")) {
                NUM_THREADS=Integer.parseInt(args[++i]);
                continue;
            }
            if(args[i].equals("-capacity")) {
                capacity=Integer.parseInt(args[++i]);
                continue;
            }
            if(args[i].equals("-max_bundle_size")) {
                max_bundle_size=Integer.parseInt(args[++i]);
                continue;
            }
            System.out.println("BundlerStressTest [-num numbers] [-adders <number of adder threads>] " +
                                 "[-capacity <capacity>] [-max_bundle_size <max_bundle_size>]");
            return;
        }

        TransferQueueBundler bundler=new TransferQueueBundler(capacity, max_bundle_size);
        final CountDownLatch latch=new CountDownLatch(1);



        Adder[] adders=new Adder[NUM_THREADS];
        for(int i=0; i < adders.length; i++) {
            adders[i]=new Adder(bundler, latch, added);
            adders[i].start();
        }

        bundler.start();

        long start=System.currentTimeMillis();
        latch.countDown();
        bundler.waitForCompletion();
        long diff=System.currentTimeMillis() - start;
        double msgs_sec=NUM / (diff / 1000.0);

        System.out.println("added messages: " + added + ", removed messages: " + removed);
        System.out.println("took " + diff + " ms to insert and remove " + NUM + " messages: " + String.format("%.2f msgs/sec", msgs_sec));

        bundler.stop();
    }


    protected static class Adder extends Thread {
        protected final TransferQueueBundler bundler;
        protected final AtomicInteger num;
        protected final CountDownLatch latch;

        public Adder(TransferQueueBundler bundler, CountDownLatch latch, AtomicInteger num) {
            this.bundler=bundler;
            this.num=num;
            this.latch=latch;
            setName("Adder");
        }

        public void run() {
            try {
                latch.await();
            }
            catch(InterruptedException e) {
                e.printStackTrace();
            }

            final byte[] buf=new byte[1000];
            while(true) {
                int seqno=num.incrementAndGet();
                if(seqno > NUM) {
                    num.decrementAndGet();
                    break;
                }
                try {
                    bundler.send(new Message(null, null, buf));
                }
                catch(Exception e) {
                    e.printStackTrace();
                }
            }
        }
    }



    protected static class TransferQueueBundler implements EventHandler<TransferQueueBundler.Value> {
        // final int                          threshold;
        final com.lmax.disruptor.RingBuffer<Value> buffer;
        final ExecutorService executor=Executors.newSingleThreadExecutor();
        final Disruptor<Value> disruptor;
        volatile Thread                    bundler_thread;
        final Log                          log=LogFactory.getLog(getClass());

        /** Keys are destinations, values are lists of Messages */
        final Map<Address,List<Message>>   msgs=new HashMap<Address,List<Message>>(36);
        long                               count=0;    // current number of bytes accumulated
        int                                num_msgs=0;
        final int                          max_bundle_size;
        volatile boolean                   running=true;
        protected final Lock               lock=new ReentrantLock();
        protected final Condition          cond=lock.newCondition();
        public static final String         THREAD_NAME="TransferQueueBundler";

        protected final ExposedByteArrayOutputStream bundler_out_stream=new ExposedByteArrayOutputStream((int)(count + 50));
        protected final ExposedDataOutputStream bundler_dos=new ExposedDataOutputStream(bundler_out_stream);

        protected static final byte LIST=1; // we have a list of messages rather than a single message when set
        protected static final byte MULTICAST=2; // message is a multicast (versus a unicast) message when set
        protected static final byte OOB=4; // message has OOB flag set (Message.OOB)


        protected TransferQueueBundler(int capacity, int max_bundle_size) {
            if(capacity <=0) throw new IllegalArgumentException("Bundler capacity cannot be " + capacity);
            this.max_bundle_size=max_bundle_size;

            ClaimStrategy claim_strategy=new MultiThreadedClaimStrategy(2 << 19);
            // WaitStrategy wait_strategy=new SleepingWaitStrategy();
            WaitStrategy wait_strategy=new BlockingWaitStrategy();
            // WaitStrategy wait_strategy=new BusySpinWaitStrategy();
            // WaitStrategy wait_strategy=new YieldingWaitStrategy();

            disruptor = new Disruptor<Value>(new EventFactory<Value>() {
                public Value newInstance() {
                    return new Value(null);
                }
            }, executor, claim_strategy, wait_strategy);

            disruptor.handleEventsWith(this);
            buffer = disruptor.start();
        }

        public void start() {
            /*if(bundler_thread == null || !bundler_thread.isAlive()) {
                bundler_thread=new Thread(this, THREAD_NAME);
                running=true;
                bundler_thread.start();
            }*/
        }

        public Thread getThread() {return bundler_thread;}

        public void stop() {
            running=false;
            //if(bundler_thread != null)
              //  bundler_thread.interrupt();
            disruptor.shutdown();
            executor.shutdown();
        }

        public void send(Message msg) throws Exception {
            // Publishers claim events in sequence
            long sequence = buffer.next();
            Value event = buffer.get(sequence);

            event.setVal(msg);

            // make the event available to EventProcessors
            buffer.publish(sequence);
        }


        protected void waitForCompletion() {
            lock.lock();
            try {
                while(removed.get() < NUM) {
                    try {
                        cond.await();
                    }
                    catch(InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            }
            finally {
                lock.unlock();
            }
        }

        public void onEvent(Value event, long sequence, boolean endOfBatch) throws Exception {
            if(!running)
                return;
            Message msg=event.getVal();
            long size=msg.size();
            if(count + size >= max_bundle_size) { // || buffer.size() >= threshold) {
                sendMessages();
            }
            addMessage(msg);
            count+=size;
            if(endOfBatch)
                sendMessages();
        }

        public void run() {
            while(running) {
                Message msg=null;
                try {
                    if(count == 0) {
                        //msg=buffer.take();
                        if(msg == null)
                            continue;
                        long size=msg.size();
                        if(count + size >= max_bundle_size) { // || buffer.size() >= threshold) {
                            sendMessages();
                        }
                        addMessage(msg);
                        count+=size;
                    }
                    while(running) {
                    // while(null != (msg=buffer.poll())) {
                        long size=msg.size();
                        if(count + size >= max_bundle_size) { //  || buffer.size() >= threshold) {
                            sendMessages();
                        }
                        addMessage(msg);
                        count+=size;
                    }
                    if(count > 0)
                        sendMessages();
                }
                catch(Throwable t) {
                }
            }
        }


        void sendMessages() {
            for(List<Message> list: msgs.values()) {
                if(list != null) {
                    int size=list.size();
                    for(int i=0; i < size; i++) {
                        removed.incrementAndGet();
                        if(removed.get() % PRINT == 0)
                            System.out.println("added messages: " + added + ", removed messages: " + removed);
                    }
                }
            }
            
            if(removed.get() >= NUM) {
                lock.lock();
                try {
                    cond.signalAll();
                }
                finally {
                    lock.unlock();
                }
            }


            sendBundledMessages(msgs);
            msgs.clear();
            count=0;
        }

        private void checkLength(long len) throws Exception {
            if(len > max_bundle_size)
                throw new Exception("message size (" + len + ") is greater than max bundling size (" + max_bundle_size +
                        "). Set the fragmentation/bundle size in FRAG and TP correctly");
        }


        private void addMessage(Message msg) {
            Address dst=msg.getDest();
            List<Message> tmp=msgs.get(dst);
            if(tmp == null) {
                tmp=new LinkedList<Message>();
                msgs.put(dst, tmp);
            }
            tmp.add(msg);
            num_msgs++;
        }



        /**
         * Sends all messages from the map, all messages for the same destination are bundled into 1 message.
         * This method may be called by timer and bundler concurrently
         * @param msgs
         */
        private void sendBundledMessages(final Map<Address,List<Message>> msgs) {
            boolean   multicast;

            if(log.isTraceEnabled()) {
                double percentage=100.0 / max_bundle_size * count;
                StringBuilder sb=new StringBuilder("sending ").append(num_msgs).append(" msgs (");
                sb.append(count).append(" bytes (" + percentage + "% of max_bundle_size)");
                sb.append(" to ").append(msgs.size()).append(" destination(s)");
                if(msgs.size() > 1) sb.append(" (dests=").append(msgs.keySet()).append(")");
                log.trace(sb);
                num_msgs=0;
            }

            for(Map.Entry<Address,List<Message>> entry: msgs.entrySet()) {
                List<Message> list=entry.getValue();
                if(list.isEmpty())
                    continue;

                Address dest=entry.getKey();
                Address src_addr=list.get(0).getSrc();

                multicast=dest == null;
                try {
                    bundler_out_stream.reset();
                    bundler_dos.reset();
                    writeMessageList(dest, src_addr, list, bundler_dos, multicast); // flushes output stream when done
                    Buffer buf=new Buffer(bundler_out_stream.getRawBuffer(), 0, bundler_out_stream.size());
                    doSend(buf, dest, multicast);
                }
                catch(Throwable e) {
                    if(log.isErrorEnabled()) log.error("exception sending bundled msgs: " + e + ":, cause: " + e.getCause());
                }
            }
        }

        private static void doSend(Buffer buf, Address dest, boolean multicast) {
            ;
        }


        protected static void writeMessageList(Address dest, Address src, List<Message> msgs,
                                               DataOutputStream dos, boolean multicast) throws Exception {
            dos.writeShort(Version.version);

            byte flags=LIST;
            if(multicast)
                flags+=MULTICAST;

            dos.writeByte(flags);

            Util.writeAddress(dest, dos);

            Util.writeAddress(src, dos);

            if(msgs != null) {
                for(Message msg: msgs) {
                    dos.writeBoolean(true);
                    msg.writeToNoAddrs(src, dos);
                }
            }

            dos.writeBoolean(false); // terminating presence - no more messages will follow
        }




        public static class Value {
            protected Message val;

            public Value(Message val) {
                this.val=val;
            }

            public Message getVal() {
                return val;
            }

            public void setVal(Message val) {
                this.val=val;
            }

            public String toString() {
                return val.toString();
            }
        }

    }
}
