/*
 * Decompiled with CFR 0.152.
 */
package tigase.auth.impl;

import java.io.IOException;
import java.security.MessageDigest;
import java.security.cert.X509Certificate;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.AuthorizeCallback;
import javax.security.sasl.SaslException;
import tigase.auth.AuthRepositoryAware;
import tigase.auth.DomainAware;
import tigase.auth.MechanismNameAware;
import tigase.auth.SessionAware;
import tigase.auth.XmppSaslException;
import tigase.auth.callbacks.AuthorizationIdCallback;
import tigase.auth.callbacks.ChannelBindingCallback;
import tigase.auth.callbacks.PBKDIterationsCallback;
import tigase.auth.callbacks.SaltCallback;
import tigase.auth.callbacks.ServerKeyCallback;
import tigase.auth.callbacks.StoredKeyCallback;
import tigase.auth.callbacks.XMPPSessionCallback;
import tigase.auth.credentials.Credentials;
import tigase.auth.credentials.entries.PlainCredentialsEntry;
import tigase.auth.credentials.entries.ScramCredentialsEntry;
import tigase.auth.mechanisms.AbstractSasl;
import tigase.auth.mechanisms.AbstractSaslSCRAM;
import tigase.db.AuthRepository;
import tigase.util.Base64;
import tigase.xmpp.XMPPResourceConnection;
import tigase.xmpp.jid.BareJID;

public class ScramCallbackHandler
implements CallbackHandler,
AuthRepositoryAware,
SessionAware,
DomainAware,
MechanismNameAware {
    private static final Logger log = Logger.getLogger(ScramCallbackHandler.class.getCanonicalName());
    private String credentialId = null;
    private ScramCredentialsEntry credentialsEntry;
    private boolean credentialsFetched;
    private String domain;
    private BareJID jid = null;
    private boolean loggingInForbidden = false;
    private String mechanismName;
    private AuthRepository repo;
    private XMPPResourceConnection session;

    @Override
    public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
        for (int i = 0; i < callbacks.length; ++i) {
            if (log.isLoggable(Level.FINEST)) {
                log.log(Level.FINEST, "Callback: {0}", callbacks[i].getClass().getSimpleName());
            }
            this.handleCallback(callbacks[i]);
        }
    }

    @Override
    public void setAuthRepository(AuthRepository repo) {
        this.repo = repo;
    }

    @Override
    public void setDomain(String domain) {
        this.domain = domain;
    }

    @Override
    public void setMechanismName(String mechanismName) {
        this.mechanismName = mechanismName;
    }

    @Override
    public void setSession(XMPPResourceConnection session) {
        this.session = session;
    }

    protected void handleAuthorizeCallback(AuthorizeCallback authCallback) throws SaslException {
        String authenId = authCallback.getAuthenticationID();
        if (log.isLoggable(Level.FINEST)) {
            log.log(Level.FINEST, "AuthorizeCallback: authenId: {0}", authenId);
        }
        this.fetchCredentials();
        if (this.loggingInForbidden) {
            authCallback.setAuthorized(false);
            if (log.isLoggable(Level.FINEST)) {
                log.log(Level.FINEST, "User {0} is disabled", this.jid);
            }
            return;
        }
        String authorId = authCallback.getAuthorizationID();
        if (log.isLoggable(Level.FINEST)) {
            log.log(Level.FINEST, "AuthorizeCallback: authorId: {0}", authorId);
        }
        authCallback.setAuthorized(true);
        this.session.removeSessionData("authentication-jid");
    }

    protected void handleCallback(Callback callback) throws UnsupportedCallbackException, IOException {
        if (callback instanceof XMPPSessionCallback) {
            ((XMPPSessionCallback)callback).setSession(this.session);
        } else if (callback instanceof ChannelBindingCallback) {
            this.handleChannelBindingCallback((ChannelBindingCallback)callback);
        } else if (callback instanceof PBKDIterationsCallback) {
            this.handlePBKDIterationsCallback((PBKDIterationsCallback)callback);
        } else if (callback instanceof NameCallback) {
            this.handleNameCallback((NameCallback)callback);
        } else if (callback instanceof AuthorizationIdCallback) {
            this.handleAuthorizationIdCallback((AuthorizationIdCallback)callback);
        } else if (callback instanceof SaltCallback) {
            this.handleSaltCallback((SaltCallback)callback);
        } else if (callback instanceof ServerKeyCallback) {
            ServerKeyCallback serverKeyCallback = (ServerKeyCallback)callback;
            this.handleServerKeyCallback(serverKeyCallback);
        } else if (callback instanceof StoredKeyCallback) {
            StoredKeyCallback storedKeyCallback = (StoredKeyCallback)callback;
            this.handleStoredKeyCallback(storedKeyCallback);
        } else if (callback instanceof AuthorizeCallback) {
            this.handleAuthorizeCallback((AuthorizeCallback)callback);
        } else {
            throw new UnsupportedCallbackException(callback, "Unrecognized Callback " + callback);
        }
    }

    protected void handleNameCallback(NameCallback nc) throws IOException {
        this.credentialId = "default";
        BareJID jid = BareJID.bareJIDInstanceNS((String)nc.getDefaultName());
        if (jid.getLocalpart() == null || !this.domain.equalsIgnoreCase(jid.getDomain())) {
            jid = BareJID.bareJIDInstanceNS((String)nc.getDefaultName(), (String)this.domain);
        }
        this.setJid(jid);
        nc.setName(jid.toString());
        if (log.isLoggable(Level.FINEST)) {
            log.log(Level.FINEST, "NameCallback: {0}", this.credentialId);
        }
    }

    protected void handlePBKDIterationsCallback(PBKDIterationsCallback callback) throws SaslException {
        if (log.isLoggable(Level.FINEST)) {
            log.log(Level.FINEST, "PBKDIterationsCallback: {0}", this.jid);
        }
        this.fetchCredentials();
        if (this.credentialsEntry != null) {
            callback.setInterations(this.credentialsEntry.getIterations());
        }
    }

    protected void handleSaltCallback(SaltCallback callback) throws SaslException {
        if (log.isLoggable(Level.FINEST)) {
            log.log(Level.FINEST, "SaltCallback: {0}", this.jid);
        }
        this.fetchCredentials();
        if (this.credentialsEntry != null) {
            callback.setSalt(this.credentialsEntry.getSalt());
        } else {
            callback.setSalt(null);
        }
    }

    private void fetchCredentials() throws SaslException {
        if (this.credentialsFetched) {
            return;
        }
        try {
            Credentials credentials = this.repo.getCredentials(this.jid, this.credentialId);
            log.log(Level.FINE, "Fetched credentials for: " + this.jid + " with credentialsId: " + this.credentialId + ", credentials: " + credentials);
            if (credentials == null) {
                this.loggingInForbidden = true;
            } else {
                String mech = this.mechanismName.endsWith("-PLUS") ? this.mechanismName.substring(0, this.mechanismName.length() - "-PLUS".length()) : this.mechanismName;
                Credentials.Entry entry = credentials.getEntryForMechanism(mech);
                if (entry == null) {
                    entry = credentials.getEntryForMechanism("PLAIN");
                }
                if (entry instanceof ScramCredentialsEntry) {
                    this.credentialsEntry = (ScramCredentialsEntry)entry;
                } else if (entry instanceof PlainCredentialsEntry) {
                    this.credentialsEntry = new ScramCredentialsEntry(mech.replace("SCRAM-", ""), (PlainCredentialsEntry)entry);
                }
                boolean bl = this.loggingInForbidden = !credentials.canLogin();
                if (this.loggingInForbidden) {
                    throw XmppSaslException.getExceptionFor(credentials.getAccountStatus());
                }
            }
        }
        catch (SaslException e) {
            throw e;
        }
        catch (Exception ex) {
            log.log(Level.FINE, "Could not retrieve credentials for user " + this.jid + " with credentialId " + this.credentialId, ex);
        }
        this.credentialsFetched = true;
    }

    private void handleAuthorizationIdCallback(AuthorizationIdCallback callback) {
        if (!AbstractSasl.isAuthzIDIgnored() && callback.getAuthzId() != null && !callback.getAuthzId().equals(this.jid.toString())) {
            try {
                this.credentialId = this.jid.getLocalpart();
                this.setJid(BareJID.bareJIDInstance((String)callback.getAuthzId()));
            }
            catch (Exception ex) {
                throw new RuntimeException(ex);
            }
        } else {
            this.credentialId = "default";
            callback.setAuthzId(this.jid.toString());
        }
    }

    private void handleChannelBindingCallback(ChannelBindingCallback callback) {
        if (callback.getRequestedBindType() == AbstractSaslSCRAM.BindType.tls_exporter) {
            callback.setBindingData((byte[])this.session.getSessionData("TLS_EXPORTER_KEY"));
        } else if (callback.getRequestedBindType() == AbstractSaslSCRAM.BindType.tls_unique) {
            callback.setBindingData((byte[])this.session.getSessionData("TLS_UNIQUE_ID_KEY"));
        } else if (callback.getRequestedBindType() == AbstractSaslSCRAM.BindType.tls_server_end_point) {
            try {
                X509Certificate cert = (X509Certificate)this.session.getSessionData("LOCAL_CERTIFICATE_KEY");
                String algo = cert.getSigAlgName();
                int withIdx = algo.indexOf("with");
                if (withIdx <= 0) {
                    throw new RuntimeException("Unable to parse SigAlgName: " + algo);
                }
                String usealgo = algo.substring(0, withIdx);
                if (usealgo.equalsIgnoreCase("MD5") || usealgo.equalsIgnoreCase("SHA1")) {
                    usealgo = "SHA-256";
                }
                MessageDigest md = MessageDigest.getInstance(usealgo);
                byte[] der = cert.getEncoded();
                md.update(der);
                callback.setBindingData(md.digest());
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        if (log.isLoggable(Level.FINEST)) {
            log.log(Level.FINEST, "Channel binding {0}: {1} in session-id {2}", new Object[]{callback.getRequestedBindType(), callback.getBindingData() == null ? "null" : Base64.encode((byte[])callback.getBindingData()), this.session});
        }
    }

    private void handleServerKeyCallback(ServerKeyCallback callback) throws SaslException {
        if (log.isLoggable(Level.FINEST)) {
            log.log(Level.FINEST, "ServerKeyCallback: {0}", this.jid);
        }
        this.fetchCredentials();
        if (this.credentialsEntry != null) {
            callback.setServerKey(this.credentialsEntry.getServerKey());
        } else {
            callback.setServerKey(null);
        }
    }

    private void handleStoredKeyCallback(StoredKeyCallback callback) throws SaslException {
        if (log.isLoggable(Level.FINEST)) {
            log.log(Level.FINEST, "StoredKeyCallback: {0}", this.jid);
        }
        this.fetchCredentials();
        if (this.credentialsEntry != null) {
            callback.setStoredKey(this.credentialsEntry.getStoredKey());
        } else {
            callback.setStoredKey(null);
        }
    }

    private void setJid(BareJID jid) {
        this.jid = jid;
        if (jid != null) {
            this.session.putSessionData("authentication-jid", jid);
        }
    }
}

