/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.streaming.api.driver;

import org.apache.flink.api.common.cache.DistributedCache;
import org.apache.flink.client.program.ClusterClient;
import org.apache.flink.client.program.DriverContextEnvironment;
import org.apache.flink.client.program.PackagedProgram;
import org.apache.flink.client.program.ProgramInvocationException;
import org.apache.flink.client.program.StandaloneClusterClient;
import org.apache.flink.client.program.rest.RestClusterClient;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.CoreOptions;
import org.apache.flink.configuration.DriverConfigConstants;
import org.apache.flink.configuration.JobManagerOptions;
import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings;
import org.apache.flink.streaming.api.functions.source.RichSourceFunction;

import org.apache.commons.io.FilenameUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
 * DriverSourceFunction.
 * As a flink job, Driver will call the main method of user packaged program to submit user jobs via a flink Source.
 * before calling the main method, Driver needs to prepare a DriverStreamEnvironment, then gets user blobs from distributed
 * cache and set into prepared environment with other managed configuration, cluster connection information, etc.
 */
public class DriverSourceFunction extends RichSourceFunction<Byte> {

	private final Logger logger = LoggerFactory.getLogger(DriverSourceFunction.class);

	private final String clusterIP;

	private final int clusterPort;

	private final Class<?> userMainClass;

	private DriverStreamEnvironment driverStreamEnvironment;

	private final List<URL> userJar;

	private final List<URL> libJarURLs;

	private final List<URL> externalFileURLs;

	private final List<URL> classPath;

	private final Configuration configuration;

	private final String driverName;

	private final String[] args;

	public DriverSourceFunction(
		String driverName,
		Class<?> userMainClass,
		String[] args,
		List<URL> userJar,
		List<URL> classPath,
		List<URL> libjars,
		List<URL> externalFiles,
		Configuration configuration) {
		this.clusterIP = configuration.getString(JobManagerOptions.ADDRESS);
		this.clusterPort = configuration.getInteger(JobManagerOptions.PORT);
		this.userJar = userJar;
		this.classPath = classPath;
		this.driverName = driverName;
		this.userMainClass = userMainClass;
		this.args = args;
		this.libJarURLs = libjars;
		this.externalFileURLs = externalFiles;
		this.configuration = configuration;
	}

	@Override
	public void open(Configuration parameters) throws Exception {
		super.open(parameters);

		List<URI> jarFiles = getBlods(userJar, DriverEntry.BLOB_TYPE_JARFILE);
		List<String> jarFileStrings = jarFiles.stream().flatMap(uri -> Stream.of(uri.getPath())).collect(Collectors.toList());
		List<URI> globalClassPath = getBlods(classPath, DriverEntry.BLOB_TYPE_CLASSPATH);
		List<URI> libJars = getBlods(libJarURLs, DriverEntry.BLOB_TYPE_LIB_JARS);
		List<URI> externalFiles = getBlods(externalFileURLs, DriverEntry.BLOB_TYPE_EXTERNAL_FILE);

		List<URI> addUpGlobalClassPath = new ArrayList<>();
		addUpGlobalClassPath.addAll(globalClassPath);
		addUpGlobalClassPath.addAll(libJars);
		addUpGlobalClassPath.addAll(externalFiles);
		URL[] addUpGlobalClassPathURL = transURItoURL(addUpGlobalClassPath).toArray(new URL[0]);
		driverStreamEnvironment = new DriverStreamEnvironment(
			clusterIP,
			clusterPort,
			driverName,
			jarFileStrings.toArray(new String[0]),
			addUpGlobalClassPathURL,
			configuration
		);

		driverStreamEnvironment.setParallelism(configuration.getInteger(DriverConfigConstants.FLINK_DRIVER_PARALLELISM,
			configuration.getInteger(CoreOptions.DEFAULT_PARALLELISM)));
		driverStreamEnvironment.setAsContext();

		ClusterClient client = prepareClusterClient(false);
		DriverContextEnvironment driverContextEnvironment = new DriverContextEnvironment(client, driverName, transURItoURL(jarFiles),
			transURItoURL(globalClassPath), libJars, externalFiles, getRuntimeContext().getUserCodeClassLoader(), getJobSavePointSettingsFromConfiguration());

		driverContextEnvironment.setParallelism(configuration.getInteger(DriverConfigConstants.FLINK_DRIVER_PARALLELISM,
			configuration.getInteger(CoreOptions.DEFAULT_PARALLELISM)));
		driverContextEnvironment.setAsContext();
	}

	@Override
	public void run(SourceContext<Byte> ctx) throws Exception {
		PackagedProgram.callMainMethod(userMainClass, args);
	}

	@Override
	public void close() throws Exception {
		driverStreamEnvironment.resetContextEnvironments();
	}

	@Override
	public void cancel() {

	}

	protected static List<URL> transURItoURL(List<URI> uris){
		return uris.stream().flatMap(uri -> {
			try {
				return Stream.of(uri.toURL());
			} catch (MalformedURLException e) {
				e.printStackTrace();
			}
			return Stream.empty();
		}).collect(Collectors.toList());
	}

	private List<URI> getBlods(List<URL> urls, String blobType){
		List<URI> result = new ArrayList<>();
		DistributedCache distributedCache = getRuntimeContext().getDistributedCache();
		for (URL url : urls) {
			String fileName = FilenameUtils.getName(url.getPath());
			File file = distributedCache.getFile(blobType + DriverEntry.BLOB_TYPE_SEPARATOR + fileName);
			if (file != null) {
				result.add(file.toURI());
			}
		}
		return result;
	}

	private ClusterClient<?> prepareClusterClient(boolean detached) throws Exception {

		Configuration configuration = new Configuration();
		configuration.addAll(this.configuration);
		final ClusterClient<?> client;
		try {
			if (CoreOptions.LEGACY_MODE.equals(configuration.getString(CoreOptions.MODE))) {
				client = new StandaloneClusterClient(configuration);
			} else {
				client = new RestClusterClient<>(configuration, "DriverContextEnvironment");
			}

			logger.info(String.format("connection info: host: %s, port: %d", clusterIP, clusterPort));
			client.setDetached(detached);
		} catch (Exception e) {
			throw new ProgramInvocationException("Cannot establish connection to JobManager: " + e.getMessage(), e);
		}

		client.setPrintStatusDuringExecution(getRuntimeContext().getExecutionConfig().isSysoutLoggingEnabled());

		return client;
	}

	private SavepointRestoreSettings getJobSavePointSettingsFromConfiguration() {
		String savepointRestorePath = configuration.getString(DriverConfigConstants.FLINK_DRIVER_SAVEPOINT_RESTORE_SETTINGS_PATH, null);
		boolean allowNonRestoredState = configuration.getBoolean(DriverConfigConstants.FLINK_DRIVER_SAVEPOINT_RESTORE_SETTINGS_ALLOWNONRESTORESTATE, false);
		boolean resumeFromLatestCheckpoint = configuration.getBoolean(DriverConfigConstants.FLINK_DRIVER_SAVEPOINT_RESTORE_SETTINGS_RESUMEFROMLATESTCHECKPOINT, false);
		if (savepointRestorePath == null) {
			return SavepointRestoreSettings.none();
		}
		if (resumeFromLatestCheckpoint) {
			return SavepointRestoreSettings.forResumePath(savepointRestorePath, allowNonRestoredState);
		} else {
			return SavepointRestoreSettings.forPath(savepointRestorePath, allowNonRestoredState);
		}

	}
}
