/*
 * Decompiled with CFR 0.152.
 */
package tigase.auth.impl;

import java.io.IOException;
import java.security.MessageDigest;
import java.security.cert.Certificate;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.AuthorizeCallback;
import tigase.auth.AuthRepositoryAware;
import tigase.auth.DomainAware;
import tigase.auth.MechanismNameAware;
import tigase.auth.SessionAware;
import tigase.auth.callbacks.AuthorizationIdCallback;
import tigase.auth.callbacks.ChannelBindingCallback;
import tigase.auth.callbacks.PBKDIterationsCallback;
import tigase.auth.callbacks.SaltCallback;
import tigase.auth.callbacks.SaltedPasswordCallback;
import tigase.auth.callbacks.XMPPSessionCallback;
import tigase.auth.credentials.Credentials;
import tigase.auth.credentials.entries.PlainCredentialsEntry;
import tigase.auth.credentials.entries.ScramCredentialsEntry;
import tigase.auth.mechanisms.AbstractSasl;
import tigase.auth.mechanisms.AbstractSaslSCRAM;
import tigase.db.AuthRepository;
import tigase.util.Base64;
import tigase.xmpp.XMPPResourceConnection;
import tigase.xmpp.jid.BareJID;

public class ScramCallbackHandler
implements CallbackHandler,
AuthRepositoryAware,
SessionAware,
DomainAware,
MechanismNameAware {
    private static final Logger log = Logger.getLogger(ScramCallbackHandler.class.getCanonicalName());
    private boolean accountDisabled = false;
    private ScramCredentialsEntry credentialsEntry;
    private boolean credentialsFetched;
    private String domain;
    private BareJID jid = null;
    private String mechanismName;
    private AuthRepository repo;
    private XMPPResourceConnection session;
    private String username = null;

    @Override
    public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
        for (int i = 0; i < callbacks.length; ++i) {
            if (log.isLoggable(Level.FINEST)) {
                log.log(Level.FINEST, "Callback: {0}", callbacks[i].getClass().getSimpleName());
            }
            this.handleCallback(callbacks[i]);
        }
    }

    @Override
    public void setMechanismName(String mechanismName) {
        this.mechanismName = mechanismName;
    }

    @Override
    public void setAuthRepository(AuthRepository repo) {
        this.repo = repo;
    }

    @Override
    public void setDomain(String domain) {
        this.domain = domain;
    }

    @Override
    public void setSession(XMPPResourceConnection session) {
        this.session = session;
    }

    protected void handleAuthorizeCallback(AuthorizeCallback authCallback) {
        String authenId = authCallback.getAuthenticationID();
        if (log.isLoggable(Level.FINEST)) {
            log.log(Level.FINEST, "AuthorizeCallback: authenId: {0}", authenId);
        }
        this.fetchCredentials();
        if (this.accountDisabled) {
            authCallback.setAuthorized(false);
            if (log.isLoggable(Level.FINEST)) {
                log.log(Level.FINEST, "User {0} is disabled", this.jid);
            }
            return;
        }
        String authorId = authCallback.getAuthorizationID();
        if (log.isLoggable(Level.FINEST)) {
            log.log(Level.FINEST, "AuthorizeCallback: authorId: {0}", authorId);
        }
        authCallback.setAuthorized(true);
        this.session.removeSessionData("authentication-jid");
    }

    protected void handleCallback(Callback callback) throws UnsupportedCallbackException, IOException {
        if (callback instanceof XMPPSessionCallback) {
            ((XMPPSessionCallback)callback).setSession(this.session);
        } else if (callback instanceof ChannelBindingCallback) {
            this.handleChannelBindingCallback((ChannelBindingCallback)callback);
        } else if (callback instanceof PBKDIterationsCallback) {
            this.handlePBKDIterationsCallback((PBKDIterationsCallback)callback);
        } else if (callback instanceof SaltedPasswordCallback) {
            this.handleSaltedPasswordCallbackCallback((SaltedPasswordCallback)callback);
        } else if (callback instanceof NameCallback) {
            this.handleNameCallback((NameCallback)callback);
        } else if (callback instanceof AuthorizationIdCallback) {
            this.handleAuthorizationIdCallback((AuthorizationIdCallback)callback);
        } else if (callback instanceof SaltCallback) {
            this.handleSaltCallback((SaltCallback)callback);
        } else if (callback instanceof AuthorizeCallback) {
            this.handleAuthorizeCallback((AuthorizeCallback)callback);
        } else {
            throw new UnsupportedCallbackException(callback, "Unrecognized Callback " + callback);
        }
    }

    protected void handleNameCallback(NameCallback nc) throws IOException {
        this.username = "default";
        this.setJid(BareJID.bareJIDInstanceNS((String)nc.getDefaultName(), (String)this.domain));
        nc.setName(this.jid.toString());
        if (log.isLoggable(Level.FINEST)) {
            log.log(Level.FINEST, "NameCallback: {0}", this.username);
        }
    }

    protected void handlePBKDIterationsCallback(PBKDIterationsCallback callback) {
        if (log.isLoggable(Level.FINEST)) {
            log.log(Level.FINEST, "PBKDIterationsCallback: {0}", this.jid);
        }
        this.fetchCredentials();
        if (this.credentialsEntry != null) {
            callback.setInterations(this.credentialsEntry.getIterations());
        }
    }

    protected void handleSaltCallback(SaltCallback callback) {
        if (log.isLoggable(Level.FINEST)) {
            log.log(Level.FINEST, "SaltCallback: {0}", this.jid);
        }
        this.fetchCredentials();
        if (this.credentialsEntry != null) {
            callback.setSalt(this.credentialsEntry.getSalt());
        } else {
            callback.setSalt(null);
        }
    }

    protected void handleSaltedPasswordCallbackCallback(SaltedPasswordCallback callback) {
        if (log.isLoggable(Level.FINEST)) {
            log.log(Level.FINEST, "PasswordCallback: {0}", this.jid);
        }
        this.fetchCredentials();
        if (this.credentialsEntry != null) {
            callback.setSaltedPassword(this.credentialsEntry.getSaltedPassword());
        } else {
            callback.setSaltedPassword(null);
        }
    }

    private void handleAuthorizationIdCallback(AuthorizationIdCallback callback) {
        if (!AbstractSasl.isAuthzIDIgnored() && callback.getAuthzId() != null && !callback.getAuthzId().equals(this.jid.toString())) {
            try {
                this.username = this.jid.getLocalpart();
                this.setJid(BareJID.bareJIDInstance((String)callback.getAuthzId()));
            }
            catch (Exception ex) {
                throw new RuntimeException(ex);
            }
        } else {
            this.username = "default";
            callback.setAuthzId(this.jid.toString());
        }
    }

    private void handleChannelBindingCallback(ChannelBindingCallback callback) {
        if (callback.getRequestedBindType() == AbstractSaslSCRAM.BindType.tls_unique) {
            callback.setBindingData((byte[])this.session.getSessionData("TLS_UNIQUE_ID_KEY"));
        } else if (callback.getRequestedBindType() == AbstractSaslSCRAM.BindType.tls_server_end_point) {
            try {
                Certificate cert = (Certificate)this.session.getSessionData("LOCAL_CERTIFICATE_KEY");
                String algo = cert.getPublicKey().getAlgorithm();
                String usealgo = algo.equals("MD5") || algo.equals("SHA-1") ? "SHA-256" : algo;
                MessageDigest md = MessageDigest.getInstance(usealgo);
                byte[] der = cert.getEncoded();
                md.update(der);
                callback.setBindingData(md.digest());
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        if (log.isLoggable(Level.FINEST)) {
            log.log(Level.FINEST, "Channel binding {0}: {1} in session-id {2}", new Object[]{callback.getRequestedBindType(), callback.getBindingData() == null ? "null" : Base64.encode((byte[])callback.getBindingData()), this.session});
        }
    }

    private void fetchCredentials() {
        if (this.credentialsFetched) {
            return;
        }
        try {
            Credentials credentials = this.repo.getCredentials(this.jid, this.username);
            if (credentials == null) {
                this.accountDisabled = true;
            } else {
                String mech = this.mechanismName.endsWith("-PLUS") ? this.mechanismName.substring(0, this.mechanismName.length() - "-PLUS".length()) : this.mechanismName;
                Credentials.Entry entry = credentials.getEntryForMechanism(mech);
                if (entry == null) {
                    entry = credentials.getEntryForMechanism("PLAIN");
                }
                if (entry instanceof ScramCredentialsEntry) {
                    this.credentialsEntry = (ScramCredentialsEntry)entry;
                } else if (entry instanceof PlainCredentialsEntry) {
                    this.credentialsEntry = new ScramCredentialsEntry(mech.replace("SCRAM-", ""), (PlainCredentialsEntry)entry);
                }
                this.accountDisabled = credentials.isAccountDisabled();
            }
        }
        catch (Exception ex) {
            log.log(Level.FINE, "Could not retrieve credentials for user " + this.jid + " with username " + this.username, ex);
        }
        this.credentialsFetched = true;
    }

    private void setJid(BareJID jid) {
        this.jid = jid;
        if (jid != null) {
            this.session.putSessionData("authentication-jid", jid);
        }
    }
}

