/**
 * Copyright (c) 2012 - 2018 Data In Motion and others.
 * All rights reserved. 
 * 
 * This program and the accompanying materials are made available under the terms of the 
 * Eclipse Public License v1.0 which accompanies this distribution, and is available at
 * http://www.eclipse.org/legal/epl-v10.html
 * 
 * Contributors:
 *     Data In Motion - initial API and implementation
 */
package org.gecko.rsa.provider;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import java.util.logging.Level;
import java.util.logging.Logger;

import org.eclipse.emf.ecore.EObject;
import org.gecko.emf.osgi.ResourceSetFactory;
import org.gecko.osgi.messaging.Message;
import org.gecko.osgi.messaging.MessagingService;
import org.gecko.rsa.core.DeSerializationContext;
import org.gecko.rsa.core.DeSerializer;
import org.gecko.rsa.core.SerializationContext;
import org.gecko.rsa.core.Serializer;
import org.gecko.rsa.provider.ser.BasicObjectInputStream;
import org.gecko.rsa.provider.ser.BasicObjectOutputStream;
import org.gecko.rsa.provider.ser.RequestDeSerializer;
import org.gecko.rsa.provider.ser.ResponseDeSerializer;
import org.gecko.rsa.provider.ser.VersionMarker;
import org.gecko.rsa.rsaprovider.EObjectRequestParameter;
import org.gecko.rsa.rsaprovider.RSAProviderFactory;
import org.gecko.rsa.rsaprovider.RSARequest;
import org.gecko.rsa.rsaprovider.RSAResponse;
import org.gecko.rsa.rsaprovider.RequestParameter;
import org.osgi.framework.ServiceException;
import org.osgi.framework.Version;
import org.osgi.util.promise.Deferred;
import org.osgi.util.promise.Promise;
import org.osgi.util.pushstream.PushStream;

/**
 * Invocation handler that is used as proxy for the client implementation of a distribution provider 
 * @author Mark Hoffmann
 * @since 07.07.2018
 */
public class MessagingInvocationHandler implements InvocationHandler {

	private static final Logger logger = Logger.getLogger(MessagingInvocationHandler.class.getName());
	private final MessagingService messaging;
	private final String topicAddress;
	private final ClassLoader cl;
	private PushStream<Message> receiveData = null;
	private final Map<String, CountDownLatch> waitLatch = new ConcurrentHashMap<String, CountDownLatch>();
	private final Map<String, ObjectInputStream> dataMap = new ConcurrentHashMap<String, ObjectInputStream>();
	private final Map<String, RSARequest> requestMap = new ConcurrentHashMap<String, RSARequest>();
	private final Serializer<RSARequest, SerializationContext> serializer;
	private final DeSerializer<RSAResponse, DeSerializationContext> deserializer;

	/**
	 * Creates a new instance.
	 */
	public MessagingInvocationHandler(MessagingService messaging, ResourceSetFactory resourceSetFactory, String endpointId, ClassLoader cl) {
		this.messaging = messaging;
		serializer = new RequestDeSerializer(resourceSetFactory);
		deserializer = new ResponseDeSerializer(resourceSetFactory);
		this.topicAddress = String.format(MessagingRSAEndpoint.MA_DATA_TOPIC, endpointId);
		String responseTopicAddress = String.format(MessagingRSAEndpoint.MA_DATA_RESPONSE_TOPIC, endpointId);
		this.cl = cl;
		try {
			this.receiveData = messaging.subscribe(responseTopicAddress);
//			this.receiveData.forEach(this::handleDataResponse);
			this.receiveData.forEach(this::handleDataResponseNew);
		} catch (Exception e) {
			logger.log(Level.SEVERE, String.format("Error subscribing to receiver topic '%s'", endpointId));
		}

	}

	/**
	 * @param message
	 */
	@SuppressWarnings("unused")
	private void handleDataResponse(Message message) {
		ByteBuffer buffer = message.payload();
		ByteArrayInputStream bais = new ByteArrayInputStream(buffer.array());
		try (ObjectInputStream in = new BasicObjectInputStream(bais, cl)) {
			String id = (String) in.readObject();
			CountDownLatch latch = waitLatch.remove(id);
			if (latch != null) {
				dataMap.put(id, in);
				latch.countDown();
			} else {
				logger.severe(String.format("Did not found a count down latch for id '%s'", id));
			}
		} catch (IOException e) {
			logger.log(Level.SEVERE, "Cannot create BasicInputStream from byte array", e);
		} catch (ClassNotFoundException e) {
			logger.log(Level.SEVERE, "Cannot find class to read UUID", e);
		}
	}

	/**
	 * @param message
	 */
	private void handleDataResponseNew(Message message) {
		ByteBuffer buffer = message.payload();
		ByteArrayInputStream bais = new ByteArrayInputStream(buffer.array());
		try {
			Promise<RSAResponse> responsePromise = deserializer.deserialize(bais);
			RSAResponse response = responsePromise.getValue();
			String id  = response.getId();
			CountDownLatch latch = waitLatch.remove(id);
			if (latch != null) {
				RSARequest request = requestMap.get(id);
				if (request != null) {
					response.setRequest(request);
				}
				latch.countDown();
			} else {
				logger.severe(String.format("Did not found a count down latch for id '%s'", id));
			}
		} catch (Exception e) {
			logger.log(Level.SEVERE, "Cannot find handle data response", e);
		}
	}

	/* 
	 * (non-Javadoc)
	 * @see java.lang.reflect.InvocationHandler#invoke(java.lang.Object, java.lang.reflect.Method, java.lang.Object[])
	 */
	@Override
	public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
		if (Future.class.isAssignableFrom(method.getReturnType()) ||
				CompletionStage.class.isAssignableFrom(method.getReturnType())) {
			return createFutureResult(method, args);
		} else if (Promise.class.isAssignableFrom(method.getReturnType())) {
			return createPromiseResult(method, args);
		} else {
			return handleSyncCallNew(method, args);
//			return handleSyncCall(method, args);
		}
	}

	private Object createFutureResult(final Method method, final Object[] args) {
		return CompletableFuture.supplyAsync(new Supplier<Object>() {
			public Object get() {
				try {
					return handleSyncCallNew(method, args);
//					return handleSyncCall(method, args);
				} catch (RuntimeException e) {
					throw e;
				} catch (Throwable e) {
					throw new RuntimeException(e);
				}
			}
		});
	}

	private Object createPromiseResult(final Method method, final Object[] args) {
		final Deferred<Object> deferred = new Deferred<Object>();
		new Thread(new Runnable() {

			@Override
			public void run() {
				try {
//					deferred.resolve(handleSyncCall(method, args));
					deferred.resolve(handleSyncCallNew(method, args));
				} catch (Throwable e) {
					deferred.fail(e);
				}
			}
		}).start();
		return deferred.getPromise();
	}

	@SuppressWarnings("unused")
	private Object handleSyncCall(Method method, Object[] args) throws Throwable {
		Object result;
		ByteArrayOutputStream baos = new ByteArrayOutputStream();
		try (

				ObjectOutputStream out = new BasicObjectOutputStream(baos)
				) {
			String id = UUID.randomUUID().toString();
			out.writeObject(id);
			out.writeObject(method.getName());

			out.writeObject(args);
			out.flush();
			ByteBuffer buffer = ByteBuffer.wrap(baos.toByteArray());
			messaging.publish(topicAddress, buffer);
			result = waitForResult(id);
		} catch (Throwable e) {
			if (e instanceof ServiceException) {
				throw e;
			}
			throw new ServiceException("Error calling '" + topicAddress + "' method: " + method.getName(), ServiceException.REMOTE, e);
		}
		if (result instanceof Throwable) {
			throw (Throwable)result;
		}
		return result;
	}
	
	private Object handleSyncCallNew(Method method, Object[] args) throws Throwable {
		Object result;
		try  {
			String id = UUID.randomUUID().toString();
			RSARequest request = RSAProviderFactory.eINSTANCE.createRSARequest();
			request.setId(id);
			request.setServiceName(method.getName());
			for (int i = 0; i< args.length; i++) {
				RequestParameter parameter = null;
				if (args[i] instanceof EObject) {
					parameter = RSAProviderFactory.eINSTANCE.createEObjectRequestParameter();
					((EObjectRequestParameter)parameter).setEObject((EObject) args[i]);
				} else {
					parameter = RSAProviderFactory.eINSTANCE.createRequestParameter();
					parameter.setObject(args[i]);
				}
				parameter.setNumber(i);
				request.getParameter().add(parameter);
			}
			requestMap.put(id, request);
			Promise<OutputStream> requestPromise = serializer.serialize(request);
			ByteArrayOutputStream baos = requestPromise.filter(os->os instanceof ByteArrayOutputStream).map(os->(ByteArrayOutputStream)os).getValue();
			ByteBuffer buffer = ByteBuffer.wrap(baos.toByteArray());
			messaging.publish(topicAddress, buffer);
			result = waitForResultNew(id);
		} catch (Throwable e) {
			if (e instanceof ServiceException) {
				throw e;
			}
			throw new ServiceException("Error calling '" + topicAddress + "' method: " + method.getName(), ServiceException.REMOTE, e);
		}
		if (result instanceof Throwable) {
			throw (Throwable)result;
		}
		return result;
	}

	private Object waitForResult(String id) throws Throwable {
		CountDownLatch latch = new CountDownLatch(1);
		waitLatch.put(id, latch);
		boolean waitResult;
		try {
			waitResult = latch.await(30, TimeUnit.SECONDS);
			if (waitResult) {
				ObjectInputStream resultStream = dataMap.remove(id);
				if (resultStream != null) {
					return parseResult(resultStream);
				} else {
					throw new ServiceException("Error calling '" + topicAddress + "' call returned no result, what should not happen");
				}
			} else {
				throw new ServiceException("Error calling '" + topicAddress + "' call timed out after 30 seconds");
			}
		} catch (InterruptedException e) {
			throw new ServiceException("Error calling '" + topicAddress + "' because wait lock was interrupted");
		} catch (Throwable e) {
			if (e instanceof ServiceException) {
				throw e;
			}
			throw new ServiceException("Error reading result", e);
		}
	}
	
	private Object waitForResultNew(String id) throws Throwable {
		CountDownLatch latch = new CountDownLatch(1);
		waitLatch.put(id, latch);
		boolean waitResult;
		try {
			waitResult = latch.await(30, TimeUnit.SECONDS);
			if (waitResult) {
				RSARequest request = requestMap.remove(id);
				if (request != null) {
					RSAResponse response = request.getResponse();
					if (response != null) {
						return response.isEObjectResult() ? response.getEObject() : response.getObject();
					}
					throw new ServiceException("Error calling '" + topicAddress + "' call returned response: " + response);
				} else {
					throw new ServiceException("Error calling '" + topicAddress + "' call returned no result, what should not happen");
				}
			} else {
				throw new ServiceException("Error calling '" + topicAddress + "' call timed out after 30 seconds");
			}
		} catch (InterruptedException e) {
			throw new ServiceException("Error calling '" + topicAddress + "' because wait lock was interrupted");
		} catch (Throwable e) {
			if (e instanceof ServiceException) {
				throw e;
			}
			throw new ServiceException("Error reading result", e);
		}
	}

	private Object parseResult(ObjectInputStream in) throws Throwable {
		return readReplaceVersion(in.readObject());
	}

	private Object readReplaceVersion(Object readObject) {
		if (readObject instanceof VersionMarker) {
			return new Version(((VersionMarker)readObject).getVersion());
		} else {
			return readObject;
		}
	}

}
