/*
 * Copyright 2012 The Netty Project
 *
 * The Netty Project licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 *   https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */
package io.netty.handler.codec.http;

import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.MultiThreadIoEventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.nio.NioIoHandler;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.CodecException;
import io.netty.handler.codec.PrematureChannelClosureException;
import io.netty.util.CharsetUtil;
import io.netty.util.NetUtil;
import org.junit.jupiter.api.Test;

import java.net.InetSocketAddress;
import java.util.concurrent.CountDownLatch;

import static io.netty.util.ReferenceCountUtil.release;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

public class HttpClientCodecTest {

    private static final String EMPTY_RESPONSE = "HTTP/1.0 200 OK\r\nContent-Length: 0\r\n\r\n";
    private static final String RESPONSE = "HTTP/1.0 200 OK\r\n" + "Date: Fri, 31 Dec 1999 23:59:59 GMT\r\n" +
            "Content-Type: text/html\r\n" + "Content-Length: 28\r\n" + "\r\n"
            + "<html><body></body></html>\r\n";
    private static final String INCOMPLETE_CHUNKED_RESPONSE = "HTTP/1.1 200 OK\r\n" + "Content-Type: text/plain\r\n" +
            "Transfer-Encoding: chunked\r\n" + "\r\n" +
            "5\r\n" + "first\r\n" + "6\r\n" + "second\r\n" + "0\r\n";
    private static final String CHUNKED_RESPONSE = INCOMPLETE_CHUNKED_RESPONSE + "\r\n";

    @Test
    public void testConnectWithResponseContent() {
        HttpClientCodec codec = new HttpClientCodec(4096, 8192, 8192, true);
        EmbeddedChannel ch = new EmbeddedChannel(codec);

        sendRequestAndReadResponse(ch, HttpMethod.CONNECT, RESPONSE);
        ch.finish();
    }

    @Test
    public void testFailsNotOnRequestResponseChunked() {
        HttpClientCodec codec = new HttpClientCodec(4096, 8192, 8192, true);
        EmbeddedChannel ch = new EmbeddedChannel(codec);

        sendRequestAndReadResponse(ch, HttpMethod.GET, CHUNKED_RESPONSE);
        ch.finish();
    }

    @Test
    public void testFailsOnMissingResponse() {
        HttpClientCodec codec = new HttpClientCodec(4096, 8192, 8192, true);
        EmbeddedChannel ch = new EmbeddedChannel(codec);

        assertTrue(ch.writeOutbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET,
                "http://localhost/")));
        ByteBuf buffer = ch.readOutbound();
        assertNotNull(buffer);
        buffer.release();
        try {
            ch.finish();
            fail();
        } catch (CodecException e) {
            assertTrue(e instanceof PrematureChannelClosureException);
        }
    }

    @Test
    public void testFailsOnIncompleteChunkedResponse() {
        HttpClientCodec codec = new HttpClientCodec(4096, 8192, 8192, true);
        EmbeddedChannel ch = new EmbeddedChannel(codec);

        ch.writeOutbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost/"));
        ByteBuf buffer = ch.readOutbound();
        assertNotNull(buffer);
        buffer.release();
        assertNull(ch.readInbound());
        ch.writeInbound(Unpooled.copiedBuffer(INCOMPLETE_CHUNKED_RESPONSE, CharsetUtil.ISO_8859_1));
        assertInstanceOf(HttpResponse.class, ch.readInbound());
        ((HttpContent) ch.readInbound()).release(); // Chunk 'first'
        ((HttpContent) ch.readInbound()).release(); // Chunk 'second'
        assertNull(ch.readInbound());

        try {
            ch.finish();
            fail();
        } catch (CodecException e) {
            assertTrue(e instanceof PrematureChannelClosureException);
        }
    }

    @Test
    public void testServerCloseSocketInputProvidesData() throws InterruptedException {
        ServerBootstrap sb = new ServerBootstrap();
        Bootstrap cb = new Bootstrap();
        final CountDownLatch serverChannelLatch = new CountDownLatch(1);
        final CountDownLatch responseReceivedLatch = new CountDownLatch(1);
        try {
            sb.group(new MultiThreadIoEventLoopGroup(2, NioIoHandler.newFactory()));
            sb.channel(NioServerSocketChannel.class);
            sb.childHandler(new ChannelInitializer<Channel>() {
                @Override
                protected void initChannel(Channel ch) throws Exception {
                    // Don't use the HttpServerCodec, because we don't want to have content-length or anything added.
                    ch.pipeline().addLast(new HttpRequestDecoder(4096, 8192, 8192, true));
                    ch.pipeline().addLast(new HttpObjectAggregator(4096));
                    ch.pipeline().addLast(new SimpleChannelInboundHandler<FullHttpRequest>() {
                        @Override
                        protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest msg) {
                            // This is just a simple demo...don't block in IO
                            assertTrue(ctx.channel() instanceof SocketChannel);
                            final SocketChannel sChannel = (SocketChannel) ctx.channel();
                            /**
                             * The point of this test is to not add any content-length or content-encoding headers
                             * and the client should still handle this.
                             * See <a href="https://tools.ietf.org/html/rfc7230#section-3.3.3">RFC 7230, 3.3.3</a>.
                             */
                            sChannel.writeAndFlush(Unpooled.wrappedBuffer(("HTTP/1.0 200 OK\r\n" +
                            "Date: Fri, 31 Dec 1999 23:59:59 GMT\r\n" +
                            "Content-Type: text/html\r\n\r\n").getBytes(CharsetUtil.ISO_8859_1)))
                                    .addListener(new ChannelFutureListener() {
                                @Override
                                public void operationComplete(ChannelFuture future) throws Exception {
                                    assertTrue(future.isSuccess());
                                    sChannel.writeAndFlush(Unpooled.wrappedBuffer(
                                            "<html><body>hello half closed!</body></html>\r\n"
                                            .getBytes(CharsetUtil.ISO_8859_1)))
                                            .addListener(new ChannelFutureListener() {
                                        @Override
                                        public void operationComplete(ChannelFuture future) throws Exception {
                                            assertTrue(future.isSuccess());
                                            sChannel.shutdownOutput();
                                        }
                                    });
                                }
                            });
                        }
                    });
                    serverChannelLatch.countDown();
                }
            });

            cb.group(new MultiThreadIoEventLoopGroup(1, NioIoHandler.newFactory()));
            cb.channel(NioSocketChannel.class);
            cb.option(ChannelOption.ALLOW_HALF_CLOSURE, true);
            cb.handler(new ChannelInitializer<Channel>() {
                @Override
                protected void initChannel(Channel ch) throws Exception {
                    ch.pipeline().addLast(new HttpClientCodec(4096, 8192, 8192, true, true));
                    ch.pipeline().addLast(new HttpObjectAggregator(4096));
                    ch.pipeline().addLast(new SimpleChannelInboundHandler<FullHttpResponse>() {
                        @Override
                        protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) {
                            responseReceivedLatch.countDown();
                        }
                    });
                }
            });

            Channel serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel();
            int port = ((InetSocketAddress) serverChannel.localAddress()).getPort();

            ChannelFuture ccf = cb.connect(new InetSocketAddress(NetUtil.LOCALHOST, port));
            assertTrue(ccf.awaitUninterruptibly().isSuccess());
            Channel clientChannel = ccf.channel();
            assertTrue(serverChannelLatch.await(5, SECONDS));
            clientChannel.writeAndFlush(new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"));
            assertTrue(responseReceivedLatch.await(5, SECONDS));
        } finally {
            sb.config().group().shutdownGracefully().syncUninterruptibly();
            sb.config().childGroup().shutdownGracefully().syncUninterruptibly();
            cb.config().group().shutdownGracefully().syncUninterruptibly();
        }
    }

    @Test
    public void testContinueParsingAfterConnect() throws Exception {
        testAfterConnect(true);
    }

    @Test
    public void testPassThroughAfterConnect() throws Exception {
        testAfterConnect(false);
    }

    private static void testAfterConnect(final boolean parseAfterConnect) throws Exception {
        EmbeddedChannel ch = new EmbeddedChannel(new HttpClientCodec(4096, 8192, 8192, true, true, parseAfterConnect));

        Consumer connectResponseConsumer = new Consumer();
        sendRequestAndReadResponse(ch, HttpMethod.CONNECT, EMPTY_RESPONSE, connectResponseConsumer);
        assertTrue(connectResponseConsumer.getReceivedCount() > 0, "No connect response messages received.");
        Consumer responseConsumer = new Consumer() {
            @Override
            void accept(Object object) {
                if (parseAfterConnect) {
                    assertInstanceOf(HttpObject.class, object);
                } else {
                    assertThat(object).isNotInstanceOf(HttpObject.class);
                }
            }
        };
        sendRequestAndReadResponse(ch, HttpMethod.GET, RESPONSE, responseConsumer);
        assertTrue(responseConsumer.getReceivedCount() > 0, "No response messages received.");
        assertFalse(ch.finish(), "Channel finish failed.");
    }

    private static void sendRequestAndReadResponse(EmbeddedChannel ch, HttpMethod httpMethod, String response) {
        sendRequestAndReadResponse(ch, httpMethod, response, new Consumer());
    }

    private static void sendRequestAndReadResponse(EmbeddedChannel ch, HttpMethod httpMethod, String response,
                                                   Consumer responseConsumer) {
        assertTrue(ch.writeOutbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, httpMethod, "http://localhost/")),
                "Channel outbound write failed.");
        assertTrue(ch.writeInbound(Unpooled.copiedBuffer(response, CharsetUtil.ISO_8859_1)),
                "Channel inbound write failed.");

        for (;;) {
            Object msg = ch.readOutbound();
            if (msg == null) {
                break;
            }
            release(msg);
        }
        for (;;) {
            Object msg = ch.readInbound();
            if (msg == null) {
                break;
            }
            responseConsumer.onResponse(msg);
            release(msg);
        }
    }

    private static class Consumer {

        private int receivedCount;

        final void onResponse(Object object) {
            receivedCount++;
            accept(object);
        }

        void accept(Object object) {
            // Default noop.
        }

        int getReceivedCount() {
            return receivedCount;
        }
    }

    @Test
    public void testDecodesFinalResponseAfterSwitchingProtocols() {
        String SWITCHING_PROTOCOLS_RESPONSE = "HTTP/1.1 101 Switching Protocols\r\n" +
                "Connection: Upgrade\r\n" +
                "Upgrade: TLS/1.2, HTTP/1.1\r\n\r\n";

        HttpClientCodec codec = new HttpClientCodec(4096, 8192, 8192, true);
        EmbeddedChannel ch = new EmbeddedChannel(codec, new HttpObjectAggregator(1024));

        HttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost/");
        request.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE);
        request.headers().set(HttpHeaderNames.UPGRADE, "TLS/1.2");
        assertTrue(ch.writeOutbound(request), "Channel outbound write failed.");

        assertTrue(ch.writeInbound(Unpooled.copiedBuffer(SWITCHING_PROTOCOLS_RESPONSE, CharsetUtil.ISO_8859_1)),
                "Channel inbound write failed.");
        Object switchingProtocolsResponse = ch.readInbound();
        assertNotNull(switchingProtocolsResponse, "No response received");
        assertInstanceOf(FullHttpResponse.class, switchingProtocolsResponse);
        ((FullHttpResponse) switchingProtocolsResponse).release();

        assertTrue(ch.writeInbound(Unpooled.copiedBuffer(RESPONSE, CharsetUtil.ISO_8859_1)),
                "Channel inbound write failed");
        Object finalResponse = ch.readInbound();
        assertNotNull(finalResponse, "No response received");
        assertInstanceOf(FullHttpResponse.class, finalResponse);
        ((FullHttpResponse) finalResponse).release();
        assertTrue(ch.finishAndReleaseAll(), "Channel finish failed");
    }

    @Test
    public void testWebSocket00Response() {
        byte[] data = ("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" +
                "Upgrade: WebSocket\r\n" +
                "Connection: Upgrade\r\n" +
                "Sec-WebSocket-Origin: http://localhost:8080\r\n" +
                "Sec-WebSocket-Location: ws://localhost/some/path\r\n" +
                "\r\n" +
                "1234567812345678").getBytes();
        EmbeddedChannel ch = new EmbeddedChannel(new HttpClientCodec());
        assertTrue(ch.writeInbound(Unpooled.wrappedBuffer(data)));

        HttpResponse res = ch.readInbound();
        assertSame(HttpVersion.HTTP_1_1, res.protocolVersion());
        assertEquals(HttpResponseStatus.SWITCHING_PROTOCOLS, res.status());
        HttpContent content = ch.readInbound();
        assertEquals(16, content.content().readableBytes());
        content.release();

        assertFalse(ch.finish());

        assertNull(ch.readInbound());
    }

    @Test
    public void testWebDavResponse() {
        byte[] data = ("HTTP/1.1 102 Processing\r\n" +
                       "Status-URI: Status-URI:http://status.com; 404\r\n" +
                       "\r\n" +
                       "1234567812345678").getBytes();
        EmbeddedChannel ch = new EmbeddedChannel(new HttpClientCodec());
        assertTrue(ch.writeInbound(Unpooled.wrappedBuffer(data)));

        HttpResponse res = ch.readInbound();
        assertSame(HttpVersion.HTTP_1_1, res.protocolVersion());
        assertEquals(HttpResponseStatus.PROCESSING, res.status());
        HttpContent content = ch.readInbound();
        // HTTP 102 is not allowed to have content.
        assertEquals(0, content.content().readableBytes());
        content.release();

        assertFalse(ch.finish());
    }

    @Test
    public void testInformationalResponseKeepsPairsInSync() {
        byte[] data = ("HTTP/1.1 102 Processing\r\n" +
                "Status-URI: Status-URI:http://status.com; 404\r\n" +
                "\r\n").getBytes();
        byte[] data2 = ("HTTP/1.1 200 OK\r\n" +
                "Content-Length: 8\r\n" +
                "\r\n" +
                "12345678").getBytes();
        EmbeddedChannel ch = new EmbeddedChannel(new HttpClientCodec());
        assertTrue(ch.writeOutbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.HEAD, "/")));
        ByteBuf buffer = ch.readOutbound();
        buffer.release();
        assertNull(ch.readOutbound());
        assertTrue(ch.writeInbound(Unpooled.wrappedBuffer(data)));
        HttpResponse res = ch.readInbound();
        assertSame(HttpVersion.HTTP_1_1, res.protocolVersion());
        assertEquals(HttpResponseStatus.PROCESSING, res.status());
        HttpContent content = ch.readInbound();
        // HTTP 102 is not allowed to have content.
        assertEquals(0, content.content().readableBytes());
        assertInstanceOf(LastHttpContent.class, content);
        content.release();

        assertTrue(ch.writeOutbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/")));
        buffer = ch.readOutbound();
        buffer.release();
        assertNull(ch.readOutbound());
        assertTrue(ch.writeInbound(Unpooled.wrappedBuffer(data2)));

        res = ch.readInbound();
        assertSame(HttpVersion.HTTP_1_1, res.protocolVersion());
        assertEquals(HttpResponseStatus.OK, res.status());
        content = ch.readInbound();
        // HTTP 200 has content.
        assertEquals(8, content.content().readableBytes());
        assertInstanceOf(LastHttpContent.class, content);
        content.release();

        assertFalse(ch.finish());
    }

    @Test
    public void testMultipleResponses() {
        String response = "HTTP/1.1 200 OK\r\n" +
                "Content-Length: 0\r\n\r\n";

        HttpClientCodec codec = new HttpClientCodec(4096, 8192, 8192, true);
        EmbeddedChannel ch = new EmbeddedChannel(codec, new HttpObjectAggregator(1024));

        HttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost/");
        assertTrue(ch.writeOutbound(request));

        assertTrue(ch.writeInbound(Unpooled.copiedBuffer(response, CharsetUtil.UTF_8)));
        assertTrue(ch.writeInbound(Unpooled.copiedBuffer(response, CharsetUtil.UTF_8)));
        FullHttpResponse resp = ch.readInbound();
        assertTrue(resp.decoderResult().isSuccess());
        resp.release();

        resp = ch.readInbound();
        assertTrue(resp.decoderResult().isSuccess());
        resp.release();
        assertTrue(ch.finishAndReleaseAll());
    }

    @Test
    public void testWriteThroughAfterUpgrade() {
        HttpClientCodec codec = new HttpClientCodec();
        EmbeddedChannel ch = new EmbeddedChannel(codec);
        codec.prepareUpgradeFrom(null);

        ByteBuf buffer = ch.alloc().buffer();
        assertEquals(1, buffer.refCnt());
        assertTrue(ch.writeOutbound(buffer));
        // buffer should pass through unchanged
        assertSame(buffer, ch.<ByteBuf>readOutbound());
        assertEquals(1, buffer.refCnt());

        buffer.release();
    }
}
