#!/usr/bin/env python
# AUTHOR:
#    Adam Wehmann
#    http://www.adamwehmann.com/
#     Last updated: 1/8/2015
#
#    Produced at The Ohio State University
#    http://www.geography.ohio-state.edu/
#
# PURPOSE:
#    Peforms accuracy assessment for classified raster based on testing data.
#    Calculates error matrix, user's accuracies, producer's accuracies,
#        overall accuracy, and Fleiss' kappa statistic.
#    OA and K are returned as derived model outputs.
#    If an output file is specified, full results are also written to it.
#

import arcpy
import numpy

# Retrieve inputs
inputclass = arcpy.GetParameterAsText(0)
inputtest = arcpy.GetParameterAsText(1)
outputresults = arcpy.GetParameterAsText(2)

arcvers = float('.'.join(arcpy.GetInstallInfo()['Version'].split('.')[0:2]))

blocksize = 512

# Get classification data
r = arcpy.Raster(inputclass)

# Get testing data
t = arcpy.Raster(inputtest)

# Check that extents are the same
if (int(r.extent.XMin) != int(t.extent.XMin)) or \
   (int(r.extent.YMin) != int(t.extent.YMin)) or \
   (int(r.extent.XMax) != int(t.extent.XMax)) or \
   (int(r.extent.YMax) != int(t.extent.YMax)):
    arcpy.AddError('Extent of input rasters do not match.')
# Check that number of columns and rows are the same
if (r.width != t.width) or (r.height != t.height):
    arcpy.AddError('Cell sizes of input rasters do not match.')
    
arcpy.AddMessage("Loading data...")

totalprogress = numpy.ceil(1.0 * r.width / blocksize) * numpy.ceil(1.0 * r.height / blocksize)
arcpy.SetProgressor("step", "Loading data...", 0, int(totalprogress), 1)
progress = 0

# Empty error dictionary
train_keys = {}
class_keys = {}
values = {}

# Generate training data
for x in range(0, r.width, blocksize):
    for y in range(0, r.height, blocksize):
        # Extract image and training data blocks
        lx = min([r.width, x + blocksize])
        ly = min([r.height, y + blocksize])
        mx = r.extent.XMin + x * r.meanCellWidth
        my = r.extent.YMin + y * r.meanCellHeight
        dr = arcpy.RasterToNumPyArray(inputclass, arcpy.Point(mx, my), lx-x, ly-y, 0)
        if arcvers >= 10.2:
            dr = numpy.transpose(dr)
        dt = arcpy.RasterToNumPyArray(inputtest, arcpy.Point(mx, my), lx-x, ly-y, 0)
        if arcvers >= 10.2:
            dt = numpy.transpose(dt)
        for xx, yy in numpy.transpose(numpy.nonzero(dt)):
            val_train = dt[xx, yy]
            val_class = dr[xx, yy]
            if (val_train != 0) and (val_class != 0):
                if val_train not in train_keys:
                    train_keys[val_train] = len(train_keys)
                if val_class not in class_keys:
                    class_keys[val_class] = len(class_keys)
                val_key = (class_keys[val_class], train_keys[val_train])
                if val_key in values:
                    values[val_key] += 1
                else:
                    values[val_key] = 1        
        progress = progress + 1
        arcpy.SetProgressorPosition()
arcpy.ResetProgressor()

# Build error matrix
noclasses = max([len(train_keys), len(class_keys)])
M = numpy.zeros((noclasses,noclasses))
for x, y in values.keys():
    M[x, y] = values[(x, y)]
    
# Order key names    
train_names = ['']*noclasses
class_names = ['']*noclasses
for item in train_keys:
    train_names[train_keys[item]] = 'C%d' % item
for item in class_keys:
    class_names[class_keys[item]] = 'C%d' % item
    
arcpy.AddMessage("Calculating...")

# Calculate error statistics

R = numpy.zeros((noclasses+2,noclasses+2))

R[-2, :-2] = numpy.sum(M, axis=0)                    # producer totals
R[:-2, -2] = numpy.sum(M, axis=1)                    # user totals
R[-2, -2]  = numpy.sum(M)                            # total

R[-1, :-2] = 100. * numpy.diag(M)  / R[-2, :-2]   # producer percents
R[:-2, -1] = 100. * numpy.diag(M)  / R[:-2, -2]   # user percents
R[-1, -1]  = 100. * numpy.trace(M) / R[-2, -2]    # total percent
R[0:noclasses,0:noclasses] = M

# Kappa statistic
K = (numpy.sum(M)*numpy.trace(M) - numpy.sum(R[:-2, -2] * R[-2, :-2])) / (numpy.sum(M)**2 - numpy.sum(R[:-2, -2] * R[-2, :-2]))

# Return overall accuracy and kappa
arcpy.AddMessage("Overall Accuracy: %.2f%%" % R[-1, -1])
arcpy.AddMessage("Kappa: %.2f" % K)
arcpy.SetParameter(3, R[-1, -1])
arcpy.SetParameter(4, K)

# Save error matrix to file
if outputresults != '':
    row_labels = class_names
    row_labels += ['PT', 'PP']
    with open(outputresults,'w') as f:
        f.write('Error matrix:\n')
        f.write('\t' + '\t'.join(train_names) + '\tUT\tUP\n')
        for row_label, row in zip(row_labels[:-1], R[:-1,:]):
            f.write('%s\t%s\t%s\n' % (row_label, '\t'.join('%i' % i for i in row[:-1]),'%.2f' % row[-1]))
        f.write('%s\t%s\n' % (row_labels[-1], '\t'.join('%.2f' % i for i in R[-1,:])))
        f.write("\twhere UT is user's total, PT is producer's total, and UP and PP are corresponding percentages.\n\n")
        f.write('Overall Accuracy: %.2f\n' % R[-1,-1])
        f.write('Kappa: %.2f\n\n' % K)
        f.write('Classification: %s\n' % inputclass)
        f.write('Test: %s\n' % inputtest)