package io.pythagoras.messagebus.core;

import io.pythagoras.messagebus.annotations.MessageHandler;
import io.pythagoras.messagebus.core.config.MessageBusProperties;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

@Component
public class MessageHandlerProvider implements ApplicationContextAware {

    private ApplicationContext context;

    private MessageContractProvider messageContractProvider;

    private MessageFactory messageFactory;

    private HashMap<String, List<IMessageHandler>> classes = new HashMap<>();

    private Boolean enabled;

    public MessageHandlerProvider(MessageContractProvider messageContractProvider, MessageFactory messageFactory, MessageBusProperties properties) {
        this.messageContractProvider = messageContractProvider;
        this.messageFactory = messageFactory;
        this.enabled = properties.isEnabled();
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        this.context = applicationContext;
    }

    @PostConstruct
    public void init() {
        this.load();
    }

    public List<String> getHandlerContractCodes() {
        List<String> codes = new ArrayList<>();
        for(String contractClass : this.classes.keySet()) {
            try {
                IMessageContract message = this.messageFactory.make(Class.forName(contractClass));
                codes.add(message.getCode());
            } catch (ClassNotFoundException e) {
                continue;
            }
        }

        return codes;
    }

    public IMessageHandler getMessageHandlerForCodeAndVersion(String code, Integer version) {
        Class klass = messageContractProvider.get(code, version);
        if(!(classes.containsKey(klass.getName()))) {
            throw new HandleMessageFailureException();
        }
        List<IMessageHandler> handlers = classes.get(klass.getName());

        // Make sure version is allowed.
        List<IMessageHandler> matchingListeners = new ArrayList<>();
        for (IMessageHandler listener : handlers) {
            if (listener.allowedVersion(version)) {
                matchingListeners.add(listener);
            }
        }

        // Throw Exception if not EXACTLY one allowed listener is found
        if (matchingListeners.size() != 1) {
            throw new HandleMessageFailureException("There can only be one listener per contract and version.");
        }

        return matchingListeners.get(0);
    }

    private void load() throws MessageBusInitializationException {

        if(!enabled) {
            return;
        }

        for (Object o : this.context.getBeansWithAnnotation(MessageHandler.class).values()) {

            // Get the class
            Class<IMessageHandler> cl = (Class<IMessageHandler>) o.getClass();

            if (!(IMessageHandler.class.isAssignableFrom(cl))) {
                throw new MessageBusInitializationException("Class '"+cl.getName()+"' must implement IMessageHandler.");
            }

            // Make sure it can be built
            IMessageHandler handler = this.context.getBean(cl);

            // Add it to the registry
            String allowedClass = handler.getAllowedClass().getName();
            this.classes.putIfAbsent(allowedClass, new ArrayList<>());
            this.classes.get(handler.getAllowedClass().getName()).add(handler);
        }
    }

}
