package cn.com.duiba.nezha.engine.biz.service.advert.ctr.impl;

import cn.com.duiba.nezha.alg.common.util.StringZIP;
import cn.com.duiba.nezha.alg.feature.parse.FeatureParse;
import cn.com.duiba.nezha.alg.feature.vo.FeatureDo;
import cn.com.duiba.nezha.alg.model.CODER;
import cn.com.duiba.nezha.alg.model.ESMM;
import cn.com.duiba.nezha.alg.model.FFM;
import cn.com.duiba.nezha.alg.model.FM;
import cn.com.duiba.nezha.alg.model.IModel;
import cn.com.duiba.nezha.alg.model.enums.MutModelType;
import cn.com.duiba.nezha.alg.model.enums.PredictResultType;
import cn.com.duiba.nezha.alg.model.tf.TFServingClient;
import cn.com.duiba.nezha.engine.api.enums.DeepTfServer;
import cn.com.duiba.nezha.engine.api.enums.ModelKey;
import cn.com.duiba.nezha.engine.api.enums.ModelKeyEnum;
import cn.com.duiba.nezha.engine.api.enums.ModelType;
import cn.com.duiba.nezha.engine.api.support.RecommendEngineException;
import cn.com.duiba.nezha.engine.biz.bo.hbase.ConsumerFeatureBo;
import cn.com.duiba.nezha.engine.biz.domain.ActivityDo;
import cn.com.duiba.nezha.engine.biz.domain.AdvertStatFeatureDo;
import cn.com.duiba.nezha.engine.biz.domain.AppDo;
import cn.com.duiba.nezha.engine.biz.domain.ConsumerDo;
import cn.com.duiba.nezha.engine.biz.domain.FeatureIndex;
import cn.com.duiba.nezha.engine.biz.domain.RequestDo;
import cn.com.duiba.nezha.engine.biz.domain.advert.Advert;
import cn.com.duiba.nezha.engine.biz.domain.advert.Material;
import cn.com.duiba.nezha.engine.biz.domain.advert.OrientationPackage;
import cn.com.duiba.nezha.engine.biz.service.CacheService;
import cn.com.duiba.nezha.engine.biz.service.advert.ctr.AdvertPredictService;
import cn.com.duiba.nezha.engine.biz.service.advert.ctr.AdvertStatFeatureService;
import cn.com.duiba.nezha.engine.biz.vo.advert.AdvertRecommendRequestVo;
import cn.com.duiba.nezha.engine.common.utils.SystemClock;
import cn.com.duiba.wolf.perf.timeprofile.DBTimeProfile;
import cn.com.duibaboot.ext.autoconfigure.core.utils.CatUtils;
import com.alibaba.fastjson.JSON;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListenableFutureTask;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;

import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.partitioningBy;
import static java.util.stream.Collectors.toMap;
import static java.util.stream.Collectors.toSet;

/**
 * @author ZhouFeng zhoufeng@duiba.com.cn
 * @version $Id: AdvertPredictServiceImpl.java , v 0.1 2017/5/20 下午11:09 ZhouFeng Exp $
 */
@Service
public class AdvertPredictServiceImpl extends CacheService implements AdvertPredictService, InitializingBean {

    protected Logger logger = LoggerFactory.getLogger(getClass());
    /**
     * 离线预估模型缓存
     */
    private final LoadingCache<ModelKeyEnum, IModel> offlineFMModelCache = CacheBuilder.newBuilder()
            .refreshAfterWrite(30, TimeUnit.MINUTES)
            .build(new CacheLoader<ModelKeyEnum, IModel>() {

                @Override
                public IModel load(ModelKeyEnum modelKey) {
                    //数据第一次加载或失效时调用，如果获取不到值则返回null
                    return loadFMModel(modelKey).orElse(null);
                }

                @Override
                public ListenableFuture<IModel> reload(ModelKeyEnum modelKey, IModel oldValue) {
                    //数据刷新时调用，如果获取不到，将继续使用老值
                    ListenableFutureTask<IModel> task = ListenableFutureTask.create(() -> loadFMModel(modelKey).orElse(oldValue));
                    executorService.submit(task);
                    return task;
                }
            });
    /**
     * 离线预估模型缓存
     */
    private final LoadingCache<ModelKeyEnum, IModel> onlineFFMModelCache = CacheBuilder.newBuilder()
            .refreshAfterWrite(90, TimeUnit.SECONDS)
            .build(new CacheLoader<ModelKeyEnum, IModel>() {

                @Override
                public IModel load(ModelKeyEnum modelKey) {
                    //数据第一次加载或失效时调用，如果获取不到值则返回null
                    return loadFFMModel(modelKey).orElse(null);
                }

                @Override
                public ListenableFuture<IModel> reload(ModelKeyEnum modelKey, IModel oldValue) {
                    //数据刷新时调用，如果获取不到，将继续使用老值
                    ListenableFutureTask<IModel> task = ListenableFutureTask.create(() -> loadFFMModel(modelKey).orElse(oldValue));
                    executorService.submit(task);
                    return task;
                }
            });
    /**
     * 在线编码器模型缓存
     */
    private final LoadingCache<ModelKeyEnum, IModel> onlineCODERModelCache = CacheBuilder.newBuilder()
            .refreshAfterWrite(60, TimeUnit.SECONDS)
            .build(new CacheLoader<ModelKeyEnum, IModel>() {

                @Override
                public IModel load(ModelKeyEnum modelKey) {
                    //数据第一次加载或失效时调用，如果获取不到值则返回null
                    return loadCODERModel(modelKey).orElse(null);
                }

                @Override
                public ListenableFuture<IModel> reload(ModelKeyEnum modelKey, IModel oldValue) {
                    //数据刷新时调用，如果获取不到，将继续使用老值
                    ListenableFutureTask<IModel> task = ListenableFutureTask.create(() -> loadCODERModel(modelKey).orElse(oldValue));
                    executorService.submit(task);
                    return task;
                }
            });


    /**
     * 在线预估模型缓存
     */
    private final LoadingCache<ModelKeyEnum, IModel> onlineFMModelCache = CacheBuilder.newBuilder()
            .refreshAfterWrite(60, TimeUnit.SECONDS)
            .build(new CacheLoader<ModelKeyEnum, IModel>() {

                @Override
                public IModel load(ModelKeyEnum modelKey) {
                    //数据第一次加载或失效时调用，如果获取不到值则返回null
                    return loadFMModel(modelKey).orElse(null);
                }

                @Override
                public ListenableFuture<IModel> reload(ModelKeyEnum modelKey, IModel oldValue) {
                    //数据刷新时调用，如果获取不到，将继续使用老值
                    ListenableFutureTask<IModel> task = ListenableFutureTask.create(() -> loadFMModel(modelKey).orElse(oldValue));
                    executorService.submit(task);
                    return task;
                }
            });
    @Autowired
    private ConsumerFeatureBo consumerFeatureBo;
    @Autowired
    private AdvertStatFeatureService advertStatFeatureService;
    @Value("${spring.profiles.active}")
    private String active;

    @Override
    public Map<PredictResultType, Map<FeatureIndex, Double>> predict(AdvertRecommendRequestVo advertRecommendRequestVo, Map<PredictResultType, Collection<Advert>> needPredictAdvertMap) {


        try {
            DBTimeProfile.enter("predict");
            Collection<Advert> adverts = advertRecommendRequestVo.getAdvertMap().values();

            AppDo appDo = advertRecommendRequestVo.getAppDo();
            RequestDo requestDo = advertRecommendRequestVo.getRequestDo();
            ConsumerDo consumerDo = advertRecommendRequestVo.getConsumerDo();
            ActivityDo activityDo = advertRecommendRequestVo.getActivityDo();

            Long appId = appDo.getId();
            Long slotId = appDo.getSlotId();
            Long activityId = activityDo.getOperatingId();


            Map<Long, AdvertStatFeatureDo> advertStatFeatureMap = advertStatFeatureService.get(appId, slotId, activityId, adverts);
            advertRecommendRequestVo.setAdvertStatFeatureMap(advertStatFeatureMap);

            // 获取特征
            Map<FeatureIndex, Map<String, String>> featureMap = this.getFeatureMap(adverts,
                    consumerDo, appDo, activityDo, requestDo,
                    advertStatFeatureMap,advertRecommendRequestVo);

            advertRecommendRequestVo.setFeatureMap(featureMap);
            // 获取多模型类型
            MutModelType mutModelType = advertRecommendRequestVo.getMutModelType();
            IModel ctrModel = this.getModel(advertRecommendRequestVo.getCtrModelKey());
            IModel cvrModel = this.getModel(advertRecommendRequestVo.getCvrModelKey());
            IModel fusingCtrModel = this.getModel(advertRecommendRequestVo.getFusingCtrModelKey());
            IModel fusingCvrModel = this.getModel(advertRecommendRequestVo.getFusingCvrModelKey());
            DeepTfServer deepCtrModelKey = advertRecommendRequestVo.getDeepCtrModelKey();
            DeepTfServer deepCvrModelKey = advertRecommendRequestVo.getDeepCvrModelKey();
            TFServingClient ctrTfServingClient = Optional.ofNullable(deepCtrModelKey).map(server -> new TFServingClient(server.getHost(), server.getPort(), server.getModelKey(), null)).orElse(null);
            TFServingClient cvrTfServingClient = Optional.ofNullable(deepCvrModelKey).map(server -> new TFServingClient(server.getHost(), server.getPort(), server.getModelKey(), null)).orElse(null);

            ESMM esmm = new ESMM(ctrModel, cvrModel, fusingCtrModel, fusingCvrModel, null, ctrTfServingClient, cvrTfServingClient, mutModelType);

            long now = SystemClock.now();
            Map<PredictResultType, Map<FeatureIndex, Double>> predictResultTypeMapMap = CatUtils
                    .executeInCatTransaction(() -> esmm.predictCTRsAndCVRsWithTF(featureMap),
                            "esmmPredict",
                            advertRecommendRequestVo.getAdvertAlgEnum().toString());
            advertRecommendRequestVo.setPredictCost(SystemClock.now()-now);
            // 计算得出预估ctr和预估cvr


            return needPredictAdvertMap.entrySet().stream().collect(toMap(Map.Entry::getKey, entry -> {
                Map<FeatureIndex, Double> featurePredictValueMap = predictResultTypeMapMap
                        .getOrDefault(entry.getKey(), new HashMap<>());
                return entry.getValue()
                        .stream()
                        .map(this::convertToFeatureIndex)
                        .flatMap(Set::stream)
                        .collect(toMap(Function.identity(),
                                featureIndex -> featurePredictValueMap.
                                        getOrDefault(featureIndex, 0D)));
            }));

        } catch (Throwable e) {
            logger.warn("predict happened error:{}", e);
            return new HashMap<>();
        } finally {
            DBTimeProfile.release();
        }
    }

    private Set<FeatureIndex> convertToFeatureIndex(Advert advert) {
        Long advertId = advert.getId();
        Set<OrientationPackage> orientationPackages = advert.getOrientationPackages();

        return orientationPackages.stream().map(orientationPackage -> {
            Set<Material> materials = orientationPackage.getMaterials();
            if (materials.isEmpty()) {
                return Collections.singleton(new FeatureIndex(advertId, orientationPackage.getId()));
            } else {
                return materials.stream().map(material -> new FeatureIndex(advertId, orientationPackage.getId(), material.getId())).collect(toSet());
            }
        }).flatMap(Set::stream).collect(toSet());
    }

    public Map<FeatureIndex, Map<String, String>> getFeatureMap(Collection<Advert> adverts,
                                                                ConsumerDo consumerDo,
                                                                AppDo appDo,
                                                                ActivityDo activityDo,
                                                                RequestDo requestDo,
                                                                Map<Long, AdvertStatFeatureDo> advertStatFeatureMap,
                                                                AdvertRecommendRequestVo advertRecommendRequestVo) {

        // 获取本次请求特征(静态特征)
        FeatureDo staticFeatureDo = this.getFeatureDo(consumerDo, appDo, activityDo, requestDo);
        // 设置与媒体的属性
        staticFeatureDo.setAppInNewTrade(advertRecommendRequestVo.getAppInNewTrade());
        staticFeatureDo.setAppTagInNewTrade(advertRecommendRequestVo.getAppTagInNewTrade());

        Map<String, String> staticFeatureMap = FeatureParse.generateFeatureMapStatic(staticFeatureDo);

        // 算法需要的特征map , FeatureDo -> Map<String,String>
        Map<FeatureIndex, Map<String, String>> finalFeatureMap = new HashMap<>();

        adverts.forEach(advert -> advert.getOrientationPackages().forEach(orientationPackage -> {


            Long advertId = advert.getId();
            Long packageId = orientationPackage.getId();

            // 广告粒度的特征
            FeatureDo orientationPackageFeatureDo = new FeatureDo();
            orientationPackageFeatureDo.setAdvertId(advertId);
            orientationPackageFeatureDo.setAccountId(advert.getAccountId());
            orientationPackageFeatureDo.setMatchTagNums(advert.getMatchTags());
            orientationPackageFeatureDo.setTradeId2(advert.getIndustryTag());
            orientationPackageFeatureDo.setAdvertTags(advert.getSpreadTags().stream().filter(Objects::nonNull).collect(Collectors.joining(",")));
            orientationPackageFeatureDo.setTimes(advert.getLaunchCountToUser());
            orientationPackageFeatureDo.setOperatingResource(advert.getResourceTagNum());
            orientationPackageFeatureDo.setBankEndType(orientationPackage.getCvrType().toString());
            orientationPackageFeatureDo.setOperatingNewTrade(advert.getNewTradeTagId());
            orientationPackageFeatureDo.setNewTradeInAppTag(advert.getNewTradeInAppTag());
            orientationPackageFeatureDo.setAdvertInAppTag(advert.getAdvertInAppTag());

            //用户维度数据
            orientationPackageFeatureDo.setNewTradeDayOrderRank(consumerDo.getNewTradeDayOrderRank());
            orientationPackageFeatureDo.setLastOperatingNewTrade(consumerDo.getLastOperatingNewTrade());
            orientationPackageFeatureDo.setNewTradeLastGmtCreateTime(consumerDo.getNewTradeLastGmtCreateTime());

            // 设置特征统计值
            Optional.ofNullable(advertStatFeatureMap.get(advertId)).ifPresent(advertStatFeatureDo -> {
                orientationPackageFeatureDo.setAdvertCtr(advertStatFeatureDo.getAdvertCtr());
                orientationPackageFeatureDo.setAdvertCvr(advertStatFeatureDo.getAdvertCvr());
                orientationPackageFeatureDo.setAdvertAppCtr(advertStatFeatureDo.getAdvertAppCtr());
                orientationPackageFeatureDo.setAdvertAppCvr(advertStatFeatureDo.getAdvertAppCvr());
                orientationPackageFeatureDo.setAdvertSlotCtr(advertStatFeatureDo.getAdvertSlotCtr());
                orientationPackageFeatureDo.setAdvertSlotCvr(advertStatFeatureDo.getAdvertSlotCvr());
                orientationPackageFeatureDo.setAdvertActivityCtr(advertStatFeatureDo.getAdvertActivityCtr());
                orientationPackageFeatureDo.setAdvertActivityCvr(advertStatFeatureDo.getAdvertActivityCvr());
            });
            // 获取配置包对应的素材
            Set<Material> materials = orientationPackage.getMaterials();

            // 如果没有素材,则直接使用广告的特征
            if (materials.isEmpty()) {
                FeatureIndex orientationPackageFeatureIndex = FeatureIndex.newBuilder().advertId(advertId).packageId(packageId).build();
                Map<String, String> dynamicMapMap = this.getDynamicMapMap(staticFeatureDo, orientationPackageFeatureDo, staticFeatureMap);
                finalFeatureMap.put(orientationPackageFeatureIndex, dynamicMapMap);
                return;
            }

            materials.forEach(materialDto -> {
                Long materialId = materialDto.getId();

                FeatureIndex orientationPackageMaterialFeatureIndex = FeatureIndex.newBuilder().advertId(advertId).packageId(packageId).materialId(materialId).build();

                FeatureDo materialFeatureDo = this.copyAdvertFeature(orientationPackageFeatureDo);
                materialFeatureDo.setMaterialId(materialId.toString());
                materialFeatureDo.setAtmosphere(materialDto.getAtmosphere());
                materialFeatureDo.setBackgroundColour(materialDto.getBackgroundColour());
                materialFeatureDo.setIfPrevalent(String.valueOf(materialDto.getPrevalent()));
                materialFeatureDo.setDescribeKeywords(materialDto.getInterception());
                materialFeatureDo.setDynamicEffect(materialDto.getCarton());
                materialFeatureDo.setBodyElement(materialDto.getBodyElement());
                materialFeatureDo.setMaterialTags(materialDto.getTags().stream().filter(Objects::nonNull).collect(joining(",")));
                Map<String, String> dynamicMap = this.getDynamicMapMap(staticFeatureDo, materialFeatureDo, staticFeatureMap);
                finalFeatureMap.put(orientationPackageMaterialFeatureIndex, dynamicMap);
            });
        }));

        return finalFeatureMap;

    }

    /**
     * 获取用户请求特征
     */
    private FeatureDo getFeatureDo(ConsumerDo consumerDo, AppDo appDo, ActivityDo activityDo, RequestDo requestDo) {
        try {
            DBTimeProfile.enter("consumerFeatureBo.getFeatureDo");
            return consumerFeatureBo.getFeatureDo(consumerDo, appDo, activityDo, requestDo);
        } catch (Exception e) {
            throw new RecommendEngineException("consumerFeatureBo.getFeatureDo happened error", e);
        } finally {
            DBTimeProfile.release();
        }
    }

    /**
     * 加载模型
     *
     * @param modelKey 模型的key
     */
    private IModel getModel(ModelKeyEnum modelKey) {

        try {
            if (modelKey.getOnline()) {
                if (modelKey.getModelType().equals(ModelType.FFM)) {
                    return onlineFFMModelCache.get(modelKey);
                } else if(modelKey.getModelType().equals(ModelType.CODER)){
                    return onlineCODERModelCache.get(modelKey);
                }else {
                    return onlineFMModelCache.get(modelKey);

                }
            } else {
                return offlineFMModelCache.get(modelKey);
            }
        } catch (Exception e) {
            logger.warn("get model happened error ,the model key is {} ,{}", modelKey, e);
            return null;
        }
    }


    /**
     * 加载在线预估模型
     */
    private Optional<IModel> loadFMModel(ModelKeyEnum modelKeyEnum) {
        try {
            DBTimeProfile.enter("loadFMModel");
            return Optional.ofNullable(modelKeyEnum)
                    .map(modelKey -> nezhaStringRedisTemplate.opsForValue().get(ModelKey.getLastModelNewKey(modelKey.getIndex())))
                    .map(StringZIP::unzipString)
                    .map(json -> JSON.parseObject(json, FM.class));
        } catch (Exception e) {
            logger.error("load loadFMModel error,modelKey:{}", modelKeyEnum);
            throw new RecommendEngineException("load model exception", e);
        } finally {
            DBTimeProfile.release();
        }

    }

    /**
     * 加载在线预估模型
     */
    private Optional<IModel> loadFFMModel(ModelKeyEnum modelKeyEnum) {
        try {
            DBTimeProfile.enter("loadFFMModel");
            return Optional.ofNullable(modelKeyEnum)
                    .map(modelKey -> nezhaStringRedisTemplate.opsForValue().get(ModelKey.getLastModelNewKey(modelKey.getIndex())))
                    .map(StringZIP::unzipString)
                    .map(json -> JSON.parseObject(json, FFM.class));
        } catch (Exception e) {
            logger.error("load loadFFMModel error,modelKey:{}", modelKeyEnum);
            throw new RecommendEngineException("load model exception", e);
        } finally {
            DBTimeProfile.release();
        }

    }

    /**
     *
     * @param modelKeyEnum
     * @return
     */
    private Optional<IModel> loadCODERModel(ModelKeyEnum modelKeyEnum) {
        try {
            DBTimeProfile.enter("loadCODERModel");
            return Optional.ofNullable(modelKeyEnum)
                    .map(modelKey -> nezhaStringRedisTemplate.opsForValue().get(ModelKey.getLastModelNewKey(modelKey.getIndex())))
                    .map(StringZIP::unzipString)
                    .map(json -> JSON.parseObject(json, CODER.class));
        } catch (Exception e) {
            logger.error("load loadCODERModel error,modelKey:{}", modelKeyEnum);
            throw new RecommendEngineException("load model exception", e);
        } finally {
            DBTimeProfile.release();
        }
    }



    /**
     * 拷贝对象
     */
    private FeatureDo copyAdvertFeature(FeatureDo advertFeatureDo) {
        FeatureDo featureDo = new FeatureDo();
        featureDo.setAdvertId(advertFeatureDo.getAdvertId());
        featureDo.setAccountId(advertFeatureDo.getAccountId());
        featureDo.setMatchTagNums(advertFeatureDo.getMatchTagNums());
        featureDo.setTradeId2(advertFeatureDo.getTradeId2());
        featureDo.setAdvertTags(advertFeatureDo.getAdvertTags());
        featureDo.setTimes(advertFeatureDo.getTimes());
        featureDo.setOperatingResource(advertFeatureDo.getOperatingResource());
        featureDo.setBankEndType(advertFeatureDo.getBankEndType());

        // 设置特征统计值
        featureDo.setAdvertCtr(advertFeatureDo.getAdvertCtr());
        featureDo.setAdvertCvr(advertFeatureDo.getAdvertCvr());
        featureDo.setAdvertAppCtr(advertFeatureDo.getAdvertAppCtr());
        featureDo.setAdvertAppCvr(advertFeatureDo.getAdvertAppCvr());
        featureDo.setAdvertSlotCtr(advertFeatureDo.getAdvertSlotCtr());
        featureDo.setAdvertSlotCvr(advertFeatureDo.getAdvertSlotCvr());
        featureDo.setAdvertActivityCtr(advertFeatureDo.getAdvertActivityCtr());
        featureDo.setAdvertActivityCvr(advertFeatureDo.getAdvertActivityCvr());

        return featureDo;
    }

    /**
     * 获取动态map
     */
    private Map<String, String> getDynamicMapMap(FeatureDo staticFeatureDo, FeatureDo advertFeatureDo, Map<String, String> staticMap) {
        Map<String, String> dynamicMap = FeatureParse.generateFeatureMapDynamic(advertFeatureDo, staticFeatureDo);
        dynamicMap.putAll(staticMap);
        return dynamicMap;
    }


    @Override
    public void afterPropertiesSet() throws Exception {

        if ("dev".equals(active)) {
            return;
        }
        //初始化模型
        logger.info("start init predict model");
        long start = System.currentTimeMillis();

        Map<Boolean, List<ModelKeyEnum>> modelKeyMap = Arrays.stream(ModelKeyEnum.values())
                .collect(partitioningBy(ModelKeyEnum::getOnline));


        // 预加载离线模型
        for (ModelKeyEnum modelKey : modelKeyMap.get(false)) {
            try {
                offlineFMModelCache.get(modelKey);
            } catch (Exception e) {
                logger.info("{}", modelKey);
            }
        }

        // 预加载在线模型
        for (ModelKeyEnum modelKey : modelKeyMap.get(true)) {
            try {
                onlineFMModelCache.get(modelKey);
            } catch (Exception e) {
                logger.info("{}", modelKey);
            }
        }

        logger.info("init predict model finish,spend {} ms", System.currentTimeMillis() - start);

    }
}
