package org.tio.http.mcp.server;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tio.http.common.HttpRequest;
import org.tio.http.common.HttpResponse;
import org.tio.http.common.HttpResponseStatus;
import org.tio.http.jsonrpc.JsonRpcMessage;
import org.tio.http.jsonrpc.JsonRpcNotification;
import org.tio.http.jsonrpc.JsonRpcRequest;
import org.tio.http.jsonrpc.JsonRpcResponse;
import org.tio.http.mcp.schema.McpCallToolRequest;
import org.tio.http.mcp.schema.McpCallToolResult;
import org.tio.http.mcp.schema.McpImplementation;
import org.tio.http.mcp.schema.McpInitializeRequest;
import org.tio.http.mcp.schema.McpInitializeResult;
import org.tio.http.mcp.schema.McpListToolsResult;
import org.tio.http.mcp.schema.McpRoot;
import org.tio.http.mcp.schema.McpSchema;
import org.tio.http.mcp.schema.McpServerCapabilities;
import org.tio.http.mcp.schema.McpTool;
import org.tio.http.sse.SseEmitter;
import org.tio.utils.hutool.StrUtil;
import org.tio.utils.json.JsonUtil;

/* loaded from: input_file:org/tio/http/mcp/server/McpServer.class */
public class McpServer {
    public static final String ENDPOINT_EVENT_TYPE = "endpoint";
    public static final String DEFAULT_SSE_ENDPOINT = "/sse";
    public static final String DEFAULT_MESSAGE_ENDPOINT = "/sse/message";
    private final ConcurrentHashMap<String, McpServerSession> sessions;
    private final String sseEndpoint;
    private final String messageEndpoint;
    private McpImplementation serverInfo;
    private McpServerCapabilities serverCapabilities;
    private final List<McpToolSpecification> tools;
    private final Map<String, McpResourceSpecification> resources;
    private final Map<String, McpResourceTemplateSpecification> resourceTemplates;
    private final Map<String, McpPromptSpecification> prompts;
    private final List<BiConsumer<McpServerSession, List<McpRoot>>> rootsChangeHandlers;
    private static final Logger log = LoggerFactory.getLogger(McpServer.class);
    private static final McpImplementation DEFAULT_SERVER_INFO = new McpImplementation("mcp-server", "1.0.0");

    public McpServer() {
        this(DEFAULT_SSE_ENDPOINT, DEFAULT_MESSAGE_ENDPOINT);
    }

    public McpServer(String str, String str2) {
        this.sessions = new ConcurrentHashMap<>();
        this.serverInfo = DEFAULT_SERVER_INFO;
        this.tools = new ArrayList();
        this.resources = new HashMap();
        this.resourceTemplates = new HashMap();
        this.prompts = new HashMap();
        this.rootsChangeHandlers = new ArrayList();
        this.sseEndpoint = (String) Objects.requireNonNull(str, "SSE endpoint must not be null");
        this.messageEndpoint = (String) Objects.requireNonNull(str2, "Message endpoint must not be null");
    }

    public McpServer serverInfo(McpImplementation mcpImplementation) {
        Objects.requireNonNull(mcpImplementation, "Server info must not be null");
        this.serverInfo = mcpImplementation;
        return this;
    }

    public McpServer serverInfo(String str, String str2) {
        if (StrUtil.isBlank(str)) {
            throw new IllegalArgumentException("Server info name must not be blank");
        }
        if (StrUtil.isBlank(str2)) {
            throw new IllegalArgumentException("Server info version must not be blank");
        }
        this.serverInfo = new McpImplementation(str, str2);
        return this;
    }

    public McpServer capabilities(McpServerCapabilities mcpServerCapabilities) {
        Objects.requireNonNull(mcpServerCapabilities, "Server capabilities must not be null");
        this.serverCapabilities = mcpServerCapabilities;
        return this;
    }

    public McpServer tool(McpTool mcpTool, BiFunction<McpServerSession, Map<String, Object>, McpCallToolResult> biFunction) {
        Objects.requireNonNull(mcpTool, "Tool must not be null");
        Objects.requireNonNull(biFunction, "Handler must not be null");
        this.tools.add(new McpToolSpecification(mcpTool, biFunction));
        return this;
    }

    public McpServer tools(List<McpToolSpecification> list) {
        Objects.requireNonNull(list, "Tool handlers list must not be null");
        this.tools.addAll(list);
        return this;
    }

    public McpServer tools(McpToolSpecification... mcpToolSpecificationArr) {
        Objects.requireNonNull(mcpToolSpecificationArr, "Tool handlers list must not be null");
        this.tools.addAll(Arrays.asList(mcpToolSpecificationArr));
        return this;
    }

    public McpServer resources(Map<String, McpResourceSpecification> map) {
        Objects.requireNonNull(map, "Resource handlers map must not be null");
        this.resources.putAll(map);
        return this;
    }

    public McpServer resources(List<McpResourceSpecification> list) {
        Objects.requireNonNull(list, "Resource handlers list must not be null");
        for (McpResourceSpecification mcpResourceSpecification : list) {
            this.resources.put(mcpResourceSpecification.getResource().getUri(), mcpResourceSpecification);
        }
        return this;
    }

    public McpServer resources(McpResourceSpecification... mcpResourceSpecificationArr) {
        Objects.requireNonNull(mcpResourceSpecificationArr, "Resource handlers list must not be null");
        for (McpResourceSpecification mcpResourceSpecification : mcpResourceSpecificationArr) {
            this.resources.put(mcpResourceSpecification.getResource().getUri(), mcpResourceSpecification);
        }
        return this;
    }

    public McpServer resourceTemplates(List<McpResourceTemplateSpecification> list) {
        Objects.requireNonNull(list, "Resource templates must not be null");
        for (McpResourceTemplateSpecification mcpResourceTemplateSpecification : list) {
            this.resourceTemplates.put(mcpResourceTemplateSpecification.getResource().getUriTemplate(), mcpResourceTemplateSpecification);
        }
        return this;
    }

    public McpServer resourceTemplates(McpResourceTemplateSpecification... mcpResourceTemplateSpecificationArr) {
        Objects.requireNonNull(mcpResourceTemplateSpecificationArr, "Resource templates must not be null");
        for (McpResourceTemplateSpecification mcpResourceTemplateSpecification : mcpResourceTemplateSpecificationArr) {
            this.resourceTemplates.put(mcpResourceTemplateSpecification.getResource().getUriTemplate(), mcpResourceTemplateSpecification);
        }
        return this;
    }

    public McpServer prompts(Map<String, McpPromptSpecification> map) {
        Objects.requireNonNull(map, "Prompts map must not be null");
        this.prompts.putAll(map);
        return this;
    }

    public McpServer prompts(List<McpPromptSpecification> list) {
        Objects.requireNonNull(list, "Prompts list must not be null");
        for (McpPromptSpecification mcpPromptSpecification : list) {
            this.prompts.put(mcpPromptSpecification.getPrompt().getName(), mcpPromptSpecification);
        }
        return this;
    }

    public McpServer prompts(McpPromptSpecification... mcpPromptSpecificationArr) {
        Objects.requireNonNull(mcpPromptSpecificationArr, "Prompts list must not be null");
        for (McpPromptSpecification mcpPromptSpecification : mcpPromptSpecificationArr) {
            this.prompts.put(mcpPromptSpecification.getPrompt().getName(), mcpPromptSpecification);
        }
        return this;
    }

    public McpServer rootsChangeHandler(BiConsumer<McpServerSession, List<McpRoot>> biConsumer) {
        Objects.requireNonNull(biConsumer, "Consumer must not be null");
        this.rootsChangeHandlers.add(biConsumer);
        return this;
    }

    public McpServer rootsChangeHandlers(List<BiConsumer<McpServerSession, List<McpRoot>>> list) {
        Objects.requireNonNull(list, "Handlers list must not be null");
        this.rootsChangeHandlers.addAll(list);
        return this;
    }

    @SafeVarargs
    public final McpServer rootsChangeHandlers(BiConsumer<McpServerSession, List<McpRoot>>... biConsumerArr) {
        Objects.requireNonNull(biConsumerArr, "Handlers list must not be null");
        return rootsChangeHandlers(Arrays.asList(biConsumerArr));
    }

    public HttpResponse sseEndpoint(HttpRequest httpRequest) {
        HttpResponse httpResponse = new HttpResponse(httpRequest);
        SseEmitter emitter = SseEmitter.getEmitter(httpRequest, httpResponse);
        httpResponse.setPacketListener((channelContext, packet, z) -> {
            if (z) {
                String nanoId = StrUtil.getNanoId();
                this.sessions.put(nanoId, new McpServerSession(nanoId, emitter));
                emitter.send(ENDPOINT_EVENT_TYPE, this.messageEndpoint + "?sessionId=" + nanoId);
            }
        });
        return httpResponse;
    }

    public HttpResponse sseMessageEndpoint(HttpRequest httpRequest) {
        String param = httpRequest.getParam("sessionId");
        HttpResponse httpResponse = new HttpResponse(httpRequest);
        if (StrUtil.isBlank(param)) {
            httpResponse.setStatus(HttpResponseStatus.C404);
            httpResponse.setBody("Session ID missing in message endpoint".getBytes());
            return httpResponse;
        }
        McpServerSession mcpServerSession = this.sessions.get(param);
        if (mcpServerSession == null) {
            httpResponse.setStatus(HttpResponseStatus.C404);
            httpResponse.setBody("Session is null".getBytes());
            log.error("Session is null sessionId:{}", param);
            return httpResponse;
        }
        JsonRpcMessage deserializeJsonRpcMessage = deserializeJsonRpcMessage(httpRequest.getBody());
        if (deserializeJsonRpcMessage instanceof JsonRpcRequest) {
            mcpServerSession.sendMessage(handleIncomingRequest(mcpServerSession, (JsonRpcRequest) deserializeJsonRpcMessage));
        } else if (deserializeJsonRpcMessage instanceof JsonRpcNotification) {
            System.out.println((JsonRpcNotification) deserializeJsonRpcMessage);
        }
        return httpResponse;
    }

    public void sendHeartbeat() {
        Iterator<McpServerSession> it = this.sessions.values().iterator();
        while (it.hasNext()) {
            it.next().sendHeartbeat();
        }
    }

    public String getMessageEndpoint() {
        return this.messageEndpoint;
    }

    public String getSseEndpoint() {
        return this.sseEndpoint;
    }

    private JsonRpcResponse handleIncomingRequest(McpServerSession mcpServerSession, JsonRpcRequest jsonRpcRequest) {
        String method = jsonRpcRequest.getMethod();
        if (McpSchema.METHOD_INITIALIZE.equals(method)) {
            McpInitializeRequest mcpInitializeRequest = (McpInitializeRequest) JsonUtil.convertValue(jsonRpcRequest.getParams(), McpInitializeRequest.class);
            McpInitializeResult mcpInitializeResult = new McpInitializeResult();
            mcpInitializeResult.setProtocolVersion(mcpInitializeRequest.getProtocolVersion());
            mcpInitializeResult.setCapabilities(this.serverCapabilities);
            mcpInitializeResult.setServerInfo(this.serverInfo);
            JsonRpcResponse jsonRpcResponse = new JsonRpcResponse();
            jsonRpcResponse.setJsonrpc(McpSchema.JSONRPC_VERSION);
            jsonRpcResponse.setId(jsonRpcRequest.getId());
            jsonRpcResponse.setResult(mcpInitializeResult);
            return jsonRpcResponse;
        }
        if (McpSchema.METHOD_PING.equals(method)) {
            JsonRpcResponse jsonRpcResponse2 = new JsonRpcResponse();
            jsonRpcResponse2.setJsonrpc(McpSchema.JSONRPC_VERSION);
            jsonRpcResponse2.setId(jsonRpcRequest.getId());
            jsonRpcResponse2.setResult(Collections.emptyMap());
            return jsonRpcResponse2;
        }
        if (McpSchema.METHOD_TOOLS_LIST.equals(method)) {
            JsonRpcResponse jsonRpcResponse3 = new JsonRpcResponse();
            jsonRpcResponse3.setJsonrpc(McpSchema.JSONRPC_VERSION);
            jsonRpcResponse3.setId(jsonRpcRequest.getId());
            McpListToolsResult mcpListToolsResult = new McpListToolsResult();
            ArrayList arrayList = new ArrayList();
            Iterator<McpToolSpecification> it = this.tools.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next().getTool());
            }
            mcpListToolsResult.setTools(arrayList);
            jsonRpcResponse3.setResult(mcpListToolsResult);
            return jsonRpcResponse3;
        }
        if (!McpSchema.METHOD_TOOLS_CALL.equals(method)) {
            return null;
        }
        JsonRpcResponse jsonRpcResponse4 = new JsonRpcResponse();
        jsonRpcResponse4.setJsonrpc(McpSchema.JSONRPC_VERSION);
        jsonRpcResponse4.setId(jsonRpcRequest.getId());
        McpCallToolRequest mcpCallToolRequest = (McpCallToolRequest) JsonUtil.convertValue(jsonRpcRequest.getParams(), McpCallToolRequest.class);
        String name = mcpCallToolRequest.getName();
        McpCallToolResult mcpCallToolResult = null;
        Iterator<McpToolSpecification> it2 = this.tools.iterator();
        while (true) {
            if (!it2.hasNext()) {
                break;
            }
            McpToolSpecification next = it2.next();
            if (next.getTool().getName().equals(name)) {
                mcpCallToolResult = next.getCall().apply(mcpServerSession, getCallToolArguments(mcpCallToolRequest.getArguments()));
                break;
            }
        }
        if (mcpCallToolResult == null) {
            throw new IllegalArgumentException("Cannot find tool with name " + name);
        }
        jsonRpcResponse4.setResult(mcpCallToolResult);
        return jsonRpcResponse4;
    }

    private static JsonRpcMessage deserializeJsonRpcMessage(byte[] bArr) {
        Map map = (Map) JsonUtil.readValue(bArr, Map.class);
        String str = new String(bArr);
        log.debug("Received JSON message: {}", str);
        if (map.containsKey("method") && map.containsKey("id")) {
            return (JsonRpcMessage) JsonUtil.convertValue(map, JsonRpcRequest.class);
        }
        if (map.containsKey("method") && !map.containsKey("id")) {
            return (JsonRpcMessage) JsonUtil.convertValue(map, JsonRpcNotification.class);
        }
        if (map.containsKey("result") || map.containsKey("error")) {
            return (JsonRpcMessage) JsonUtil.convertValue(map, JsonRpcResponse.class);
        }
        throw new IllegalArgumentException("Cannot deserialize JsonRpcMessage: " + str);
    }

    private static Map<String, Object> getCallToolArguments(Object obj) {
        if (obj == null) {
            return null;
        }
        if (obj instanceof Map) {
            return (Map) obj;
        }
        if ((obj instanceof String) && StrUtil.isBlank((String) obj)) {
            return null;
        }
        return (Map) JsonUtil.convertValue(obj, Map.class);
    }
}
