/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.healthmanager.plugins.detectors;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.operators.ResourceSpec;
import org.apache.flink.runtime.healthmanager.HealthMonitor;
import org.apache.flink.runtime.healthmanager.RestServerClient;
import org.apache.flink.runtime.healthmanager.plugins.Detector;
import org.apache.flink.runtime.healthmanager.plugins.Symptom;
import org.apache.flink.runtime.healthmanager.plugins.symptoms.JobVertexTmKilledDueToMemoryExceed;
import org.apache.flink.runtime.jobgraph.ExecutionVertexID;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KilledDueToMemoryExceedDetector
implements Detector {
    private static final Logger LOGGER = LoggerFactory.getLogger(KilledDueToMemoryExceedDetector.class);
    private static final String ERR_MSG_MEMORY_EXCEED_YARN_V3 = "Container Killed due to memory exceeds ";
    private static final String ERR_MSG_MEMORY_EXCEED_YARN_V2 = "is running beyond physical memory limits. Current usage: ";
    private static final String ERR_MSG_MACHINE_MEMORY_HEAVY_YARN_V3 = "QosContainersMonitor killing, reason: machine memory is too heavy";
    private JobID jobID;
    private HealthMonitor monitor;
    private RestServerClient restServerClient;
    private long lastDetectTime;
    private Map<String, List<JobVertexID>> tmTasks;
    private long hmInterval;

    @Override
    public void open(HealthMonitor monitor) {
        this.monitor = monitor;
        this.jobID = monitor.getJobID();
        this.restServerClient = monitor.getRestServerClient();
        this.lastDetectTime = System.currentTimeMillis();
        this.tmTasks = new HashMap<String, List<JobVertexID>>();
        this.hmInterval = monitor.getConfig().getLong(HealthMonitor.HEALTH_CHECK_INTERNAL);
    }

    @Override
    public void close() {
    }

    @Override
    public Symptom detect() throws Exception {
        LOGGER.debug("Start detecting.");
        long now = System.currentTimeMillis();
        if (now - this.lastDetectTime > this.hmInterval * 2L) {
            LOGGER.debug("Long time since last detection, detect for recent exceptions.");
            this.lastDetectTime = now - this.hmInterval * 2L;
        }
        Map<String, List<Exception>> tmExceptions = this.restServerClient.getTaskManagerExceptions(this.lastDetectTime, now);
        this.lastDetectTime = now;
        JobVertexTmKilledDueToMemoryExceed jobVertexTmKilledDueToMemoryExceed = null;
        if (tmExceptions != null) {
            RestServerClient.JobConfig jobConfig = this.monitor.getJobConfig();
            HashMap<JobVertexID, Double> vertexMaxUtilities = new HashMap<JobVertexID, Double>();
            for (Map.Entry<String, List<Exception>> entry : tmExceptions.entrySet()) {
                String tmId = entry.getKey();
                for (Exception exception : entry.getValue()) {
                    double exceedTime = this.getExceedTime(exception.getLocalizedMessage());
                    if (exceedTime < 0.0) continue;
                    List<JobVertexID> vertices = this.tmTasks.get(tmId);
                    LOGGER.debug("TM {} with tasks {} killed due to memory exceed {} times.", new Object[]{tmId, vertices, exceedTime});
                    if (vertices == null) continue;
                    for (JobVertexID vertexID : vertices) {
                        ResourceSpec currentResource = jobConfig.getVertexConfigs().get((Object)vertexID).getResourceSpec();
                        double usage = (double)(currentResource.getHeapMemory() + currentResource.getDirectMemory() + currentResource.getNativeMemory()) * exceedTime;
                        double capacity = currentResource.getNativeMemory();
                        if (capacity == 0.0) {
                            capacity = 1.0;
                        }
                        double utility = usage / capacity;
                        if (vertexMaxUtilities.containsKey((Object)vertexID) && !(utility > (Double)vertexMaxUtilities.get((Object)vertexID))) continue;
                        vertexMaxUtilities.put(vertexID, utility);
                    }
                }
            }
            if (!vertexMaxUtilities.isEmpty()) {
                LOGGER.info("TM killed due to memory exceed detected for vertices with max utility {}.", vertexMaxUtilities);
                jobVertexTmKilledDueToMemoryExceed = new JobVertexTmKilledDueToMemoryExceed(this.jobID, vertexMaxUtilities);
            }
        }
        this.updateTmTasks();
        return jobVertexTmKilledDueToMemoryExceed;
    }

    private void updateTmTasks() {
        for (Map.Entry<String, List<ExecutionVertexID>> entry : this.restServerClient.getAllTaskManagerTasks().entrySet()) {
            this.tmTasks.put(entry.getKey(), entry.getValue().stream().map(executionVertexID -> executionVertexID.getJobVertexID()).collect(Collectors.toList()));
        }
    }

    private double getExceedTime(String msg) {
        if (msg.contains(ERR_MSG_MEMORY_EXCEED_YARN_V3)) {
            msg = msg.substring(msg.indexOf(ERR_MSG_MEMORY_EXCEED_YARN_V3) + ERR_MSG_MEMORY_EXCEED_YARN_V3.length());
            return Double.valueOf(msg.split(" ")[0]);
        }
        if (msg.contains(ERR_MSG_MACHINE_MEMORY_HEAVY_YARN_V3)) {
            return 1.0;
        }
        if (msg.contains(ERR_MSG_MEMORY_EXCEED_YARN_V2)) {
            int unit;
            msg = msg.substring(msg.indexOf(ERR_MSG_MEMORY_EXCEED_YARN_V2) + ERR_MSG_MEMORY_EXCEED_YARN_V2.length());
            String[] tokens = msg.split(" ");
            double usage = Double.valueOf(tokens[0]);
            switch (tokens[1].charAt(0)) {
                case 'E': {
                    unit = 6;
                    break;
                }
                case 'P': {
                    unit = 5;
                    break;
                }
                case 'T': {
                    unit = 4;
                    break;
                }
                case 'G': {
                    unit = 3;
                    break;
                }
                case 'M': {
                    unit = 2;
                    break;
                }
                case 'K': {
                    unit = 1;
                    break;
                }
                default: {
                    unit = 0;
                }
            }
            while (unit-- > 0) {
                usage *= 1024.0;
            }
            double capacity = Double.valueOf(tokens[3]);
            switch (tokens[4].charAt(0)) {
                case 'E': {
                    unit = 6;
                    break;
                }
                case 'P': {
                    unit = 5;
                    break;
                }
                case 'T': {
                    unit = 4;
                    break;
                }
                case 'G': {
                    unit = 3;
                    break;
                }
                case 'M': {
                    unit = 2;
                    break;
                }
                case 'K': {
                    unit = 1;
                    break;
                }
                default: {
                    unit = 0;
                }
            }
            while (unit-- > 0) {
                capacity *= 1024.0;
            }
            return usage / capacity;
        }
        return -1.0;
    }
}

