package cn.com.duiba.nezha.compute.mllib.model

import cn.com.duiba.nezha.compute.api.point.Point
import cn.com.duiba.nezha.compute.mllib.classification.ClassificationModel
import Point.{FMParams, LabeledSPoint}
import cn.com.duiba.nezha.compute.mllib.optimizing.FMGD
import cn.com.duiba.nezha.compute.mllib.util.{MLUtil, SparseUtil}
import org.apache.spark.mllib.linalg._
import org.apache.spark.rdd.RDD

/**
 * Logistic regression based classification.
 */

class SparseFMModel(fmParams: FMParams) extends GeneralizedModel with ClassificationModel with Serializable {
  var numFeatures: Int = -1
  var threshold: Option[Double] = Some(0.5)

  //  private var fmParams: FMParams = null

  //  def setfmParams(fmParams: FMParams): this.type = {
  //    this.fmParams = fmParams
  //    this
  //  }


  def setThreshold(threshold: Double): this.type = {
    this.threshold = Some(threshold)
    this
  }

  def getThreshold: Option[Double] = threshold

  def getFMParams: FMParams = fmParams


  def clearThreshold(): this.type = {
    threshold = None
    this
  }

  override def predict(testData: RDD[SparseVector]): RDD[Double] = {
    testData.map(x => SparseFMModel.predict(x, fmParams, threshold))
  }

  override def predict(sv: SparseVector): Double = {
    SparseFMModel.predict(sv, fmParams, threshold)
  }

  override def predictPoint(testData: RDD[LabeledSPoint]): RDD[(Double, Double)] = {
    testData.map(p => (SparseFMModel.predict(p.x, fmParams, threshold), p.y))
  }

  override def predictPoint(point: LabeledSPoint): (Double, Double) = {
    (SparseFMModel.predict(point.x, fmParams, threshold), point.y)
  }
}

object SparseFMModel {

  //  def predLabel(p: SDataPoint, fm_params: FMParams, threshold: Double): Double = {
  //    ClassifierEvaluater.signLabel(predictPoint(p, fm_params), threshold)
  //  }

  def apply(fm_params: FMParams, vector: SparseVector) = new SparseFMModel(fm_params: FMParams)

  def predict(p: SparseVector, fm_params: FMParams, threshold: Option[Double]): Double = {

    val p_score = FMGD.h(p, fm_params)


    threshold match {
      case Some(t) => if (p_score > t) 1.0 else 0.0
      case None => p_score
    }
  }

}