package net.hasor.neta.handler.ssl;

import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.Objects;
import java.util.Set;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.TrustManagerFactory;
import net.hasor.cobble.ArrayUtils;
import net.hasor.cobble.StringUtils;
import net.hasor.cobble.logging.Logger;
import net.hasor.neta.channel.PipeContext;

/* loaded from: input_file:net/hasor/neta/handler/ssl/JdkSslContext.class */
public class JdkSslContext extends SslContextBasic {
    private static final Logger logger = Logger.getLogger(JdkSslContext.class);
    protected static final String PROTOCOL = "TLS";
    private static final String[] DEFAULT_PROTOCOLS;
    private static final Set<String> SUPPORTED_CIPHERS;
    private static final String[] DEFAULT_CIPHERS;
    private static final String[] DEFAULT_CIPHERS_NON_TLSV13;
    private static final Set<String> SUPPORTED_CIPHERS_NON_TLSV13;

    public JdkSslContext(long j, PipeContext pipeContext, SslConfig sslConfig, boolean z) throws Exception {
        super(j, sslConfig, pipeContext, z);
    }

    @Override // net.hasor.neta.handler.ssl.SslContext
    public String getApplicationProtocol() {
        String applicationProtocol;
        SSLEngine unwrap = getEngine().unwrap();
        if (unwrap == null || !JdkAlpnSslUtils.supportsAlpn() || (applicationProtocol = JdkAlpnSslUtils.getApplicationProtocol(unwrap)) == null || applicationProtocol.isEmpty()) {
            return null;
        }
        return applicationProtocol;
    }

    @Override // net.hasor.neta.handler.ssl.SslContextBasic
    protected SSLContext createSSLContext() throws GeneralSecurityException, IOException {
        KeyStore createKeyStore = createKeyStore();
        KeyManagerFactory createKeyManagerFactory = createKeyManagerFactory(createKeyStore);
        TrustManagerFactory trustManagers = getTrustManagers(createKeyStore);
        if (this.sslLog) {
            logger.info("ssl(" + this.channelID + ") create JdkSslContext.");
        }
        SSLContext sSLContext = SSLContext.getInstance(PROTOCOL);
        sSLContext.init(createKeyManagerFactory.getKeyManagers(), trustManagers.getTrustManagers(), null);
        return sSLContext;
    }

    @Override // net.hasor.neta.handler.ssl.SslContextBasic
    protected SSLEngine configSslEngine(SSLContext sSLContext, SSLEngine sSLEngine) {
        String[] filterCipherSuites;
        if (this.sslLog) {
            logger.info("ssl(" + this.channelID + ") create SSLEngine on " + (isClient() ? "client" : "server"));
        }
        sSLEngine.setUseClientMode(isClient());
        String[] protocols = this.sslConfig.getProtocols();
        String[] strArr = protocols == null ? DEFAULT_PROTOCOLS : protocols;
        sSLEngine.setEnabledProtocols(strArr);
        String[] ciphers = this.sslConfig.getCiphers();
        if (isTlsV13Supported(strArr)) {
            filterCipherSuites = filterCipherSuites(Arrays.asList(ciphers == null ? DEFAULT_CIPHERS : ciphers), DEFAULT_CIPHERS, SUPPORTED_CIPHERS);
        } else {
            filterCipherSuites = filterCipherSuites(Arrays.asList(ciphers == null ? DEFAULT_CIPHERS_NON_TLSV13 : ciphers), DEFAULT_CIPHERS_NON_TLSV13, SUPPORTED_CIPHERS_NON_TLSV13);
        }
        if (this.sslLog) {
            logger.info("ssl(" + this.channelID + ") enabled CipherSuites [" + StringUtils.join(filterCipherSuites, ", ") + "]");
        }
        sSLEngine.setEnabledCipherSuites(filterCipherSuites);
        SslClientAuth clientAuth = this.sslConfig.getClientAuth();
        if (isServer() && clientAuth != null) {
            if (this.sslLog) {
                logger.info("ssl(" + this.channelID + ") clientAuth = " + clientAuth);
            }
            switch (clientAuth) {
                case OPTIONAL:
                    sSLEngine.setWantClientAuth(true);
                    break;
                case REQUIRE:
                    sSLEngine.setNeedClientAuth(true);
                    break;
                case NONE:
                    break;
                default:
                    throw new Error("Unknown auth " + clientAuth);
            }
        }
        String[] appProtocol = this.sslConfig.getAppProtocol();
        JdkAlpnSslUtils.setApplicationProtocols(sSLEngine, appProtocol == null ? ArrayUtils.EMPTY_STRING_ARRAY : appProtocol);
        SslAppProtocolSelector appProtocolSelector = this.sslConfig.getAppProtocolSelector();
        if (appProtocolSelector != null) {
            JdkAlpnSslUtils.setHandshakeApplicationProtocolSelector(sSLEngine, (sSLEngine2, list) -> {
                return appProtocolSelector.selector(this.soContext.findChannel(this.channelID), sSLEngine2, list);
            });
        }
        return sSLEngine;
    }

    private static String[] defaultProtocols(SSLContext sSLContext, SSLEngine sSLEngine) {
        String[] protocols = sSLContext.getDefaultSSLParameters().getProtocols();
        LinkedHashSet linkedHashSet = new LinkedHashSet(protocols.length);
        Collections.addAll(linkedHashSet, protocols);
        LinkedHashSet linkedHashSet2 = new LinkedHashSet();
        addIfSupported(linkedHashSet, linkedHashSet2, SslProtocol.TLS_v1_3, SslProtocol.TLS_v1_2, SslProtocol.TLS_v1_1, SslProtocol.TLS_v1);
        return !linkedHashSet2.isEmpty() ? (String[]) linkedHashSet2.toArray(ArrayUtils.EMPTY_STRING_ARRAY) : sSLEngine.getEnabledProtocols();
    }

    private static Set<String> supportedCiphers(SSLEngine sSLEngine) {
        String[] supportedCipherSuites = sSLEngine.getSupportedCipherSuites();
        LinkedHashSet linkedHashSet = new LinkedHashSet(supportedCipherSuites.length);
        for (String str : supportedCipherSuites) {
            linkedHashSet.add(str);
            if (str.startsWith("SSL_")) {
                String str2 = "TLS_" + str.substring("SSL_".length());
                try {
                    sSLEngine.setEnabledCipherSuites(new String[]{str2});
                    linkedHashSet.add(str2);
                } catch (IllegalArgumentException e) {
                }
            }
        }
        return linkedHashSet;
    }

    private static String[] defaultCiphers(SSLEngine sSLEngine, Set<String> set) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        addIfSupported(set, linkedHashSet, SslUtils.DEFAULT_CIPHER_SUITES);
        if (linkedHashSet.isEmpty()) {
            for (String str : sSLEngine.getEnabledCipherSuites()) {
                if (!str.startsWith("SSL_") && !str.contains("_RC4_")) {
                    linkedHashSet.add(str);
                }
            }
        }
        return (String[]) linkedHashSet.toArray(ArrayUtils.EMPTY_STRING_ARRAY);
    }

    private static void addIfSupported(Set<String> set, Set<String> set2, String... strArr) {
        for (String str : strArr) {
            if (set.contains(str)) {
                set2.add(str);
            }
        }
    }

    private static boolean isTlsV13Supported(String[] strArr) {
        for (String str : strArr) {
            if (SslProtocol.TLS_v1_3.equals(str)) {
                return true;
            }
        }
        return false;
    }

    private static String[] filterCipherSuites(Iterable<String> iterable, String[] strArr, Set<String> set) {
        ArrayList arrayList;
        String next;
        Objects.requireNonNull(strArr, "defaultCiphers");
        Objects.requireNonNull(set, "supportedCiphers");
        if (iterable == null) {
            arrayList = new ArrayList(strArr.length);
            iterable = Arrays.asList(strArr);
        } else {
            arrayList = new ArrayList(set.size());
        }
        Iterator<String> it = iterable.iterator();
        while (it.hasNext() && (next = it.next()) != null) {
            if (set.contains(next)) {
                arrayList.add(next);
            }
        }
        return (String[]) arrayList.toArray(new String[0]);
    }

    static {
        try {
            SSLContext sSLContext = SSLContext.getInstance(PROTOCOL);
            sSLContext.init(null, null, null);
            SSLEngine createSSLEngine = sSLContext.createSSLEngine();
            DEFAULT_PROTOCOLS = defaultProtocols(sSLContext, createSSLEngine);
            SUPPORTED_CIPHERS = Collections.unmodifiableSet(supportedCiphers(createSSLEngine));
            DEFAULT_CIPHERS = defaultCiphers(createSSLEngine, SUPPORTED_CIPHERS);
            LinkedHashSet linkedHashSet = new LinkedHashSet(Arrays.asList(DEFAULT_CIPHERS));
            linkedHashSet.removeAll(Arrays.asList(SslUtils.DEFAULT_TLSV13_CIPHER_SUITES));
            DEFAULT_CIPHERS_NON_TLSV13 = (String[]) linkedHashSet.toArray(ArrayUtils.EMPTY_STRING_ARRAY);
            LinkedHashSet linkedHashSet2 = new LinkedHashSet(SUPPORTED_CIPHERS);
            linkedHashSet2.removeAll(Arrays.asList(SslUtils.DEFAULT_TLSV13_CIPHER_SUITES));
            SUPPORTED_CIPHERS_NON_TLSV13 = Collections.unmodifiableSet(linkedHashSet2);
            if (logger.isDebugEnabled()) {
                logger.debug("Default protocols (JDK): " + StringUtils.join(DEFAULT_PROTOCOLS, ","));
                logger.debug("Default cipher suites (JDK): " + StringUtils.join(DEFAULT_CIPHERS, ","));
            }
        } catch (Exception e) {
            throw new Error("failed to initialize the default SSL context", e);
        }
    }
}
