package cn.com.duiba.duiba.base.service.api.mybatis.plugins.handler.command.impl;

import cn.com.duiba.duiba.base.service.api.mybatis.plugins.bean.DbEncryptionConstant;
import cn.com.duiba.duiba.base.service.api.mybatis.plugins.config.DbEncryptColumnRule;
import cn.com.duiba.duiba.base.service.api.mybatis.plugins.handler.command.SqlCommandHandler;
import org.apache.commons.lang3.StringUtils;

import java.util.LinkedHashMap;
import java.util.Map;

/**
 * @author lizhi
 * @date 2024/9/24 15:50
 */
public abstract class AbstractSqlCommandHandler implements SqlCommandHandler {

    protected static final String LOWER_SET = "set";
    protected static final String LOWER_WHERE = "where";
    protected static final String LOWER_FROM = "from ";
    protected static final String LOWER_UPDATE = "update ";
    protected static final String LOWER_INTO = "into ";
    protected static final String SPACE = " ";
    protected static final String LINE_SEPARATOR = System.lineSeparator();
    protected static final String LOWER_AND = "and";
    protected static final String LOWER_IN = "in";
    protected static final String LOWER_CASE = "case";
    protected static final String LOWER_WHEN = "when";

    /**
     * 获取非tb_开头的表名开始index
     * @param sql sql字符串
     * @return index
     */
    protected abstract int getSpecialTbIndex(String sql);

    protected int getSpecialTbIndex(String sql, String separator) {
        int index = getIndexIgnoreCase(sql, separator);
        return index + separator.length();
    }
    
    protected String getTableName(String sql, String... separators) {
        if (sql == null || sql.isEmpty()) {
            return null;
        }
        int tbIndex = getTbIndex(sql);
        if (tbIndex < 0) {
            return null;
        }
        int index = getTableNameEndIndex(sql, tbIndex, separators);
        if (index < 0) {
            return sql.substring(tbIndex);
        }
        if (tbIndex >= index) {
            return null;
        }
        String substring = sql.substring(tbIndex, index);
        return substring.trim();
    }

    private int getTbIndex(String sql) {
        int tbIndex = sql.indexOf("tb_");
        if (tbIndex >= 0) {
            return tbIndex;
        }
        return getSpecialTbIndex(sql);
    }

    private int getTableNameEndIndex(String sql, int fromIndex, String... separators) {
        int minIndex = -1;
        for (String separator : separators) {
            int index = getIndexIgnoreCase(sql, separator, fromIndex);
            if (minIndex < 0) {
                minIndex = index;
            }
            if (index > 0 && index < minIndex) {
                minIndex = index;
            }
        }
        return minIndex;
    }

    private static int getIndexIgnoreCase(String str, String separator, int fromIndex) {
        return str.toLowerCase().indexOf(separator.toLowerCase(), fromIndex);
    }

    protected static int getIndexIgnoreCase(String str, String separator) {
        return str.toLowerCase().indexOf(separator.toLowerCase());
    }

    @Override
    public Map<Integer, DbEncryptColumnRule> getNeedEncryptParamIndexRule(String sql, Map<String, DbEncryptColumnRule> columns) {
        Map<Integer, DbEncryptColumnRule> indexRuleMap = new LinkedHashMap<>();
        if (StringUtils.isBlank(sql)) {
            return indexRuleMap;
        }
        int index = putIndexRuleBySet(sql, columns, indexRuleMap);
        putIndexRuleByWhere(sql, index, columns, indexRuleMap);
        return indexRuleMap;
    }
    
    protected static int putIndexRuleBySet(String sql, Map<String, DbEncryptColumnRule> columns, Map<Integer, DbEncryptColumnRule> indexRuleMap) {
        int index = 0;
        int setIndex = getIndexIgnoreCase(sql, LOWER_SET);
        if (setIndex < 0) {
            return index;
        }
        int whereIndex = getIndexIgnoreCase(sql, LOWER_WHERE);
        if (whereIndex < 0) {
            return index;
        }
        String setSql = sql.substring(setIndex + LOWER_SET.length(), whereIndex);
        String[] split = setSql.split(DbEncryptionConstant.COMMA);
        for (String s : split) {
            int questionMarkIndex = s.indexOf(DbEncryptionConstant.QUESTION_MARK);
            if (questionMarkIndex < 0) {
                continue;
            }
            index++;
            int addIndex = putSetValue(s, index, questionMarkIndex, columns, indexRuleMap);
            index = index + addIndex;
        }
        return index;
    }
    
    private static int putSetValue(String s, int index, int questionMarkIndex, Map<String, DbEncryptColumnRule> columns, Map<Integer, DbEncryptColumnRule> indexRuleMap) {
        int equalSignIndex = s.indexOf(DbEncryptionConstant.SEPARATOR_EQUAL_SIGN);
        if (equalSignIndex < 0) {
            return 0;
        }
        int caseIndex = getIndexIgnoreCase(s, LOWER_CASE);
        if (caseIndex < 0) {
            putEqualSign(s, index, questionMarkIndex, equalSignIndex, columns, indexRuleMap);
            return 0;
        }
        putCase(s, index, questionMarkIndex, equalSignIndex, caseIndex, columns, indexRuleMap);
        int questionCount = getQuestionCount(s);
        return questionCount - 1;
    }
    
    private static int getQuestionCount(String str) {
        int count = 0;
        char question = DbEncryptionConstant.QUESTION_MARK.toCharArray()[0];
        for (char c : str.toCharArray()) {
            if (c == question) {
                count++;
            }
        }
        return count;
    }

    protected static int putIndexRuleByWhere(String sql, int index, Map<String, DbEncryptColumnRule> columns, Map<Integer, DbEncryptColumnRule> indexRuleMap) {
        int whereIndex = getIndexIgnoreCase(sql, LOWER_WHERE);
        if (whereIndex < 0) {
            return index;
        }
        String sqlAfterWhere = sql.substring(whereIndex + LOWER_WHERE.length());
        String[] split = sqlAfterWhere.split(LOWER_AND);
        for (String s : split) {
            int questionMarkIndex = s.indexOf(DbEncryptionConstant.QUESTION_MARK);
            if (questionMarkIndex < 0) {
                continue;
            }
            index++;
            int equalSignIndex = s.indexOf(DbEncryptionConstant.SEPARATOR_EQUAL_SIGN);
            if (equalSignIndex >= 0) {
                putEqualSign(s, index, questionMarkIndex, equalSignIndex, columns, indexRuleMap);
                continue;
            }
            int inIndex = getIndexIgnoreCase(s, LOWER_IN);
            if (inIndex >= 0) {
                index = putIn(s, index, questionMarkIndex, inIndex, columns, indexRuleMap);
            }
        }
        return index;
    }
    
    private static int putIn(String s, int index, int questionMarkIndex, int inIndex, Map<String, DbEncryptColumnRule> columns, Map<Integer, DbEncryptColumnRule> indexRuleMap) {
        if (questionMarkIndex < inIndex) {
            return index;
        }
        s = s.substring(0, s.indexOf(DbEncryptionConstant.RIGHT_BRACKET));
        String columnName = s.substring(0, inIndex).trim();
        DbEncryptColumnRule rule = columns.get(columnName);
        if (rule == null) {
            return index;
        }
        while (questionMarkIndex >= 0) {
            indexRuleMap.put(index - 1, rule);
            questionMarkIndex = s.indexOf(DbEncryptionConstant.QUESTION_MARK, questionMarkIndex + 1);
            if (questionMarkIndex >= 0) {
                index++;
            }
        }
        return index;
    }
    
    private static void putEqualSign(String s, int index, int questionMarkIndex, int equalSignIndex, Map<String, DbEncryptColumnRule> columns, Map<Integer, DbEncryptColumnRule> indexRuleMap) {
        if (questionMarkIndex < equalSignIndex) {
            return;
        }
        String columnName = s.substring(0, equalSignIndex).trim();
        DbEncryptColumnRule rule = columns.get(columnName);
        if (rule == null) {
            return;
        }
        indexRuleMap.put(index - 1, rule);
    }
    
    private static void putCase(String s, int index, int questionMarkIndex, int equalSignIndex, int caseIndex, Map<String, DbEncryptColumnRule> columns, Map<Integer, DbEncryptColumnRule> indexRuleMap) {
        if (questionMarkIndex < equalSignIndex) {
            return;
        }
        String columnName = s.substring(0, equalSignIndex).trim();
        DbEncryptColumnRule rule = columns.get(columnName);
        String whenColumnName = s.substring(caseIndex + LOWER_CASE.length(), getIndexIgnoreCase(s, LOWER_WHEN)).trim();
        DbEncryptColumnRule whenRule = columns.get(whenColumnName);
        if (rule == null && whenRule == null) {
            return;
        }
        boolean isWhen = true;
        while (questionMarkIndex >= 0) {
            DbEncryptColumnRule r = isWhen ? whenRule : rule;
            if (r != null) {
                indexRuleMap.put(index - 1, r);
            }
            questionMarkIndex = s.indexOf(DbEncryptionConstant.QUESTION_MARK, questionMarkIndex + 1);
            if (questionMarkIndex >= 0) {
                index++;
            }
            isWhen = !isWhen;
        }
    }
    
    
}
