package cn.com.duibaboot.ext.autoconfigure.perftest.datasource;

import cn.com.duiba.boot.utils.AopTargetUtils;
import cn.com.duibaboot.ext.autoconfigure.core.SpecifiedBeanPostProcessor;
import cn.com.duibaboot.ext.autoconfigure.perftest.core.PerfTestFootMarker;
import cn.com.duiba.boot.perftest.PerfTestUtils;
import com.google.common.base.Throwables;
import com.zaxxer.hikari.HikariDataSource;
import io.shardingjdbc.core.jdbc.adapter.AbstractDataSourceAdapter;
import io.shardingjdbc.core.jdbc.core.ShardingContext;
import io.shardingjdbc.core.jdbc.core.datasource.MasterSlaveDataSource;
import io.shardingjdbc.core.jdbc.core.datasource.ShardingDataSource;
import io.shardingjdbc.core.rule.MasterSlaveRule;
import org.apache.commons.dbcp2.BasicDataSource;
import org.apache.shardingsphere.shardingjdbc.jdbc.core.datasource.EncryptDataSource;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.core.env.Environment;
import org.springframework.util.ReflectionUtils;

import javax.annotation.Resource;
import javax.sql.DataSource;
import java.lang.reflect.Field;
import java.util.Map;

/**
 * Created by guoyanfei .
 * 2022/2/22 .
 */
public class PerfTestDataSourcePostProcessor implements SpecifiedBeanPostProcessor<DataSource> {

    @Resource
    private Environment environment;

    @Resource
    private ApplicationContext applicationContext;

    private PerfTestFootMarker perfTestFootMarker;

    @Override
    public Class<DataSource> getBeanType() {
        return DataSource.class;
    }

    @Override
    public Object postProcessBeforeInitialization(DataSource bean, String beanName) throws BeansException {
        return bean;
    }

    @Override
    public Object postProcessAfterInitialization(DataSource bean, String beanName) throws BeansException {
        // 当前不是压测容器
        if (!PerfTestUtils.isPerfTestEnv()) {
            return bean;
        }
        if (isBeanInstanceOfShardingJdbc2DataSource((DataSource) bean, beanName)) {
            return bean;
        }
        if (isBeanInstanceOfShardingSphere4DataSource((DataSource) bean, beanName)) {
            return bean;
        }
        if (!(bean instanceof PerfTestRoutingDataSource)) {
            try {
                Class.forName("org.apache.commons.dbcp2.BasicDataSource");
                if (bean instanceof BasicDataSource) {
                    PerfTestRoutingDataSource ts = new PerfTestRoutingDataSourceForDbcp2((BasicDataSource) bean, environment, getPerfTestFootMarker());
                    ts.afterPropertiesSet();
                    bean = ts;
                    return bean;
                }
            } catch (Exception e) {
                //Ignore
            }

            try {
                Class.forName("com.zaxxer.hikari.HikariDataSource");
                if (bean instanceof HikariDataSource) {
                    PerfTestRoutingDataSource ts = new PerfTestRoutingDataSourceForHikari((HikariDataSource) bean, environment, getPerfTestFootMarker());
                    ts.afterPropertiesSet();
                    bean = ts;
                    return bean;
                }
            } catch (Exception e) {
                //Ignore
            }

            throw new IllegalStateException("数据源必须定义为HikariDataSource或者BasicDataSource，否则无法支持压测，如果你需要临时迁移数据库，请暂时把spring-boot-starter-perftest包去掉，并在迁移完成后加回来");
        }

        return bean;
    }

    @Override
    public int getOrder() {
        return 0;
    }

    private PerfTestFootMarker getPerfTestFootMarker() {
        if (perfTestFootMarker == null) {
            perfTestFootMarker = applicationContext.getBean(PerfTestFootMarker.class);
        }
        return perfTestFootMarker;
    }

    /**
     * 判断Bean是否是EncryptDataSource，如果是的话，检测内部的数据源是否为PerfTestRoutingDataSource的数据源，如果不是，则报错提示要求声明为spring 的 bean（内部bean会先创建，故肯定会被提前转为PerfTestRoutingDataSource）
     *
     * @param bean
     * @param beanName
     * @return
     */
    private static boolean isBeanInstanceOfShardingSphere4DataSource(DataSource bean, String beanName) {
        try {
            Class.forName("org.apache.shardingsphere.shardingjdbc.jdbc.core.datasource.EncryptDataSource");
        } catch (ClassNotFoundException e) {
            return false;
        }

        if (!(bean instanceof EncryptDataSource)) {
            return false;
        }

        EncryptDataSource ds = (EncryptDataSource) bean;
        try {
            ds = AopTargetUtils.getTarget(ds);
        } catch (Exception e) {
            throw Throwables.propagate(e);
        }
        Field field = ReflectionUtils.findField(EncryptDataSource.class, "dataSourceMap");
        field.setAccessible(true);
        Map<String, DataSource> dataSourceMap = (Map<String, DataSource>) ReflectionUtils.getField(field, ds);
        for (Map.Entry<String, DataSource> entry : dataSourceMap.entrySet()) {
            DataSource innerDs = entry.getValue();
            if (!(innerDs instanceof PerfTestRoutingDataSource)) {
                throw new IllegalStateException("请把的EncryptDataSource数据源内部使用的数据源注册为spring的bean，以让线上压测框架有机会处理内部的数据源（内部数据源必须为dbcp2/Hikari）");
            }
        }
        return true;
    }

    //判断bean是否是sharding-jdbc的datasource,如果是的话，检测内部的数据源是否为PerfTestRoutingDataSource的数据源，如果不是，则报错提示要求声明为spring 的 bean（内部bean会先创建，故肯定会被提前转为PerfTestRoutingDataSource）
    private static boolean isBeanInstanceOfShardingJdbc2DataSource(DataSource bean, String beanName) {
        try {
            Class.forName("io.shardingjdbc.core.jdbc.adapter.AbstractDataSourceAdapter");
        } catch (ClassNotFoundException e) {
            return false;
        }

        if (!(bean instanceof AbstractDataSourceAdapter)) {
            return false;
        }

        //如果是sharding-jdbc的数据源，做些额外校验
        if (bean instanceof ShardingDataSource) {
            ShardingDataSource ds = (ShardingDataSource) bean;
            try {
                ds = AopTargetUtils.getTarget(ds);
            } catch (Exception e) {
                throw Throwables.propagate(e);
            }
            Field field = ReflectionUtils.findField(ShardingDataSource.class, "shardingContext");
            field.setAccessible(true);
            ShardingContext shardingContext = (ShardingContext) ReflectionUtils.getField(field, ds);
            Map<String, DataSource> dataSourceMap = shardingContext.getShardingRule().getDataSourceMap();
            for (Map.Entry<String, DataSource> entry : dataSourceMap.entrySet()) {
                DataSource innerDs = entry.getValue();
                if (innerDs instanceof MasterSlaveDataSource) {
                    processMasterSlaveDataSource((MasterSlaveDataSource) innerDs, beanName);
                } else if (!(innerDs instanceof PerfTestRoutingDataSource)) {
                    onShardingJdbcError(beanName);
                }
            }
        } else if (bean instanceof MasterSlaveDataSource) {
            processMasterSlaveDataSource((MasterSlaveDataSource) bean, beanName);
        } else {
            throw new IllegalStateException("[NOTIFYME]sharding jdbc新增的数据源暂时不支持，如遇到此问题，请联系架构组添加支持");
        }

        return true;
    }

    private static void processMasterSlaveDataSource(DataSource ds, String beanName) {//方法签名不使用MasterSlaveDataSource是为了防止报出找不到类的错误
        MasterSlaveDataSource ds1 = (MasterSlaveDataSource) ds;
        MasterSlaveRule oriRule = ds1.getMasterSlaveRule();
        DataSource oriMd = oriRule.getMasterDataSource();

        boolean throwException = false;
        if (!(oriMd instanceof PerfTestRoutingDataSource)) {
            throwException = true;
        }
        if (!throwException) {
            Map<String, DataSource> slaveDataSourceMap = oriRule.getSlaveDataSourceMap();
            for (Map.Entry<String, DataSource> entry : slaveDataSourceMap.entrySet()) {
                DataSource innerDs = entry.getValue();
                if (!(innerDs instanceof PerfTestRoutingDataSource)) {
                    throwException = true;
                    break;
                }
            }
        }

        if (throwException) {
            onShardingJdbcError(beanName);
        }
    }

    private static void onShardingJdbcError(String beanName) {
        throw new IllegalStateException("请把id为[" + beanName + "]的sharding-jdbc数据源内部使用的数据源注册为spring的bean，以让线上压测框架有机会处理内部的数据源（内部数据源必须为dbcp2）");
    }

}
