package cn.com.duibaboot.ext.autoconfigure.httpclient.ssre;

import cn.com.duiba.boot.utils.AopTargetUtils;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanPostProcessor;

import java.lang.reflect.Field;
import java.util.*;

/**
 * @author liuyao
 */
@Slf4j
public class SsreBeanPostProcessor implements BeanPostProcessor {

    private final Set<String> whitePackages = Sets.newLinkedHashSet();
    private final Map<Class<?>, Class<? extends ClientWrapper>> acceptableClassesMap;
    {
        whitePackages.add("cn.com.duibaboot");
        whitePackages.add("org.springframework");
        whitePackages.add("com.netflix");
        whitePackages.add("org");
        whitePackages.add("io");

        List<ClientWrapper> wrapperClassSet = Lists.newArrayList();
        wrapperClassSet.add(new SsreHttpClientWrapper());
        wrapperClassSet.add(new SsreHttpAsyncClientWrapper());
        wrapperClassSet.add(new SsreRestTemplateWrapper());
        wrapperClassSet.add(new SsreAsyncRestTemplateWrapper());

        Map<Class<?>, Class<? extends ClientWrapper>> acceptableClassesMapTemp = new HashMap<>();
        for(ClientWrapper w : wrapperClassSet) {
            acceptableClassesMapTemp.put(w.acceptableClass(), w.getClass());
        }
        acceptableClassesMap = acceptableClassesMapTemp;
    }

    @Override
    public Object postProcessBeforeInitialization(@NotNull Object bean, @NotNull String beanName) throws BeansException {
        return bean;
    }

    @Override
    public Object postProcessAfterInitialization(@NotNull Object bean, @NotNull String beanName) throws BeansException {
        Object targetBean;//bean有可能被spring aop过，所以这里要先获取targetBean
        try {
            targetBean = AopTargetUtils.getTarget(bean);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
        Class<?> clazz = targetBean.getClass();
        if(!checkPackage(clazz.getPackage())){
            return bean;
        }
        List<Field> fields = getFieldList(clazz);
        for (Field field:fields){
            if (field.isAnnotationPresent(CanAccessInsideNetwork.class)){
                continue;
            }
            Class<?> fieldType = field.getType();
            Class<? extends ClientWrapper> clientWrapperClass = getClientWrapperClass(fieldType);
            if(clientWrapperClass == null){
                continue;
            }

            try {
                field.setAccessible(true);
                Object obj = field.get(targetBean);
                //防止重复代理而造成无限递归
                if(obj == null || clientWrapperClass.isAssignableFrom(obj.getClass())){
                    continue;
                }
                ClientWrapper clientWrapper = clientWrapperClass.newInstance();
                if(clientWrapper.trySetClient(obj)){
                    field.set(targetBean, clientWrapper);
                }
            }catch (Exception e) {
                log.warn("Client代理失败", e);
            }
        }
        return bean;
    }

    private Class<? extends ClientWrapper> getClientWrapperClass(Class<?> clazz){
        if(clazz.isPrimitive()){
            return null;
        }
        for(Map.Entry<Class<?>, Class<? extends ClientWrapper>> entry : acceptableClassesMap.entrySet()){
            Class<?> c = entry.getKey();
            if(c.isAssignableFrom(clazz)){
                return entry.getValue();
            }
        }
        return null;
    }

    public void setWhitePackage(Set<String> whitePackage) {
        this.whitePackages.addAll(whitePackage);
    }

    private List<Field> getFieldList(Class<?> clazz){
        if(null == clazz || Objects.equals(clazz,Object.class)){
            return Collections.emptyList();
        }
        Field[] fields = clazz.getDeclaredFields();
        List<Field> fieldList = Lists.newArrayList(fields);
        Class<?> superClass = clazz.getSuperclass();
        if(superClass.equals(Object.class)){
            return fieldList;
        }
        fieldList.addAll(getFieldList(superClass));
        return fieldList;
    }

    private boolean checkPackage(Package pkg){
        if(pkg==null){
            return false;
        }
        String packageName = pkg.getName();
        for(String whitePackage:whitePackages){
            if(StringUtils.startsWith(packageName,whitePackage)){
                return false;
            }
        }
        return true;
    }

}
