#!/usr/bin/env python
# AUTHOR:
#    Adam Wehmann
#    http://www.adamwehmann.com/
#    Last Updated: 1/8/2015
#
# IMPLEMENTS:
# Chang, C-C., Lin, C-J. 2011. LIBSVM: a library for support vector machines, ACM
#    Transactions on Intelligent Systems and Technology, 2(27), pp. 1-27. Software
#    available at: http://www.csie.ntu.edu.tw/~cjlin/libsvm

import svmutil
import arcpy
import numpy
import sys
import os
from cl_svm import *

imagepath = arcpy.GetParameterAsText(0)
blscale = arcpy.GetParameter(1)
trainpath = arcpy.GetParameterAsText(2)
outputpath = arcpy.GetParameterAsText(3)
probpath = arcpy.GetParameterAsText(4)
modelpath = arcpy.GetParameterAsText(5)
cspec = arcpy.GetParameter(6)
gspec = arcpy.GetParameter(7)
blcv = arcpy.GetParameter(8)
fold = arcpy.GetParameter(9)
cmin = arcpy.GetParameter(10)
cmax = arcpy.GetParameter(11)
gmin = arcpy.GetParameter(12)
gmax = arcpy.GetParameter(13)
cnostep = arcpy.GetParameter(14)
gnostep = arcpy.GetParameter(15)
cbase = arcpy.GetParameter(16)
gbase = arcpy.GetParameter(17)
blocksize = arcpy.GetParameter(18)

if outputpath != '':
    outputpath, outputname = os.path.split(outputpath)

if probpath != '':
    probpath, probname = os.path.split(probpath)
    blprob = True
    probval = 1
else:
    blprob = False
    probval = 0

if blcv:
    cvval = fold
else:
    cvval = 0

# ---------------------------------------------------------------------

arcvers = float('.'.join(arcpy.GetInstallInfo()['Version'].split('.')[0:2]))

if arcvers < 10.0:
    arcpy.AddError('Must use ArcGIS 10.0 or later.')
#if arcvers < 10.2:
#    arcpy.AddWarning('Probability output disabled due to ArcGIS Version -- 10.2 or higher is required.')
#    blprob = False
#    probval = 0

arcpy.AddMessage("Opening raster objects...")

# Create image data raster object
r = arcpy.Raster(imagepath)
r_desc = arcpy.Describe(imagepath)

# Create training data raster object
t = arcpy.Raster(trainpath)

# Check that extents are the same
if (int(r.extent.XMin) != int(t.extent.XMin)) or \
   (int(r.extent.YMin) != int(t.extent.YMin)) or \
   (int(r.extent.XMax) != int(t.extent.XMax)) or \
   (int(r.extent.YMax) != int(t.extent.YMax)):
    arcpy.AddWarning('Extent of input rasters do not match.')

# Store maximum and minimum values for scaling
smax = numpy.NINF
smin = numpy.inf

arcpy.AddMessage("Loading training data...")

totalprogress = numpy.ceil(1.0 * r.width / blocksize) * numpy.ceil(1.0 * r.height / blocksize)
arcpy.SetProgressor("step", "Loading training data...", 0, int(totalprogress), 1)
progress = 0

bldimensioned = False
# Generate training data
for x in range(0, r.width, blocksize):
    for y in range(0, r.height, blocksize):
        # Extract image and training data blocks
        lx = min([r.width, x + blocksize])
        ly = min([r.height, y + blocksize])
        mx = r.extent.XMin + x * r.meanCellWidth
        my = r.extent.YMax - ly * r.meanCellHeight
        if arcvers >= 10.2:
            dr = arcpy.RasterToNumPyArray(imagepath, arcpy.Point(mx,my), lx-x, ly-y)
            dr = numpy.transpose(dr)
        else:
            blfirst = True
            for band in r_desc.children:
                tdr = arcpy.RasterToNumPyArray(os.path.join(imagepath, band.name), arcpy.Point(mx,my), lx-x, ly-y)
                if blfirst:
                    dr = tdr
                    blfirst = False
                else:
                    dr = numpy.dstack((dr,tdr))
        dt = arcpy.RasterToNumPyArray(trainpath,arcpy.Point(mx,my),lx-x,ly-y,0)
        if arcvers >= 10.2:
            dt = numpy.transpose(dt)
        # Get scaling statistics
        if blscale:
            smax = numpy.amax([smax,numpy.amax(dr)])
            smin = numpy.amin([smin,numpy.amin(dr)])
        # Extract training labels and data
        nz = numpy.where(dt != 0)
        dt = dt[nz]
        if dt.size != 0:
            dr = dr[nz]
            if len(dr.shape) == 1:
                dr = dr[:,None]
            if bldimensioned:
                tr = numpy.vstack((tr,numpy.hstack((dt[:,None],dr))))
            else:
                tr = numpy.hstack((dt[:,None],dr))
                bldimensioned = True
        progress = progress + 1
        arcpy.SetProgressorPosition()
    #print("%3.2f%%" % (100.0 * progress / totalprogress))
arcpy.ResetProgressor()

arcpy.AddMessage("%i training samples found for %i classes." % (tr.shape[0], numpy.unique(tr[:,0]).size))

arcpy.AddMessage("Scaling training data...")

# Convert to float
tr = tr.astype(float)

# Scale between [0, 1]
if blscale:
    tr[:,1:] = (tr[:,1:] - smin) / (smax - smin)

# Cross-validation parameters
ct = (cmin, cmax, cnostep, cbase)
gt = (gmin, gmax, gnostep, gbase)

arcpy.AddMessage("Training SVM...")

# Train SVM
(model,bestcv,bestc,bestg) = train(tr[:,0],tr[:,1:],probval,cvval,ctup=ct,gtup=gt,c=cspec,g=gspec)

# Save model if option selected
if modelpath != '':
    svmutil.svm_save_model(modelpath,model)
    arcpy.SetParameterAsText(21,modelpath)

arcpy.env.overwriteOutput = True
arcpy.env.Extent=imagepath
arcpy.env.outputCoordinateSystem=imagepath
arcpy.env.cellSize=imagepath

if outputpath != '':

    arcpy.AddMessage("Classifying...")

    totalprogress = numpy.ceil(1.0 * r.width / blocksize) * numpy.ceil(1.0 * r.height / blocksize)
    arcpy.SetProgressor("step", "Classifying...", 0, int(totalprogress), 1)
    progress = 0

    outputlist = []
    outputplist = []
    blfirstoutput = True
    
    # Predict SVM
    for x in range(0, r.width, blocksize):
        for y in range(0, r.height, blocksize):
            # Extract image data blocks
            lx = min([r.width, x + blocksize])
            ly = min([r.height, y + blocksize])
            mx = r.extent.XMin + x * r.meanCellWidth
            my = r.extent.YMax - ly * r.meanCellHeight
            if arcvers >= 10.2:
                dr = arcpy.RasterToNumPyArray(imagepath,arcpy.Point(mx,my),lx-x,ly-y)
                dr = numpy.transpose(dr)
            else:
                blfirst = True
                for band in r_desc.children:
                    tdr = arcpy.RasterToNumPyArray(os.path.join(imagepath, band.name), arcpy.Point(mx,my), lx-x, ly-y)
                    if blfirst:
                        dr = tdr
                        blfirst = False
                    else:
                        dr = numpy.dstack((dr,tdr))
            # Convert to float and scale
            dr = dr.astype(float)
            if blscale:
                dr = (dr - smin) / (smax - smin)
            # Predict SVM
            (c,p,plabels) = predict(dr,model,probval)
            # Save classification
            if arcvers >= 10.2:
                c = numpy.transpose(c)
                if blprob:
                    p = numpy.transpose(p)
            outc = arcpy.NumPyArrayToRaster(c, arcpy.Point(mx,my), r.meanCellWidth, r.meanCellHeight)
            outtile = ('_%i.' % progress).join(os.path.splitext(outputname))
            outtile = outtile[:-1] if outtile[-1] == '.' else outtile
            outc.save(outputpath + '\\' + outtile)
            outputlist.append(outputpath + '\\' + outtile)
            # Save probabilities
            if blprob:
                if arcvers >= 10.2:
                    outp = arcpy.NumPyArrayToRaster(p, arcpy.Point(mx,my), r.meanCellWidth, r.meanCellHeight)
                    outtile = ('_%i.' % progress).join(os.path.splitext(probname))
                    outtile = outtile[:-1] if outtile[-1] == '.' else outtile
                    outp.save(probpath + '\\' + outtile)
                    outputplist.append(probpath + '\\' + outtile)
                else:
                    for pc in range(p.shape[2]):
                        outp = arcpy.NumPyArrayToRaster(p[:,:,pc], arcpy.Point(mx,my), r.meanCellWidth, r.meanCellHeight)
                        outtile = ('_%i_%i.' % (progress, pc)).join(os.path.splitext(probname))
                        outtile = outtile[:-1] if outtile[-1] == '.' else outtile
                        outp.save(probpath + '\\' + outtile)
                        if blfirstoutput:
                            outputplist.append([probpath + '\\' + outtile])
                        else:
                            outputplist[pc].append(probpath + '\\' + outtile)
                    blfirstoutput = False
                    
            progress = progress + 1
            arcpy.SetProgressorPosition()
        arcpy.AddMessage("%3.2f%%" % (100.0 * progress / totalprogress))
    arcpy.ResetProgressor()

    outc = None
    outp = None
    tdr = None
    dr = None

    arcpy.AddMessage("Mosaicking...")
    
    # Mosaic everything together
    arcpy.Mosaic_management(';'.join(outputlist), outputlist[0])
    arcpy.Rename_management(outputlist[0], outputpath + '\\' + outputname)
    outputlist.remove(outputlist[0])
    if blprob:
        noutputplist = []
        if arcvers >= 10.2:
            arcpy.Mosaic_management(';'.join(outputplist), outputplist[0])
            arcpy.Rename_management(outputplist[0], probpath + '\\' + probname)
            outputplist.remove(outputplist[0])
        else:
            for pitem in range(len(outputplist)):
                arcpy.Mosaic_management(';'.join(outputplist[pitem]), outputplist[pitem][0])
                noutputplist.append(outputplist[pitem][0])
                outputplist[pitem].remove(outputplist[pitem][0])
            arcpy.CompositeBands_management(';'.join(noutputplist), probpath + '\\' + probname)
        arcpy.AddMessage("Probability band class order: " + str(plabels))

    # Remove temporary files
    for f in outputlist:
         arcpy.Delete_management(f)
    if blprob:
        if arcvers >= 10.2:
            for f in outputplist:
                arcpy.Delete_management(f)
        else:
            for flist in outputplist:
                for f in flist:
                    arcpy.Delete_management(f)
            for f in noutputplist:
                arcpy.Delete_management(f)
                        
    # Set parameters
    arcpy.SetParameterAsText(19,outputpath + '\\' + outputname)

    if blprob:
        arcpy.SetParameterAsText(20,probpath + '\\' + probname)

arcpy.SetParameter(22,bestcv)    # best accuracy
arcpy.SetParameter(23,bestc)     # best c
arcpy.SetParameter(24,bestg)     # best g

arcpy.AddMessage("Finished!")