#!/usr/bin/env python
# AUTHOR:
#    Adam Wehmann
#    http://www.adamwehmann.com/
#    Last updated: 8/31/2013
#
#    Produced at The Ohio State University
#    http://www.geography.ohio-state.edu/
#
# INTERFACES:
# Chang, C-C., Lin, C-J. 2011. LIBSVM: a library for support vector machines, ACM
#    Transactions on Intelligent Systems and Technology, 2(27), pp. 1-27. Software
#    available at: http://www.csie.ntu.edu.tw/~cjlin/libsvm

from numpy import linspace
from numpy import zeros
from numpy import zeros_like
from numpy import reshape
from numpy import array
from numpy import transpose
from numpy import ceil
from numpy import argsort as npargsort
from numpy import sort as npsort
from svmutil import svm_predict
from svmutil import svm_train
import arcpy

# PURPOSE: Trains a SVM with RBF kernel.  Allows parameter selected via cross-validation.
# PARAMETERS:
#     y = 1D array of correct class labels
#     x = 2D array of corresponding data instances (by row)
#     useprob = [True|False] whether to train s.t. class conditional probabilities may be generated
#     usecv = # of folds or 0 for no C.V.
#     ctup = (cmin,cmax,cnostep,cbase)
#     gtup = (gmin,gmax,gnostep,gbase)
#     c = value for C parameter (used if usecv = 0)
#     g = value for G parameter (used if usecv = 0)
# RETURNS:
#    Tuple (model,bestcv,bestc,bestg)
#        where model = LIBSVM model
#              bestcv = best CV accuracy
#              bestc = best C parameter value found
#              bestg = best G parameter value found
def train(y,x,useprob,usecv,ctup=(-5,15,10,2),gtup=(-15,3,10,2),c=1,g=1):

    daty = y.tolist()
    datx = x.tolist()

    if usecv:
        cmin, cmax, cnostep, cbase = ctup
        gmin, gmax, gnostep, gbase = gtup
        crange = linspace(cmin,cmax,num=cnostep,endpoint=True)
        grange = linspace(gmin,gmax,num=gnostep,endpoint=True)
        nocv = int(cnostep*gnostep)
        arcpy.SetProgressor("step", "Cross Validation Grid Search...", 0, nocv, 1) 
        bestcv = 0
        for logc in crange:
            for logg in grange:
                param = '-q -v %d -c %f -g %f -e 0.1' % (usecv, cbase**logc, gbase**logg)
                cv = svm_train(daty, datx, param)
                if cv >= bestcv:
                    bestcv = cv
                    bestc = cbase**logc
                    bestg = gbase**logg
                arcpy.SetProgressorPosition()
        arcpy.ResetProgressor()
        arcpy.AddMessage("Best ACC %f%%, C %f, G %f"  % (bestcv,bestc,bestg))
    else:
        bestcv = 100.
        bestc, bestg = c, g
    
    # Train best parameters
    arcpy.SetProgressorLabel('Training...')

    param = '-q -c %f -g %f -b %i' % (bestc, bestg, useprob)
    model = svm_train(daty, datx, param)
    
    return (model,bestcv,bestc,bestg)

# PURPOSE: Classifies a raster based on an input LIBSVM model.
# PARAMETERS:
#     d = NumPy array with dimensions (x,y,b)
#    model = LIBSVM model
#    useprob = [True|False] whether to produce class conditional probability estimates
# RETURNS:
#    Tuple (rast,prob)
#        where rast = NumPy array containing classification with dimensions (x, y)
#              prob = NumPy array containing CCP estimates with dimensions (x, y, c)
#                        where c is equal to the number of classes; if not produced,
#                        the value will be None.
#           plabels = list of class labels with order corresponding to the 3rd dimension
#                        of the prob output variable
def predict(d,model,useprob):

    # Get attributes
    nolabels = len(model.get_labels())
    if useprob == 1 and not model.is_probability_model():
        useprob = 0

    # Reshape data
    if len(d.shape) < 3:
        x = d.flatten().tolist()
    else:
        x = reshape(d.flatten(),[-1,d.shape[2]]).tolist()

    y = [0.0]*len(x)

    # Predict SVM
    l, a, p = svm_predict(y, x, model, '-b %i' % (useprob))

    # Reshape data
    if useprob:
        p = reshape(array(p),(d.shape[0],d.shape[1],nolabels))
    r = reshape(array(l), (d.shape[0],d.shape[1]))

    # Reoder probability array from model.get_labels if necessary
    
    if not useprob:
        p = None
        plabels = list(model.get_labels())
    else:
        p[:,:,:] = p[:,:,npargsort(model.get_labels())]
        plabels = npsort(model.get_labels()).tolist()

    return (r, p, plabels)