package cn.com.duibaboot.ext.autoconfigure.perftest.datasource;

import cn.com.duiba.boot.perftest.PerfTestUtils;
import cn.com.duiba.boot.utils.AopTargetUtils;
import cn.com.duibaboot.ext.autoconfigure.core.SpecifiedBeanPostProcessor;
import cn.com.duibaboot.ext.autoconfigure.perftest.core.PerfTestFootMarker;
import com.google.common.base.Throwables;
import com.zaxxer.hikari.HikariDataSource;
import org.apache.commons.dbcp2.BasicDataSource;
import org.apache.shardingsphere.driver.jdbc.core.datasource.ShardingSphereDataSource;
import org.apache.shardingsphere.mode.manager.ContextManager;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.core.env.Environment;
import org.springframework.util.ReflectionUtils;

import javax.annotation.Resource;
import javax.sql.DataSource;
import java.lang.reflect.Field;
import java.util.Map;

/**
 * Created by guoyanfei .
 * 2022/2/22 .
 */
public class PerfTestDataSourcePostProcessor implements SpecifiedBeanPostProcessor<DataSource> {

    @Resource
    private Environment environment;

    @Resource
    private ApplicationContext applicationContext;

    private PerfTestFootMarker perfTestFootMarker;

    @Override
    public Class<DataSource> getBeanType() {
        return DataSource.class;
    }

    @Override
    public Object postProcessBeforeInitialization(DataSource bean, String beanName) throws BeansException {
        return bean;
    }

    @Override
    public Object postProcessAfterInitialization(DataSource bean, String beanName) throws BeansException {
        // 当前不是压测容器
        if (!PerfTestUtils.isPerfTestEnv()) {
            return bean;
        }
        if (processShardingJdbc5DataSource(bean, beanName)) {
            return bean;
        }
        if (!(bean instanceof PerfTestRoutingDataSource)) {
            try {
                Class.forName("org.apache.commons.dbcp2.BasicDataSource");
                if (bean instanceof BasicDataSource) {
                    PerfTestRoutingDataSource ts = new PerfTestRoutingDataSourceForDbcp2((BasicDataSource) bean, environment, getPerfTestFootMarker());
                    ts.afterPropertiesSet();
                    bean = ts;
                    return bean;
                }
            } catch (Exception e) {
                //Ignore
            }

            try {
                Class.forName("com.zaxxer.hikari.HikariDataSource");
                if (bean instanceof HikariDataSource) {
                    PerfTestRoutingDataSource ts = new PerfTestRoutingDataSourceForHikari((HikariDataSource) bean, environment, getPerfTestFootMarker());
                    ts.afterPropertiesSet();
                    bean = ts;
                    return bean;
                }
            } catch (Exception e) {
                //Ignore
            }

            throw new IllegalStateException("数据源必须定义为HikariDataSource或者BasicDataSource，否则无法支持压测，如果你需要临时迁移数据库，请暂时把spring-boot-starter-perftest包去掉，并在迁移完成后加回来");
        }

        return bean;
    }

    @Override
    public int getOrder() {
        return 0;
    }

    private PerfTestFootMarker getPerfTestFootMarker() {
        if (perfTestFootMarker == null) {
            perfTestFootMarker = applicationContext.getBean(PerfTestFootMarker.class);
        }
        return perfTestFootMarker;
    }

    private boolean processShardingJdbc5DataSource(DataSource bean, String beanName) {
        try {
            Class.forName("org.apache.shardingsphere.driver.jdbc.core.datasource.ShardingSphereDataSource");
        } catch (ClassNotFoundException e) {
            return false;
        }

        if (!(bean instanceof org.apache.shardingsphere.driver.jdbc.adapter.AbstractDataSourceAdapter)) {
            return false;
        }

        //如果是sharding-jdbc的数据源，做些额外校验
        if (bean instanceof ShardingSphereDataSource) {
            ShardingSphereDataSource ds = (ShardingSphereDataSource) bean;
            try {
                ds = AopTargetUtils.getTarget(ds);
            } catch (Exception e) {
                throw Throwables.propagate(e);
            }
            Field field = ReflectionUtils.findField(ShardingSphereDataSource.class, "contextManager");
            field.setAccessible(true);
            ContextManager contextManager = (ContextManager) ReflectionUtils.getField(field, ds);
            Map<String, DataSource> dataSourceMap = contextManager.getDataSourceMap(((ShardingSphereDataSource) bean).getSchemaName());
            for (Map.Entry<String, DataSource> entry : dataSourceMap.entrySet()) {
                DataSource innerDs = entry.getValue();
                if (!(innerDs instanceof PerfTestRoutingDataSource)) {
                    if (innerDs instanceof HikariDataSource) {
                        // ShardingSphereDataSource内部的HikariDataSource不是Spring的Bean，所以需要在此处进行包装，并且包装后替换原来的DataSource
                        PerfTestRoutingDataSource ts = new PerfTestRoutingDataSourceForHikari((HikariDataSource) innerDs, environment, getPerfTestFootMarker());
                        ts.afterPropertiesSet();
                        dataSourceMap.put(entry.getKey(), ts);
                    } else {
                        throw new IllegalStateException("压测暂不支持 HikariDataSource 以外的DataSource用于ShardingSphere的 `spring.shardingsphere.datasource.${schema}.type`");
                    }
                }
            }
        } else {
            throw new IllegalStateException("[NOTIFYME]ShardingSphere目前只支持ShardingSphereDataSource的压测，新增的数据源暂时不支持，如需支持，请联系中间件组");
        }

        return true;
    }

}
