package com.alibaba.hbase.client;

import com.alibaba.lindorm.client.TableService;
import com.alibaba.lindorm.client.core.LindormWideColumnService;
import com.alibaba.lindorm.client.core.meta.LColumn;
import com.alibaba.lindorm.client.core.utils.Bytes;
import com.alibaba.lindorm.client.dml.Aggregate;
import com.alibaba.lindorm.client.dml.ColumnValue;
import com.alibaba.lindorm.client.dml.Condition;
import com.alibaba.lindorm.client.dml.ConditionFactory;
import com.alibaba.lindorm.client.dml.ConditionList;
import com.alibaba.lindorm.client.dml.Row;
import com.alibaba.lindorm.client.exception.LindormException;
import com.alibaba.lindorm.client.schema.DataType;
import com.alibaba.lindorm.client.schema.PrimaryKeySchema;
import com.google.protobuf.ByteString;
import com.google.protobuf.Message;
import com.google.protobuf.RpcCallback;
import com.google.protobuf.RpcController;
import org.apache.commons.lang.exception.ExceptionUtils;
import org.apache.hadoop.hbase.HBaseIOException;
import org.apache.hadoop.hbase.HConstants;
import org.apache.hadoop.hbase.KeyValue;
import org.apache.hadoop.hbase.client.Scan;
import org.apache.hadoop.hbase.client.coprocessor.AggregationHelper;
import org.apache.hadoop.hbase.coprocessor.ColumnInterpreter;

import org.apache.hadoop.hbase.filter.BinaryComparator;
import org.apache.hadoop.hbase.filter.ByteArrayComparable;
import org.apache.hadoop.hbase.filter.CompareFilter;
import org.apache.hadoop.hbase.filter.Filter;
import org.apache.hadoop.hbase.filter.FilterList;
import org.apache.hadoop.hbase.filter.InclusiveStopFilter;
import org.apache.hadoop.hbase.filter.RowFilter;
import org.apache.hadoop.hbase.filter.SingleColumnValueFilter;
import org.apache.hadoop.hbase.protobuf.ProtobufUtil;
import org.apache.hadoop.hbase.protobuf.generated.AggregateProtos;
import org.apache.hadoop.hbase.protobuf.generated.ClientProtos;
import org.apache.hadoop.hbase.protobuf.generated.AggregateProtos.AggregateService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.hadoop.hbase.shaded.protobuf.ResponseConverter;

import java.io.IOException;
import java.util.List;
import java.util.NavigableSet;

import static com.alibaba.lindorm.client.core.LindormWideColumnService.UNIFIED_PK_COLUMN_NAME;
import static com.alibaba.lindorm.client.dml.ConditionFactory.and;
import static com.alibaba.lindorm.client.dml.ConditionFactory.compare;
import static com.alibaba.lindorm.client.dml.ConditionFactory.or;

public class AliHBaseUEAggregateService<T, S, P extends Message, Q extends Message, R extends Message>
    extends AggregateService {

  private static final Logger LOG = LoggerFactory.getLogger(AliHBaseUEAggregateService.class);

  private TableService tableService;

  private String table;

  private static byte[] PK_COLUMN_NAME = LindormWideColumnService.UNIFIED_PK_COLUMN_NAME.getBytes();

  private static byte[] COUNT_ALL = Bytes.toBytes("COUNT(*)");


  public AliHBaseUEAggregateService(TableService tableService, String table) {
    this.tableService = tableService;
    this.table = table;
  }

  @Override
  public void getMax(RpcController rpcController, AggregateProtos.AggregateRequest aggregateRequest,
      RpcCallback<AggregateProtos.AggregateResponse> rpcCallback) {
    AggregateProtos.AggregateResponse response = null;
    try {

      LindormAggregate lindormAggregate = toLindormAggregate(aggregateRequest);
      ColumnValue cv = lindormAggregate.max();

      ColumnInterpreter<T, S, P, Q, R> ci = constructColumnInterpreterFromRequest(aggregateRequest);
      byte[] value = cv.getBinary();
      T max = ci.getValue(lindormAggregate.getCfName(), lindormAggregate.getColName(),
          new KeyValue(HConstants.EMPTY_BYTE_ARRAY, lindormAggregate.getCfName(), lindormAggregate.getColName(),
              value));

      response = AggregateProtos.AggregateResponse.newBuilder().addFirstPart(ci.getProtoForCellType(max).toByteString())
          .build();

    } catch (IOException e) {
      ResponseConverter.setControllerException(rpcController, e);
    }
    if (rpcCallback != null) {
      rpcCallback.run(response);
    }
  }

  @Override
  public void getMin(RpcController rpcController, AggregateProtos.AggregateRequest aggregateRequest,
      RpcCallback<AggregateProtos.AggregateResponse> rpcCallback) {
    AggregateProtos.AggregateResponse response = null;
    try {

      LindormAggregate lindormAggregate = toLindormAggregate(aggregateRequest);
      ColumnValue cv = lindormAggregate.min();

      ColumnInterpreter<T, S, P, Q, R> ci = constructColumnInterpreterFromRequest(aggregateRequest);
      byte[] value = cv.getBinary();
      T min = ci.getValue(lindormAggregate.getCfName(), lindormAggregate.getColName(),
          new KeyValue(HConstants.EMPTY_BYTE_ARRAY, lindormAggregate.getCfName(), lindormAggregate.getColName(),
              value));

      response = AggregateProtos.AggregateResponse.newBuilder().addFirstPart(ci.getProtoForCellType(min).toByteString())
          .build();

    } catch (IOException e) {
      ResponseConverter.setControllerException(rpcController, e);
    }
    if (rpcCallback != null) {
      rpcCallback.run(response);
    }

  }

  @Override
  public void getSum(RpcController rpcController, AggregateProtos.AggregateRequest aggregateRequest,
      RpcCallback<AggregateProtos.AggregateResponse> rpcCallback) {
    AggregateProtos.AggregateResponse response = null;
    try {

      LindormAggregate lindormAggregate = toLindormAggregate(aggregateRequest);
      ColumnValue cv = lindormAggregate.sum();

      ColumnInterpreter<T, S, P, Q, R> ci = constructColumnInterpreterFromRequest(aggregateRequest);

      byte[] value = ElementConvertor.toValueBytes(cv, ci);
      T sum = ci.getValue(lindormAggregate.getCfName(), lindormAggregate.getColName(),
          new KeyValue(HConstants.EMPTY_BYTE_ARRAY, lindormAggregate.getCfName(), lindormAggregate.getColName(),
              value));

      response = AggregateProtos.AggregateResponse.newBuilder().addFirstPart(ci.getProtoForCellType(sum).toByteString())
          .build();

    } catch (IOException e) {
      ResponseConverter.setControllerException(rpcController, e);
    }
    if (rpcCallback != null) {
      rpcCallback.run(response);
    }
  }


  @Override
  public void getRowNum(RpcController rpcController, AggregateProtos.AggregateRequest aggregateRequest,
      RpcCallback<AggregateProtos.AggregateResponse> rpcCallback) {
    AggregateProtos.AggregateResponse response = null;
    try {

      LindormAggregate lindormAggregate = toLindormAggregate(aggregateRequest);
      ColumnValue cv = lindormAggregate.count();

      response = AggregateProtos.AggregateResponse.newBuilder()
          .addFirstPart(ByteString.copyFrom(Bytes.toBytes(cv.getLong().longValue()))).build();

    } catch (IOException e) {
      ExceptionUtils.getFullStackTrace(e);
      ResponseConverter.setControllerException(rpcController, e);
    }
    if (rpcCallback != null) {
      rpcCallback.run(response);
    }
  }

  @Override
  public void getAvg(RpcController rpcController, AggregateProtos.AggregateRequest aggregateRequest,
      RpcCallback<AggregateProtos.AggregateResponse> rpcCallback) {

    AggregateProtos.AggregateResponse response = null;
    try {

      LindormAggregate lindormAggregate = toLindormAggregate(aggregateRequest);
      ColumnValue cv = lindormAggregate.avg();

      ColumnInterpreter<T, S, P, Q, R> ci = constructColumnInterpreterFromRequest(aggregateRequest);

      byte[] value = ElementConvertor.toValueBytes(cv, ci);
      T avg = ci.getValue(lindormAggregate.getCfName(), lindormAggregate.getColName(),
          new KeyValue(HConstants.EMPTY_BYTE_ARRAY, lindormAggregate.getCfName(), lindormAggregate.getColName(),
              value));

      AggregateProtos.AggregateResponse.Builder pair = AggregateProtos.AggregateResponse.newBuilder();
      pair.addFirstPart(ci.getProtoForCellType(avg).toByteString());
      pair.setSecondPart(ByteString.copyFrom(Bytes.toBytes(1L)));
      response = pair.build();

    } catch (IOException e) {
      ResponseConverter.setControllerException(rpcController, e);
    }
    if (rpcCallback != null) {
      rpcCallback.run(response);
    }
  }

  @Override
  public void getStd(RpcController rpcController, AggregateProtos.AggregateRequest aggregateRequest,
      RpcCallback<AggregateProtos.AggregateResponse> rpcCallback) {
    throw new UnsupportedOperationException("GetStd unsupported !");
  }

  @Override
  public void getMedian(RpcController rpcController, AggregateProtos.AggregateRequest aggregateRequest,
      RpcCallback<AggregateProtos.AggregateResponse> rpcCallback) {
    throw new UnsupportedOperationException("GetMedian unsupported !");
  }


  private LindormAggregate toLindormAggregate(AggregateProtos.AggregateRequest aggregateRequest) throws IOException {
    ClientProtos.Scan scan = aggregateRequest.getScan();
    Scan scanner = ProtobufUtil.toScan(scan);

    Condition condition = convertRequestToCondition(scanner);

    Aggregate aggregate = tableService.aggregate().from(table);
    if (condition != null) {
      aggregate.where(condition);
    }
    //allow count (*)
    aggregate.allowFiltering(true);

    if (!aggregateRequest.hasScan()) {
      throw new UnsupportedOperationException("Scan is null !");
    }

    byte[] colFamily = null;
    byte[] qualifier = null;
    if(scanner.getFamilies() != null) {
      if(scanner.getFamilies().length > 1){
        LOG.warn("Only support one family in sum, max, min, avg");
      }
      colFamily = scanner.getFamilies()[0];
      NavigableSet<byte[]> qualifiers = scanner.getFamilyMap().get(colFamily);
      if (qualifiers != null && !qualifiers.isEmpty()) {
        qualifier = qualifiers.pollFirst();
      }
    }

    if (!aggregateRequest.hasInterpreterClassName()) {
      throw new UnsupportedOperationException("Must provide interpreter class");
    }
    DataType interpreterDatype = ElementConvertor.toInterpreterDataType(aggregateRequest.getInterpreterClassName());

    return new LindormAggregate(colFamily, qualifier, interpreterDatype, aggregate);
  }


  public Condition convertRequestToCondition(Scan scan) throws HBaseIOException {
    try {
      ConditionList conditionList = and();

      checkAggregateSupport(scan);

      if (scan.hasFilter()) {
        conditionList.add(covertFilterToCondition(scan.getFilter(), scan.isReversed()));
      }
      Condition startRow = null;
      Condition stopRow = null;
      if (scan != null) {
        if (scan.getStartRow() != null) {
          startRow = compare(PK_COLUMN_NAME, ConditionFactory.CompareOp.GREATER, scan.getStartRow());
        }
        if (scan.getStopRow() != null) {
          stopRow = compare(PK_COLUMN_NAME, ConditionFactory.CompareOp.LESS_OR_EQUAL, scan.getStopRow());
        }
      }

      if (startRow != null) {
        conditionList.add(ConditionFactory.and(startRow));
      }
      if (stopRow != null) {
        conditionList.add(ConditionFactory.and(stopRow));
      }

      List<Condition> conditions = conditionList.getConditions();
      if (conditions == null || conditions.isEmpty()) {
        return null;
      } else {
        return conditions.size() > 1 ? conditionList : conditions.get(0);
      }
    }catch(LindormException e){
      throw new HBaseIOException(e);
    }
  }

  public Condition covertFilterToCondition(Filter filter, boolean reversed) throws HBaseIOException{
    try {
      if (filter instanceof FilterList) {
        return generateConditionByFilterList((FilterList) filter, reversed);
      } else {
        return generateConditionBySingleFilter(filter, reversed);
      }
    }catch (LindormException e){
      throw new HBaseIOException(e);
    }
  }

  private Condition generateConditionByFilterList(FilterList filter, boolean reversed) throws HBaseIOException, LindormException {
    ConditionList conditionList = (filter.getOperator() == FilterList.Operator.MUST_PASS_ALL) ? and() : or();
    List<Filter> filters = filter.getFilters();
    for (Filter subFilter : filters) {
      conditionList.add(covertFilterToCondition(subFilter, reversed));
    }
    return conditionList;
  }

  public Condition generateConditionBySingleFilter(Filter filter, boolean reversed) throws HBaseIOException,
    LindormException {
    if (filter instanceof RowFilter){
      return generateCondition((RowFilter)filter);
    }else if (filter instanceof InclusiveStopFilter){
      return generateCondition((InclusiveStopFilter) filter, reversed);
    }else if (filter instanceof SingleColumnValueFilter){
      return generateCondition((SingleColumnValueFilter)filter);
    }else{
      throw new HBaseIOException("Unsupported filter [type: " + filter.toString() + "]");
    }
  }

  private Condition generateCondition(RowFilter filter) throws HBaseIOException, LindormException {
    LColumn column = new LColumn(getFirstPkSchema(), 0);
    return generateCondition(column, filter.getOperator(), filter.getComparator());
  }

  private Condition generateCondition(LColumn column, CompareFilter.CompareOp operator, ByteArrayComparable comparator)
    throws HBaseIOException, LindormException{
    checkComparator(comparator);
    ConditionFactory.CompareOp compareOp = transformOp(operator);
    return compare(column.getFamilyName(), column.getColumnName(), compareOp, comparator.getValue());
  }

  private void checkComparator(ByteArrayComparable comparator) throws HBaseIOException {
    if (!(comparator instanceof BinaryComparator)) {
      throw new HBaseIOException("Compiling the comparator [class: " +
        (comparator == null ? null : comparator.getClass().getName()) +
        "] is not supported.");
    }
  }

  private Condition generateCondition(InclusiveStopFilter filter, boolean reversed) throws LindormException {
    LColumn column = new LColumn(getFirstPkSchema(), 0);
    ConditionFactory.CompareOp op = !reversed ? ConditionFactory.CompareOp.LESS_OR_EQUAL :
      ConditionFactory.CompareOp.GREATER_OR_EQUAL;
    return compare(column.getFamilyName(), column.getColumnName(), op, filter.getStopRowKey());
  }

  private Condition generateCondition(SingleColumnValueFilter filter) throws HBaseIOException, LindormException {
    ByteArrayComparable comparator = filter.getComparator();
    checkComparator(comparator);
    ConditionFactory.CompareOp compareOp = transformOp(filter.getOperator());
    Condition condition = compare(filter.getFamily(), filter.getQualifier(), compareOp, comparator.getValue());
    return condition;
  }

  public static PrimaryKeySchema getFirstPkSchema() {
    return new PrimaryKeySchema(UNIFIED_PK_COLUMN_NAME, DataType.VARBINARY);
  }

  public static ConditionFactory.CompareOp transformOp(CompareFilter.CompareOp operator) {
    switch (operator) {
      case LESS:
        return ConditionFactory.CompareOp.LESS;
      case LESS_OR_EQUAL:
        return ConditionFactory.CompareOp.LESS_OR_EQUAL;
      case EQUAL:
        return ConditionFactory.CompareOp.EQUAL;
      case NOT_EQUAL:
        return ConditionFactory.CompareOp.NOT_EQUAL;
      case GREATER_OR_EQUAL:
        return ConditionFactory.CompareOp.GREATER_OR_EQUAL;
      case GREATER:
        return ConditionFactory.CompareOp.GREATER;
      default:
        throw new IllegalArgumentException("Invalid compare operator " + operator);
    }
  }

  public static void checkAggregateSupport(Scan scan) {
    if (scan == null) {
      return;
    }

    if (scan.hasFilter() && !checkSupportedFilters(scan.getFilter())) {
      throw new UnsupportedOperationException("Filter unsupported !");
    }
    if (scan.getMaxVersions() > 1) {
      throw new UnsupportedOperationException("Versions unsupported ! current : " + scan.getMaxVersions());
    }

  }

  public static boolean checkSupportedFilters(Filter filter) {
    if (filter == null) {
      return false;
    }

    if(filter instanceof FilterList){
      return ((FilterList) filter).getFilters().stream().allMatch(AliHBaseUEAggregateService::checkSupportedFilters);
    }else if (filter instanceof RowFilter){
      return true;
    }else if (filter instanceof InclusiveStopFilter){
      return true;
    }else if (filter instanceof SingleColumnValueFilter){
      return true;
    }else{
      return false;
    }
  }

  ColumnInterpreter<T, S, P, Q, R> constructColumnInterpreterFromRequest(AggregateProtos.AggregateRequest request)
      throws IOException {
    String className = request.getInterpreterClassName();
    Class<?> cls;
    try {
      cls = Class.forName(className);
      ColumnInterpreter<T, S, P, Q, R> ci = (ColumnInterpreter<T, S, P, Q, R>) cls.newInstance();
      if (request.hasInterpreterSpecificBytes()) {
        ByteString b = request.getInterpreterSpecificBytes();
        P initMsg = AggregationHelper.getParsedGenericInstance(ci.getClass(), 2, b);
        ci.initialize(initMsg);
      }
      return ci;
    } catch (Exception e) {
      throw new IOException(e);
    }
  }

  class LindormAggregate {

    private byte[] cfName;

    private byte[] colName;

    private DataType interpreterDatype;

    private Aggregate aggregate;

    public LindormAggregate(byte[] cfName, byte[] colName, DataType interpreterDatype, Aggregate aggregate) {
      this.cfName = cfName;
      this.colName = colName;
      this.interpreterDatype = interpreterDatype;
      this.aggregate = aggregate;
    }


    public byte[] getCfName() {
      return cfName;
    }

    public byte[] getColName() {
      return colName;
    }

    public DataType getInterpreterDatype() {
      return interpreterDatype;
    }

    public Aggregate getAggregate() {
      return aggregate;
    }

    public ColumnValue sum() throws LindormException {
      String colStrName = Bytes.toString(colName);
      aggregate.sumAs(Bytes.toString(cfName), colStrName, colStrName, interpreterDatype);
      Row row = aggregate.execute();
      return row.getColumnValue(colName);
    }


    public ColumnValue max() throws LindormException {
      String colStrName = Bytes.toString(colName);
      aggregate.maxAs(Bytes.toString(cfName), colStrName, colStrName);
      Row row = aggregate.execute();
      return row.getColumnValue(colName);
    }

    public ColumnValue min() throws LindormException {
      String colStrName = Bytes.toString(colName);
      aggregate.minAs(Bytes.toString(cfName), colStrName, colStrName);
      Row row = aggregate.execute();
      return row.getColumnValue(colName);
    }

    public ColumnValue avg() throws LindormException {
      String colStrName = Bytes.toString(colName);
      aggregate.avgAs(Bytes.toString(cfName), colStrName, colStrName, interpreterDatype);
      Row row = aggregate.execute();
      return row.getColumnValue(colName);
    }

    public ColumnValue count() throws LindormException {
      if(colName == null){
        aggregate.count();
      }else {
        String colStrName = Bytes.toString(colName);
        aggregate.countAs(Bytes.toString(cfName), colStrName, colStrName);
      }
      Row row = aggregate.execute();
      if(colName == null){
        return row.getColumnValue(COUNT_ALL);
      }
      return row.getColumnValue(colName);
    }
  }

}