package io.pythagoras.messagebus.core;

import io.pythagoras.messagebus.annotations.MessageContract;
import io.pythagoras.messagebus.core.config.MessageBusProperties;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.annotation.ClassPathScanningCandidateComponentProvider;
import org.springframework.core.type.filter.AnnotationTypeFilter;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;
import java.util.*;

@Component
public class MessageContractProvider<A extends IMessageContract> implements ApplicationContextAware {

    private MessageFactory<A> messageFactory;

    private ApplicationContext context;

    private HashMap<String, HashMap<Integer, Class<IMessageContract>>> classes = new HashMap<>();

    private Boolean enabled;

    @Autowired
    public MessageContractProvider(MessageFactory<A> messageFactory, MessageBusProperties properties) {
        this.messageFactory = messageFactory;
        this.enabled = properties.isEnabled();
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        this.context = applicationContext;
    }

    @PostConstruct
    public void init() {
        this.load();
    }

    public boolean has(String code, Integer version) {
        if (!classes.containsKey(code)) {
            return false;
        }
        if (!classes.get(code).containsKey(version)) {
            return false;
        }
        return true;
    }

    public List<String> getCodeList() {
        List<String> codes = new ArrayList<>();
        for (String code : classes.keySet()) {
            if (!(codes.contains(code))) {
                codes.add(code);
            }
        }
        return codes;
    }

    public Class<IMessageContract> get(String code, Integer version) throws MessageConversionException {
        if (!has(code, version)) {
            throw new MessageConversionException();
        }
        return classes.get(code).get(version);
    }

    private void add(Class<IMessageContract> klass) throws MessageConversionException {
        IMessageContract obj = messageFactory.make(klass);
        classes.putIfAbsent(obj.getCode(), new HashMap<>());
        classes.get(obj.getCode()).putIfAbsent(obj.getVersion(), klass);
    }


    private void load() {

        if(!enabled) {
            return;
        }

        for (Object o : this.context.getBeansWithAnnotation(MessageContract.class).values()) {

            // Get the class
            Class<IMessageContract> cl = (Class<IMessageContract>) o.getClass();

            if (!(IMessageContract.class.isAssignableFrom(cl))) {
                throw new MessageBusInitializationException("Class '"+cl.getName()+"' must implement IMessageContract.");
            }

            // Make sure it can be built
            IMessageContract message = messageFactory.make(cl);

            // Add it to the registry
            add(cl);
        }

    }


}
