/*
 * Tigase Push - Push notifications component for Tigase
 * Copyright (C) 2017 Tigase, Inc. (office@tigase.com) - All Rights Reserved
 * Unauthorized copying of this file, via any medium is strictly prohibited
 * Proprietary and confidential
 */
package tigase.push.apns;

import groovy.json.JsonSlurper;
import tigase.xmpp.Authorization;

import javax.net.ssl.*;
import java.io.*;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.security.*;
import java.security.cert.CertificateEncodingException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.security.spec.InvalidKeySpecException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Base64;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.logging.Level;
import java.util.logging.Logger;

public class ApnsService {

	private static final Logger log = Logger.getLogger(ApnsService.class.getCanonicalName());

	private final String apnsEndoint;
	private final HttpClient httpClient;
	private final APNSTokenManager apnsTokenManager;

	private static HttpClient.Builder createClientBuilder(Builder builder)
			throws NoSuchAlgorithmException, KeyStoreException, KeyManagementException {
		if (builder.keyManagerFactory != null) {
			SSLContext sslContext = SSLContext.getInstance("TLS");
			TrustManagerFactory tmf = TrustManagerFactory.getInstance("sunx509");
			tmf.init((KeyStore) null);
			TrustManager[] origTrustManagers = tmf.getTrustManagers();
			TrustManager[] trustManagers = new TrustManager[origTrustManagers.length + 1];
			for (int i = 0; i < origTrustManagers.length; i++) {
				trustManagers[i + 1] = origTrustManagers[i];
			}
			trustManagers[0] = new SSLTrustManager(tmf.getTrustManagers());
			sslContext.init(builder.keyManagerFactory.getKeyManagers(), trustManagers, null);
			SSLParameters sslParameters = sslContext.getSupportedSSLParameters();
			sslParameters.setApplicationProtocols(new String[] {"h2"});
			return HttpClient.newBuilder()
					.sslContext(sslContext)
					.sslParameters(sslParameters);
		} else if (builder.encryptionKey != null) {
			return HttpClient.newBuilder();
		} else {
			throw new IllegalArgumentException("Missing encryption key or certificate configuration!") ;
		}
	}

	private ApnsService(Builder builder) throws IOException {
		try {
			apnsEndoint = builder.apnsEndpoint;

			HttpClient.Builder clientBuilder = createClientBuilder(builder);
			if (builder.executor != null) {
				clientBuilder.executor(builder.executor);
			}
			httpClient = clientBuilder.version(HttpClient.Version.HTTP_2).build();
			if (builder.encryptionKey == null) {
				apnsTokenManager = null;
			} else {
				apnsTokenManager = new APNSTokenManager(builder.encryptionKey);
			}
		} catch (Exception ex) {
			throw new IOException("Could not initialize HttpClient", ex);
		}
	}

	public enum AuthorizationType {
		certificate,
		token
	}

	public AuthorizationType getAuthorizationType() {
		if (apnsTokenManager != null) {
			return AuthorizationType.token;
		}
		return AuthorizationType.certificate;
	}

	public CompletableFuture<String> push(ApnsNotification notification) {
		CompletableFuture<String> future = new CompletableFuture<>();
		push(notification, future);
		return future;
	}

	public void push(ApnsNotification notification, CompletableFuture<String> future) {
		HttpRequest.Builder builder = HttpRequest.newBuilder(
				URI.create("https://" + apnsEndoint + "/3/device/" + notification.getDeviceId()))
				.POST(HttpRequest.BodyPublishers.ofString(notification.getPayload().toPayloadString()))
				.header("apns-id", notification.getId())
//				.header("apns-expiration","10")
//						"" + (LocalDateTime.now().plusDays(2).toInstant(ZoneOffset.UTC).getEpochSecond()))
				.header("apns-priority", String.valueOf(notification.getPriority()));
		if (apnsTokenManager != null) {
			try {
				var token = apnsTokenManager.getToken();
				if (log.isLoggable(Level.FINEST)) {
					log.finest("authenticating request with token " + token);
				}
				builder.header("authorization", "bearer " + token);
			} catch (GeneralSecurityException ex) {
				log.log(Level.FINEST, "failed to send notification: " + notification.toString() + " at " + this +
						", exception thrown during auth token generation", ex);
				future.completeExceptionally(ex);
				return;
			}
		}
//		if (notification.getPushType() != ApnsNotification.PushType.voip) {
			builder.header("apns-push-type", notification.getPushType().name());
//		}
		if (notification.getCollapseId() != null) {
			builder.header("apns-collapse-id", notification.getCollapseId());
		}
		if (notification.getTopic() != null) {
			builder.header("apns-topic", notification.getTopic());
		}

		HttpRequest request = builder.build();
		if (log.isLoggable(Level.FINEST)) {
			log.log(Level.FINEST, "trying to send push notification " + notification + " at " + this + ", request: " + request.toString() + ", headers: " + request.headers());
		}
		httpClient.sendAsync(request,
							 responseInfo -> HttpResponse.BodySubscribers.ofString(Charset.forName("UTF-8")))
				.thenAccept(response -> {
					String apnsId = response.headers().firstValue("apns-id").get();
					if (log.isLoggable(Level.FINEST)) {
						var apnsUniqueId = response.headers().firstValue("apns-unique-id");
						log.log(Level.FINEST,
								"GOT response " + response.statusCode() + " " + response.body() + " apns-id " +
										apnsId + ", apns-unique-id " + apnsUniqueId);
					}
					ErrorCode errorCode = ErrorCode.fromStatus(response.statusCode());
					if (errorCode == null) {
						if (log.isLoggable(Level.FINEST)) {
							log.log(Level.FINEST, "sent notification: " + notification.toString() + " at " + this);
						}
						future.complete(apnsId);
					} else {
						Map<String, Object> result = (Map<String, Object>) new JsonSlurper().parse(new StringReader(response.body()));
						ErrorType errorType = Optional.ofNullable(result)
								.map(map -> map.get("reason"))
								.map(Object::toString)
								.map(ErrorType::from)
								.orElse(ErrorType.unknownError);

						log.log(Level.FINEST, "failed to send notification: " + notification.toString() + " at " + this + ", error: " +
								errorCode.name() + ", response: " + response.body());

						future.completeExceptionally(new ApnsServiceException(errorCode, errorType));
					}
				}).exceptionally( throwable -> {
					log.log(Level.FINEST, "failed to send notification: " + notification.toString() + " at " + this + ", exception thrown!", throwable);
					future.completeExceptionally(throwable);
					return null;
				});
	}

	public void stop() {
		// nothing to do..
	}
	
	public static class APNSTokenManager {
		private final EncryptionKey encryptionKey;
		private volatile APNSToken token = new APNSToken(Instant.MIN, "");

		public APNSTokenManager(EncryptionKey encryptionKey) throws GeneralSecurityException  {
			this.encryptionKey = encryptionKey;
			regenerateIfNeeded();
		}

		public String getToken() throws GeneralSecurityException  {
			regenerateIfNeeded();
			return token.getToken();
		}

		private void regenerateIfNeeded() throws GeneralSecurityException {
			if (token.shouldRegenerate()) {
				synchronized (this) {
					if (token.shouldRegenerate()) {
						Instant createdAt = Instant.now();
						long iat = createdAt.getEpochSecond();
						String header = "{\"alg\":\"ES256\",\"kid\":\"" + encryptionKey.getId() + "\"}";
						String payload = "{\"iss\":\"" + encryptionKey.getTeamId() + "\",\"iat\":" + iat + "}";
						String tokenPrefix = Base64.getUrlEncoder().withoutPadding().encodeToString(header.getBytes(StandardCharsets.UTF_8)) +
								"." + Base64.getUrlEncoder().withoutPadding().encodeToString(payload.getBytes(StandardCharsets.UTF_8));
						Signature ecdsa = Signature.getInstance("SHA256withECDSA");
						ecdsa.initSign(encryptionKey.getKey());
						ecdsa.update(tokenPrefix.getBytes(StandardCharsets.UTF_8));
						String signature = Base64.getUrlEncoder().withoutPadding().encodeToString(ecdsa.sign());
						token = new APNSToken(createdAt, tokenPrefix + "." + signature);
					}
				}
			}
		}
	}

	public static class APNSToken {
		private final Instant createdAt;
		private final String token;
		
		public APNSToken(Instant createdAt, String token) {
			this.createdAt = createdAt;
			this.token = token;
		}

		public String getToken() {
			return token;
		}

		public boolean isValid() {
			return createdAt.isAfter(Instant.now().minus(1L, ChronoUnit.HOURS));
		}

		public boolean shouldRegenerate() {
			return createdAt.isBefore(Instant.now().minus(50L, ChronoUnit.MINUTES));
		}
	}

	@Override
	public String toString() {
		return "ApnsService{" + "apnsEndoint='" + apnsEndoint + '\'' + '}';
	}

	public static Builder newBuilder() {
		return new Builder();
	}

	public static class Builder {

		private static final String APNS_DEVELOPMENT = "api.sandbox.push.apple.com";
		private static final String APNS_PRODUCTION = "api.push.apple.com";
		private String apnsEndpoint = APNS_PRODUCTION;
		private ApnsDelegate delegate;
		private EncryptionKey encryptionKey;
		private Executor executor;
		private KeyManagerFactory keyManagerFactory;

		private Builder() {

		}

		public Builder withAppleDestination(boolean value) {
			apnsEndpoint = value ? APNS_PRODUCTION : APNS_DEVELOPMENT;
			return this;
		}

		public Builder withCert(String certificatePath, String base64certificate, String certificatePassword)
				throws IOException {
			if (base64certificate != null) {
				return withBase64Cert(base64certificate, certificatePassword);
			} else if (certificatePath != null) {
				return withCert(certificatePath, certificatePassword);
			} else {
				return this;
			}
		}

		public Builder withCert(String certPath, String certPass) throws IOException {
			try (FileInputStream fis = new FileInputStream(certPath)) {
				return withCert(fis, certPass);
			}
		}

		public Builder withCert(InputStream is, String certPass) throws IOException {
			KeyStore keyStore = APNSUtil.loadCertificate(is, certPass);
			return withCertificateKeyStore(keyStore, certPass);
		}

		public Builder withEncryptionKey(String encryptionKeyId, String encryptionKey, String encryptionKeyPath, String teamId)
				throws InvalidKeySpecException, NoSuchAlgorithmException, IOException {
			if (encryptionKey != null) {
				return withEncryptionKey(encryptionKeyId, encryptionKey, teamId);
			} else if (encryptionKeyPath != null) {
				return withEncryptionKeyFile(encryptionKeyId, encryptionKeyPath, teamId);
			} else {
				return this;
			}
		}

		public Builder withEncryptionKey(String encryptionKeyId, String encryptionKey, String teamId)
				throws InvalidKeySpecException, NoSuchAlgorithmException {
			PrivateKey key = APNSUtil.loadPrivateKey(encryptionKey);
			return withEncryptionKey(encryptionKeyId, key, teamId);
		}

		public Builder withEncryptionKeyFile(String encryptionKeyId, String encryptionKeyPath, String teamId)
				throws InvalidKeySpecException, NoSuchAlgorithmException, IOException {
			try(FileInputStream fis = new FileInputStream(encryptionKeyPath)) {
				return withEncryptionKey(encryptionKeyId, new String(fis.readAllBytes(), StandardCharsets.UTF_8), teamId);
			}
		}

		public Builder withEncryptionKey(String encryptionKeyId, PrivateKey encryptionKey, String teamId) {
			this.encryptionKey = new EncryptionKey(encryptionKeyId, encryptionKey, teamId);
			return this;
		}

		public Builder withCertificateKeyStore(KeyStore keyStore, String keyStorePassword) throws IOException {
			try {
				keyManagerFactory = KeyManagerFactory.getInstance("sunx509");
				keyManagerFactory.init(keyStore, keyStorePassword.toCharArray());
				return this;
			} catch (Exception ex) {
				throw new IOException("Could not initialize key manager factory", ex);
			}
		}
		
		public ApnsService build() throws IOException {
			return new ApnsService(this);
		}

		public Builder withBase64Cert(String base64certificate, String certificatePassword) throws IOException {
			return withCert(APNSUtil.inputStreamFromBase64(base64certificate), certificatePassword);
		}
	}
	
	public enum ErrorCode {
		badRequest,
		authenticationFailure,
		invalidMethod,
		deviceTokenInactive,
		payloadTooLarge,
		tooManyRequestsForDeviceToken,
		internalServerError,
		serverShutdown;

		public static ErrorCode fromStatus(int statusCode) {
			switch (statusCode) {
				case 200:
					return null;
				case 400:
					return ErrorCode.badRequest;
				case 403:
					return ErrorCode.authenticationFailure;
				case 405:
					return ErrorCode.invalidMethod;
				case 410:
					return ErrorCode.deviceTokenInactive;
				case 413:
					return ErrorCode.payloadTooLarge;
				case 429:
					return ErrorCode.tooManyRequestsForDeviceToken;
				case 500:
					return ErrorCode.internalServerError;
				case 503:
					return ErrorCode.serverShutdown;
				default:
					return ErrorCode.internalServerError;
			}
		}
	}

	public enum ErrorType {
		badCollapseId,
		badDeviceToken,
		badExpirationDate,
		badMessageId,
		badPriority,
		badTopic,
		deviceTokenNotForTopic,
		duplicatedHeaders,
		idleTimeout,
		invalidPushType,
		missingDeviceToken,
		missingTopic,
		payloadEmpty,
		topicDisallowed,

		badCertificate,
		badCertificateEnvironment,
		expiredProviderToken,
		forbidden,
		invalidProviderToken,
		missingProviderToken,

		badPath,

		methodNotAllowed,

		unregistred,

		payloadTooLarge,

		tooManyProviderTokenUpdates,
		tooManyRequests,

		internalServerError,

		serviceUnavailable,
		shutdown,

		unknownError;

		public static ErrorType from(String type) {
			if (type == null) {
				return unknownError;
			}
			switch (type) {
				case "BadCollapseId":
					return badCollapseId;
				case "BadDeviceToken":
					return badDeviceToken;
				case "BadExpirationDate":
					return badExpirationDate;
				case "BadMessageId":
					return badMessageId;
				case "BadPriority":
					return badPriority;
				case "BadTopic":
					return badTopic;
				case "DeviceTokenNotForTopic":
					return deviceTokenNotForTopic;
				case "DuplicateHeaders":
					return duplicatedHeaders;
				case "IdleTimeout":
					return idleTimeout;
				case "InvalidPushType":
					return invalidPushType;
				case "MissingDeviceToken":
					return missingDeviceToken;
				case "MissingTopic":
					return missingTopic;
				case "PayloadEmpty":
					return payloadEmpty;
				case "TopicDisallowed":
					return topicDisallowed;
				case "BadCertificate":
					return badCertificate;
				case "BadCertificateEnvironment":
					return badCertificateEnvironment;
				case "ExpiredProviderToken":
					return expiredProviderToken;
				case "Forbidden":
					return forbidden;
				case "InvalidProviderToken":
					return invalidProviderToken;
				case "MissingProviderToken":
					return missingProviderToken;
				case "BadPath":
					return badPath;
				case "MethodNotAllowed":
					return methodNotAllowed;
				case "Unregistered":
					return unregistred;
				case "PayloadTooLarge":
					return payloadTooLarge;
				case "TooManyProviderTokenUpdates":
					return tooManyProviderTokenUpdates;
				case "TooManyRequests":
					return tooManyRequests;
				case "InternalServerError":
					return internalServerError;
				case "ServiceUnavailable":
					return serviceUnavailable;
				case "Shutdown":
					return shutdown;
				default:
					return unknownError;
			}
		}

		public Authorization getErrorCondition() {
			switch (this) {
				case badDeviceToken:
				case deviceTokenNotForTopic:
				case unregistred:
					return Authorization.ITEM_NOT_FOUND;
				case topicDisallowed:
					return Authorization.NOT_ALLOWED;
				case payloadTooLarge:
					return Authorization.POLICY_VIOLATION;
				case tooManyRequests:
					return Authorization.RESOURCE_CONSTRAINT;
				default:
					return Authorization.INTERNAL_SERVER_ERROR;
			}
		}

		public boolean shouldRetry() {
			switch (this) {
				case internalServerError:
				case serviceUnavailable:
				case shutdown:
					return true;
				default:
					return false;
			}
		}
	}

	public static class EncryptionKey {
		private final String id;
		private final PrivateKey key;
		private final String teamId;

		public EncryptionKey(String id, PrivateKey key, String teamId) {
			this.id = id;
			this.key = key;
			this.teamId = teamId;
		}

		public String getId() {
			return id;
		}
		
		public PrivateKey getKey() {
			return key;
		}

		public String getTeamId() {
			return teamId;
		}
	}

	public static class SSLTrustManager implements X509TrustManager {

		// whitelisted as mentioned in https://wiki.mozilla.org/CA/Additional_Trust_Changes
		private static final String[] WHITELISTED_HASHES = {
				"c0554bde87a075ec13a61f275983ae023957294b454caf0a9724e3b21b7935bc",
				"56e98deac006a729afa2ed79f9e419df69f451242596d2aaf284c74a855e352e",
				"7289c06dedd16b71a7dcca66578572e2e109b11d70ad04c2601b6743bc66d07b",
				"fae46000d8f7042558541e98acf351279589f83b6d3001c18442e4403d111849",
				"b5cf82d47ef9823f9aa78f123186c52e8879ea84b0f822c91d83e04279b78fd5",
				"e24f8e8c2185da2f5e88d4579e817c47bf6eafbc8505f0f960fd5a0df4473ad3",
				"3174d9092f9531c06026ba489891016b436d5ec02623f9aafe2009ecc3e4d557"
		};

		private static final byte[][] WHITELISTED = Arrays.stream(WHITELISTED_HASHES).map(SSLTrustManager::hexDecode).toArray(byte[][]::new);

		private final TrustManager[] trustManagers;

		public SSLTrustManager(TrustManager[] trustManagers) {
			this.trustManagers = trustManagers;
		}

		@Override
		public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException {
			Optional<X509TrustManager> optionalParentTrustManager = Arrays.stream(trustManagers)
					.filter(X509TrustManager.class::isInstance)
					.map(X509TrustManager.class::cast)
					.findFirst();
			if (optionalParentTrustManager.isEmpty()) {
				throw new CertificateException("Could not verify certificate validity");
			} else {
				optionalParentTrustManager.get().checkClientTrusted(chain, authType);
			}
		}

		@Override
		public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
			if (checkCertificateChainFingerprint(chain)) {
				return;
			} else {
				Optional<X509TrustManager> optionalParentTrustManager = Arrays.stream(trustManagers)
						.filter(X509TrustManager.class::isInstance)
						.map(X509TrustManager.class::cast)
						.findFirst();

				if (optionalParentTrustManager.isEmpty()) {
					throw new CertificateException("Could not verify certificate validity");
				} else {
					optionalParentTrustManager.get().checkServerTrusted(chain, authType);
				}
			}
		}

		@Override
		public X509Certificate[] getAcceptedIssuers() {
			return new X509Certificate[0];
		}

		protected boolean checkCertificateChainFingerprint(X509Certificate[] chain) {
			try {
				for (int i = 0; i < chain.length; i++) {
					if (checkCerificateFingerprint(chain[i])) {
						return true;
					}
				}
			} catch (Throwable ex) {
				log.log(Level.FINEST, "Could not check certificate fingerprint", ex);
			}
			return false;
		}

		protected boolean checkCerificateFingerprint(X509Certificate certificate)
				throws NoSuchAlgorithmException, CertificateEncodingException {
			byte[] hash = sha256(certificate.getPublicKey().getEncoded());
			return Arrays.stream(WHITELISTED).anyMatch(entry -> Arrays.equals(entry, hash));
		}

		public static byte[] hexDecode(String hex) {
			int l = hex.length();
			byte[] data = new byte[l / 2];

			for(int i = 0; i < l; i += 2) {
				data[i / 2] = (byte)((Character.digit(hex.charAt(i), 16) << 4)
						+ Character.digit(hex.charAt(i + 1), 16));
			}

			return data;
		}

		public static byte[] sha256(byte[] data) throws NoSuchAlgorithmException {
			MessageDigest md = MessageDigest.getInstance("SHA-256");
			return md.digest(data);
		}

	}
}
