/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.runtime.state.gemini.engine.memstore;

import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.state.gemini.engine.GRegionContext;
import org.apache.flink.runtime.state.gemini.engine.page.GValueType;
import org.apache.flink.runtime.state.gemini.engine.page.PageSerdeFlink2Key;
import org.apache.flink.runtime.state.gemini.engine.page.bmap.GComparator;
import org.apache.flink.runtime.state.gemini.engine.utils.SeqIDUtils;

import java.util.HashMap;
import java.util.Map;
import java.util.TreeMap;

/**
 * SegmentKMapImpl. support HashMap and SortedMap.
 */
public class SegmentKMapImpl<K, MK, MV> implements SegmentKMap<K, MK, MV> {
	private final long segmentID;
	private final GRegionContext gRegionContext;
	private final long version;
	private final Map<K, GSValueMap<MK, MV>> dataMap;
	private int recordCount = 0;
	private final GComparator<MK> comparator;

	private boolean writeCopy;

	private TypeSerializer<K> keySerializer;

	private TypeSerializer<MK> mkSerializer;

	private TypeSerializer<MV> mvSerializer;

	@SuppressWarnings("unchecked")
	public SegmentKMapImpl(long segmentID, GRegionContext gRegionContext, GComparator<MK> comparator) {
		this.segmentID = segmentID;
		this.gRegionContext = gRegionContext;
		this.dataMap = new HashMap<>();
		this.version = gRegionContext.getGContext().getCurVersion();
		this.comparator = comparator;
		this.writeCopy = gRegionContext.getGContext().getGConfiguration().isWriteCopy();
		this.keySerializer = gRegionContext.getPageSerdeFlink().getKeySerde();
		this.mkSerializer = ((PageSerdeFlink2Key) gRegionContext.getPageSerdeFlink()).getKey2Serde();
		this.mvSerializer = gRegionContext.getPageSerdeFlink().getValueSerde();
	}

	@SuppressWarnings("unchecked")
	public SegmentKMapImpl(
		long segmentID, GRegionContext gRegionContext, Map<K, GSValueMap<MK, MV>> dataMap, GComparator<MK> comparator) {
		this.segmentID = segmentID;
		this.gRegionContext = gRegionContext;
		this.dataMap = dataMap;
		this.version = gRegionContext.getGContext().getCurVersion();
		this.comparator = comparator;
		this.writeCopy = gRegionContext.getGContext().getGConfiguration().isWriteCopy();
		this.keySerializer = gRegionContext.getPageSerdeFlink().getKeySerde();
		this.mkSerializer = ((PageSerdeFlink2Key) gRegionContext.getPageSerdeFlink()).getKey2Serde();
		this.mvSerializer = gRegionContext.getPageSerdeFlink().getValueSerde();
	}

	@Override
	public void add(K key, MK mkey, MV mvalue) {
		GSValueMap<MK, MV> gmap = getOrCreateMap(key);
		long seqID = gRegionContext.getNextSeqID();
		internalAdd(mkey, mvalue, gmap, seqID);
		updateMapSeqID(gmap, seqID);
	}

	private void internalAdd(MK mkey, MV mvalue, GSValueMap<MK, MV> gmap, long seqID) {
		// TODO no need to copy key if the mapping has existed
		if (gmap.getValue().put(copyMKIfNeeded(mkey),
			GSValue.of(copyMVIfNeeded(mvalue), GValueType.PutValue, seqID)) == null) {
			recordCount++;
			gRegionContext.getWriteBufferStats().addTotalRecordCount(1);
		}
	}

	@Override
	public void add(K key, Map<MK, MV> map) {
		final GSValueMap<MK, MV> gmap = getOrCreateMap(key);
		long seqID = gRegionContext.getNextSeqID();
		map.forEach((mk, mv) -> internalAdd(copyMKIfNeeded(mk), copyMVIfNeeded(mv), gmap, seqID));
		updateMapSeqID(gmap, seqID);
	}

	@Override
	public void remove(K key, MK mapKey) {
		// TODO no need to copy key if the mapping has existed
		GSValueMap<MK, MV> gmap = dataMap.computeIfAbsent(copyKeyIfNeeded(key), (nothing) -> createAddMap());
		long seqID = gRegionContext.getNextSeqID();
		if (gmap.getValue() == null) {
			return;
		}
		gmap.getValue().put(copyMKIfNeeded(mapKey),
			GSValue.of(null, GValueType.Delete, seqID));
		updateMapSeqID(gmap, seqID);
	}

	@Override
	public void put(K key, Map<MK, GSValue<MV>> value) {
		GSValueMap<MK, MV> gmap = createPutMap();
		long seqID = SeqIDUtils.INVALID_SEQID;
		for (Map.Entry<MK, GSValue<MV>> entry : value.entrySet()) {
			GSValue<MV> gsValue = entry.getValue();
			seqID = Math.max(seqID, gsValue.getSeqID());
			// TODO better way is to replace put(K, Map<MK, GSValue<MV>>) with put(K, Map<MK, MV>)
			gsValue.setValue(copyMVIfNeeded(gsValue.getValue()));
			gmap.getValue().put(copyMKIfNeeded(entry.getKey()), gsValue);
		}
		GSValueMap<MK, MV> old = dataMap.put(copyKeyIfNeeded(key), gmap);
		updateMapSeqID(gmap, seqID);
		int oldSize = (old == null || old.getValue() == null) ? 0 : old.getValue().size();
		int delta = value.size() - oldSize;
		recordCount = recordCount + delta;
		gRegionContext.getWriteBufferStats().addTotalRecordCount(delta);
	}

	@Override
	public GSValueMap<MK, MV> get(K key) {
		return dataMap.get(key);
	}

	@Override
	public GSValue<MV> get(K key, MK mapKey) {
		GSValueMap<MK, MV> gsValueMap = get(key);
		if (gsValueMap == null) {
			return null;
		}
		if (gsValueMap.getValueType() == GValueType.Delete) {
			return new GSValue<>(null, GValueType.Delete, gsValueMap.getSeqID());
		}
		GSValue<MV> result = gsValueMap.getValue().get(mapKey);
		if (result == null && gsValueMap.getValueType() == GValueType.PutMap) {
			//because this is a PutMap, so prior data is useless. we return DELETE to indicate this case.
			return new GSValue<>(null, GValueType.Delete, gsValueMap.getSeqID());
		}
		return result;
	}

	@Override
	public void removeKey(K key) {
		GSValueMap<MK, MV> gmap = createDeleteMap();
		long seqID = gRegionContext.getNextSeqID();
		updateMapSeqID(gmap, seqID);
		// TODO no need to copy key if the mapping has existed
		dataMap.put(copyKeyIfNeeded(key), gmap);
	}

	@Override
	public long getSegmentID() {
		return this.segmentID;
	}

	@Override
	public int getRecordCount() {
		return this.recordCount;
	}

	@Override
	public long getVersion() {
		return this.version;
	}

	@Override
	public Segment<K, Map<MK, GSValue<MV>>> copySegment() {
		Map<K, GSValueMap<MK, MV>> copyMap = new HashMap<>();
		for (Map.Entry<K, GSValueMap<MK, MV>> entry : dataMap.entrySet()) {
			copyMap.put(entry.getKey(), entry.getValue().copyGSValueMap());
		}
		return new SegmentKMapImpl<>(-1L, this.gRegionContext, copyMap, this.comparator);
	}

	@Override
	public Map getData() {
		return dataMap;
	}

	private GSValueMap<MK, MV> getOrCreateMap(K key) {
		GSValueMap<MK, MV> gmap = dataMap.get(key);
		if (gmap == null || gmap.getSeqID() < gRegionContext.getRemoveAllSeqID()) {
			gmap = createAddMap();
			dataMap.put(copyKeyIfNeeded(key), gmap);
		} else if (gmap.getValueType() == GValueType.Delete) {
			gmap = createPutMap();
			dataMap.put(copyKeyIfNeeded(key), gmap);
		}
		return gmap;
	}

	private GSValueMap<MK, MV> createAddMap() {
		return GSValueMap.of(comparator != null ? new TreeMap<>(comparator.getJDKCompactor()) : new HashMap<>(),
			GValueType.AddMap, SeqIDUtils.INVALID_SEQID);
	}

	private GSValueMap<MK, MV> createPutMap() {
		return GSValueMap.of(comparator != null ? new TreeMap<>(comparator.getJDKCompactor()) : new HashMap<>(),
			GValueType.PutMap, SeqIDUtils.INVALID_SEQID);
	}

	private GSValueMap<MK, MV> createDeleteMap() {
		return GSValueMap.of(null, GValueType.Delete, SeqIDUtils.INVALID_SEQID);
	}

	private void updateMapSeqID(GSValueMap<MK, MV> gsValueMap, long seqID) {
		gsValueMap.setSeqID(seqID);
	}

	public Map<K, GSValueMap<MK, MV>> getDataMap() {
		return dataMap;
	}

	private K copyKeyIfNeeded(K key) {
		return writeCopy ? keySerializer.copy(key) : key;
	}

	private MK copyMKIfNeeded(MK mk) {
		return writeCopy ? mkSerializer.copy(mk) : mk;
	}

	private MV copyMVIfNeeded(MV mv) {
		return writeCopy ? mvSerializer.copy(mv) : mv;
	}
}
