/*
 * Decompiled with CFR 0.152.
 */
package io.scalecube.security.tokens.jwt;

import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import io.jsonwebtoken.JwsHeader;
import io.jsonwebtoken.LocatorAdapter;
import io.scalecube.security.tokens.jwt.JwkInfoList;
import io.scalecube.security.tokens.jwt.JwtUnavailableException;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.math.BigInteger;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.net.http.HttpTimeoutException;
import java.security.Key;
import java.security.KeyFactory;
import java.security.PublicKey;
import java.security.spec.RSAPublicKeySpec;
import java.time.Duration;
import java.util.Base64;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantLock;

public class JwksKeyLocator
extends LocatorAdapter<Key> {
    private static final ObjectMapper OBJECT_MAPPER = JwksKeyLocator.newObjectMapper();
    private final URI jwksUri;
    private final Duration connectTimeout;
    private final Duration requestTimeout;
    private final int keyTtl;
    private final Map<String, CachedKey> keyResolutions = new ConcurrentHashMap<String, CachedKey>();
    private final ReentrantLock cleanupLock = new ReentrantLock();

    private JwksKeyLocator(Builder builder) {
        this.jwksUri = Objects.requireNonNull(builder.jwksUri, "jwksUri");
        this.connectTimeout = Objects.requireNonNull(builder.connectTimeout, "connectTimeout");
        this.requestTimeout = Objects.requireNonNull(builder.requestTimeout, "requestTimeout");
        this.keyTtl = builder.keyTtl;
    }

    public static Builder builder() {
        return new Builder();
    }

    protected Key locate(JwsHeader header) {
        try {
            Key key = this.keyResolutions.computeIfAbsent(header.getKeyId(), kid -> {
                PublicKey key = JwksKeyLocator.findKeyById(this.computeKeyList(), kid);
                if (key == null) {
                    throw new JwtUnavailableException("Cannot find key by kid: " + kid);
                }
                return new CachedKey(key, System.currentTimeMillis() + (long)this.keyTtl);
            }).key();
            return key;
        }
        finally {
            this.tryCleanup();
        }
    }

    private JwkInfoList computeKeyList() {
        HttpResponse<InputStream> httpResponse;
        try {
            httpResponse = HttpClient.newBuilder().connectTimeout(this.connectTimeout).build().send(HttpRequest.newBuilder(this.jwksUri).GET().timeout(this.requestTimeout).build(), HttpResponse.BodyHandlers.ofInputStream());
        }
        catch (HttpTimeoutException e) {
            throw new JwtUnavailableException("Failed to retrive jwk keys", e);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        }
        int statusCode = httpResponse.statusCode();
        if (statusCode != 200) {
            throw new RuntimeException("Failed to retrive jwk keys, status: " + statusCode);
        }
        return JwksKeyLocator.toJwkInfoList(httpResponse.body());
    }

    private static JwkInfoList toJwkInfoList(InputStream stream) {
        JwkInfoList jwkInfoList;
        BufferedInputStream inputStream = new BufferedInputStream(stream);
        try {
            jwkInfoList = (JwkInfoList)OBJECT_MAPPER.readValue((InputStream)inputStream, JwkInfoList.class);
        }
        catch (Throwable throwable) {
            try {
                try {
                    inputStream.close();
                }
                catch (Throwable throwable2) {
                    throwable.addSuppressed(throwable2);
                }
                throw throwable;
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        inputStream.close();
        return jwkInfoList;
    }

    private static PublicKey findKeyById(JwkInfoList jwkInfoList, String kid) {
        if (jwkInfoList.keys() != null) {
            return jwkInfoList.keys().stream().filter(jwkInfo -> kid.equals(jwkInfo.kid())).map(jwkInfo -> JwksKeyLocator.toRsaPublicKey(jwkInfo.modulus(), jwkInfo.exponent())).findFirst().orElse(null);
        }
        return null;
    }

    private static PublicKey toRsaPublicKey(String n, String e) {
        Base64.Decoder decoder = Base64.getUrlDecoder();
        BigInteger modulus = new BigInteger(1, decoder.decode(n));
        BigInteger exponent = new BigInteger(1, decoder.decode(e));
        RSAPublicKeySpec keySpec = new RSAPublicKeySpec(modulus, exponent);
        try {
            return KeyFactory.getInstance("RSA").generatePublic(keySpec);
        }
        catch (Exception ex) {
            throw new RuntimeException(e);
        }
    }

    private static ObjectMapper newObjectMapper() {
        ObjectMapper mapper = new ObjectMapper();
        mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
        mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
        mapper.configure(DeserializationFeature.READ_UNKNOWN_ENUM_VALUES_AS_NULL, true);
        mapper.configure(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS, false);
        mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
        return mapper;
    }

    private void tryCleanup() {
        if (this.cleanupLock.tryLock()) {
            long now = System.currentTimeMillis();
            try {
                this.keyResolutions.entrySet().removeIf(entry -> ((CachedKey)entry.getValue()).hasExpired(now));
            }
            finally {
                this.cleanupLock.unlock();
            }
        }
    }

    public static class Builder {
        private URI jwksUri;
        private Duration connectTimeout = Duration.ofSeconds(10L);
        private Duration requestTimeout = Duration.ofSeconds(10L);
        private int keyTtl = 60000;

        private Builder() {
        }

        public Builder jwksUri(String jwksUri) {
            this.jwksUri = URI.create(jwksUri);
            return this;
        }

        public Builder connectTimeout(Duration connectTimeout) {
            this.connectTimeout = connectTimeout;
            return this;
        }

        public Builder requestTimeout(Duration requestTimeout) {
            this.requestTimeout = requestTimeout;
            return this;
        }

        public Builder keyTtl(int keyTtl) {
            this.keyTtl = keyTtl;
            return this;
        }

        public JwksKeyLocator build() {
            return new JwksKeyLocator(this);
        }
    }

    private record CachedKey(Key key, long expirationDeadline) {
        boolean hasExpired(long now) {
            return now >= this.expirationDeadline;
        }
    }
}

