from __future__ import absolute_import, division, print_function

import sys,os
from warnings import filterwarnings, warn
warnings_ignore = [
    'is a low contrast image', 
]

for msg in warnings_ignore:
    filterwarnings("ignore", message='.*%s.*'%msg)

import numpy as np

from os.path import join as pathjoin, exists as pathexists, split as pathsplit
from os.path import splitext, abspath, getsize, realpath, isabs, expandvars
from os.path import expanduser

mkdir = os.makedirs

pathfile   = lambda path: pathsplit(path)[1] # /path/to/file.ext -> file.ext
pathbase   = lambda path: splitext(pathfile(path))[0] # /path/to/file.ext -> file

from collections import OrderedDict
from spectral.io.envi import open as envi_open_file
import LatLongUTMconversion as LLUTMConv


DATUM_WGS84 = 23 # WGS-84
DEG2RAD = np.pi/180.0

NODATA = -9999
    
# reused constants
ORDER_NN    = 0
ORDER_BILIN = 1
ORDER_QUAD  = 2
ORDER_CUBIC = 3

CONN4 = 1
CONN8 = 2

POINTSRC = 1
DIFFSRC = 2

use_absmf=False

kernel=50
mfmin,mfmax = 500,1500
minarea=9
mfminsmall  = 1250

def imsave(fname,img,**kwargs):
    from skimage.io import imsave as _imsave    
    if kwargs.pop('verbose',False):
        print('saving',fname)
    return _imsave(fname,img,**kwargs)

def array2rgba(a,**kwargs):
    '''
    converts a 1d array into a set of rgba values according to
    the default or user-provided colormap
    '''
    from pylab import rcParams
    from matplotlib.cm import get_cmap
    from numpy import isnan,clip,uint8,where
    kwargs.setdefault('cmap',rcParams['image.cmap'])
    kwargs.setdefault('lut',rcParams['image.lut'])
    cm = get_cmap(kwargs['cmap'],lut=kwargs['lut'])    
    aflat = np.float32(a.copy().ravel())
    nanmask = isnan(aflat)
    nvalid = np.count_nonzero(~nanmask)
    avals = aflat[~nanmask]
    vmin = float(kwargs.pop('vmin',avals.min()))
    vmax = float(kwargs.pop('vmax',avals.max()))
    if nvalid>0 and vmax>vmin:                
        aflat[nanmask] = vmin # use vmin to map to zero, below
        aflat = clip(((aflat-vmin)/(vmax-vmin)),0.,1.)
        rgba = uint8(cm(aflat)*255)
        if nanmask.any():
            nanr = np.where(nanmask)
            rgba[nanr[0],:] = 0
        rgba = rgba.reshape(list(a.shape)+[4])
    else:
        if nvalid==0:
            warn('all values nan')
        elif vmax<=vmin:
            warn('vmax <= vmin')
        rgba = np.zeros(list(a.shape)+[4],dtype=uint8)
    return rgba

def float2rgba(img,cmap='binary',vmin=0.0,vmax=1.0,alpha=0):
    """
    float2rgba(img,alpha=0)
    
    Summary: Stretch range of unit-scaled 32-bit single band float image to
      4-band (3x8 rgb + 1x8 alpha) uint8 rgba image
    
    Arguments:
    - img: [n,m,1] single band 32-bit float image, pixel values \in [0.0,1.0]
    - alpha: output alpha value (default=0)
    
    Output:
    - [n,m,4] rgba uint8 image with img encoded as 3x8-bit rgb bands and
      constant alpha band

    Notes:
    - img must be scaled to [0,1] range, but should *not* be scaled into
      [FLT_MIN,FLT_MAX] (typically=[1.175494e-38,3.402823e+38]) range
    - most image analysis software ignores alpha band, so we only use the 24-bit
      range instead of the full 32-bit range
    - default alpha=0 will produce transparent images, set alpha=255 or ignore
      alpha band to visualize output
    
    """
    
    # assume img pixel values \in [0,1] range
    assert((img.min()>=vmin) & (img.max()<=vmax))
    if cmap=='binary':
        # max value of uint8 rgb (8x3 bands=24bits) pixel = 2**(24)-1
        rgbavec = np.uint32(((2**24)-1)*img).view(dtype=np.uint8)
    else:
        rgbavec = np.uint8(array2rgba(np.float32(img.ravel()),cmap=cmap,
                                      vmin=vmin,vmax=vmax,lut=4096))
    rgba = rgbavec.reshape([img.shape[0],img.shape[1],4])    
    rgba[...,-1] = np.uint8(alpha)
    return rgba

def rgba2float(img,cmap='binary',alpha=0):
    from pylab import rcParams
    from matplotlib.cm import get_cmap
    imgc = img.copy()
    if cmap=='binary':
        imgc[...,-1] = np.uint8(alpha)    
        imgc = imgc.view(np.uint32) / np.float32((2**24)-1)        
        return imgc.squeeze()
    cm = get_cmap(cmap,lut=4096)    
    lut = cm.colors[...,:3]
    d = ((imgc[...,:3]/255.0-lut[:, None, None, :])**2).sum(axis=-1)
    f = d.argmin(axis=0)/np.float32(len(lut)-1)
    return f # np.round(len(lut)*f))/float(len(lut))

def rgbdet2ql(rgb,det,detmask):
    ch4idx = np.where(detmask)
    ch4rgba = array2rgba(det[ch4idx],cmap='YlOrRd')
    rgbimg = np.uint8(255*rgb.copy())
    rgbimg[ch4idx[0],ch4idx[1],:3] = ch4rgba[:,:3]
    return rgbimg #.transpose((1,0,2))[::-1])

def progressbar(caption,maxval=None):
    """
    progress(title,maxval=None)
    
    Summary: progress bar wrapper
    
    Arguments:
    - caption: progress bar caption
    - maxval: maximum value
    
    Keyword Arguments:
    None 
    
    Output:
    - progress bar instance (pbar.update(i) to step, pbar.finish() to close)
    """
    
    from progressbar import ProgressBar, Bar, UnknownLength
    from progressbar import Percentage, Counter, ETA

    capstr = caption + ': ' if caption else ''
    if maxval is not None:    
        widgets = [capstr, Percentage(), ' ', Bar('='), ' ', ETA()]
    else:
        maxval = UnknownLength
        widgets = [capstr, Counter(), ' ', Bar('=')]

    return ProgressBar(widgets=widgets, maxval=maxval)    

def findboundaries(labimg,**kwargs):
    from skimage.segmentation import find_boundaries
    kwargs.setdefault('connectivity',CONN8)
    return find_boundaries(labimg,**kwargs)

def thickboundaries(labimg):
    return findboundaries(labimg,mode='thick')

def innerboundaries(labimg):
    return findboundaries(labimg,mode='inner')

def outerboundaries(labimg):
    return findboundaries(labimg,mode='outer')

def createimg(hdrf,metadata,**kwargs):
    from spectral.io.envi import create_image
    return create_image(hdrf, metadata, ext='', force=True)

def openimgmm(img,interleave='source',writable=False):
    img_mm = img.open_memmap(interleave=interleave, writable=writable)
    return img_mm

def imlabel(img,**kwargs):
    from skimage.measure import label as _label
    kwargs.setdefault('connectivity',CONN8)
    return _label(img,**kwargs)

def disk(radius,**kwargs):
    from skimage.morphology import disk as _disk
    return _disk(radius,**kwargs)

def bwdilate(bwimg,**kwargs):
    from skimage.morphology import binary_dilation as _bwd
    kwargs.setdefault('selem',disk(3))
    return _bwd(bwimg,**kwargs)

def bwdist(bwimg,**kwargs):
    metric = kwargs.get('metric','euclidean')
    if metric=='euclidean':
        kwargs.pop('metric',None) # metric is only option for edt
        from scipy.ndimage.morphology import distance_transform_edt as _bwdist
    elif metric in ('chessboard','taxicab'):
        from scipy.ndimage.morphology import distance_transform_cdt as _bwdist
    kwargs.setdefault('return_distances',True)
    kwargs.setdefault('return_indices',False)
    return _bwdist(bwimg,**kwargs)

def mergelabels(labimg,mergedist,return_merged=False,doplot=False):
    # merge labeled regions <= mergedist pixels from each other
    fgmask = (labimg!=0)
    mergecomp = imlabel(bwdist(~fgmask,metric='chessboard')<=mergedist)
    mergelab = np.unique(mergecomp)
    outimg = np.zeros_like(labimg)
    outmerged = {}
    for i,ml in enumerate(mergelab[mergelab!=0]):
        mlmask = (mergecomp==ml) & fgmask
        outimg[mlmask] = i+1
        if return_merged:
            outmerged[ml] = np.unique(labimg[mlmask])
    print(len(np.unique(labimg))-1,'labels before merge',len(mergelab)-1,'after merge')
    if doplot:
        import pylab as pl
        fig,ax = pl.subplots(1,2,sharex=True,sharey=True)
        ax[0].imshow(labimg)
        ax[0].set_title('before')
        ax[1].imshow(outimg)
        ax[1].set_title('after')
        pl.show()
    if return_merged:
        return outimg, outmerged
    return outimg

def imread(fname,**kwargs):
    from skimage.io import imread as _imread
    kwargs.setdefault('plugin',None)
    return _imread(fname,**kwargs)

def imresize(img,output_shape,**kwargs):
    from skimage.transform import resize as _imresize
    
    kwargs.setdefault('order',ORDER_NN) 
    kwargs.setdefault('clip',False)
    kwargs.setdefault('preserve_range',True)
    if kwargs.pop('anti_alias',False):
        from scipy import ndimage as ndi
        cval = kwargs.get('cval',0)
        mode = kwargs.get('mode','constant')
        sigma = kwargs.pop('anti_alias_sigma',None)
        if sigma is None:
            factors = (np.asarray(img.shape, dtype=float) /
                       np.asarray(output_shape, dtype=float))
            sigma = np.maximum(0, (factors - 1) / 2)
        imgrs = ndi.gaussian_filter(img, sigma, cval=cval, mode=mode)
    else:
        imgrs = img
    
    return _imresize(imgrs,output_shape,**kwargs)

def filename2flightid(filename):
    '''
    get flight id from filename
    ang20160922t184215_cmf_v1g_img -> ang20160922t184215
    '''
    imgid = pathbase(filename).split('_')[0]
    return imgid

def filename2flightdate(filename):
    '''
    get flight id from filename
    ang20160922t184215_cmf_v1g_img -> ang20160922t184215
    '''
    
    flightid = filename2flightid(filename)
    if flightid.startswith('f'): # avcl
        datestr = flightid.split('t')[0][1:6]
        yyyy,mm,dd = map(int,['20'+datestr[:2],datestr[2:4],datestr[4:6]])
    else:
        datestr = flightid.split('t')[0][-8:]
        yyyy,mm,dd = map(int,[datestr[:4],datestr[4:6],datestr[6:]])
    
    return yyyy,mm,dd

def shortstr(s,width=20,placeholder='...',quote=False):
    shortstr = s if len(s) <= width else s[:width]+placeholder
    return '"%s"'%shortstr if quote else shortstr

def filename2calid(infile):
    # e.g., ./ort/ang20160915t194328_cmf_v1n2_img -> v1n2
    _,filen = pathsplit(infile)
    spl = filen.split('_')
    if filen.startswith('f'): # avcl
        calid = str(spl[1]+'_'+spl[2])
    else:
        calid = spl[2]
    return calid

l1calid = filename2calid

def extrema(a,**kwargs):
    p = kwargs.pop('p',1.0)
    if p==1.0:
        return np.amin(a,**kwargs),np.amax(a,**kwargs)
    elif p==0.0:
        return np.amax(a,**kwargs),np.amin(a,**kwargs)
    assert(p>0.0 and p<1.0)
    axis = kwargs.pop('axis',None)
    apercent = lambda q: np.nanpercentile(a,axis=axis,q=q,
                                          interpolation='nearest')
    return apercent((1-p)*100),apercent(p*100)

def envitypecode(np_dtype):
    from numpy import dtype
    from spectral.io.envi import dtype_to_envi
    _dtype = dtype(np_dtype).char
    return dtype_to_envi[_dtype]

def extract_tile(img,ul,tdim,verbose=False,transpose=None,cval=0):
    '''
    extract a tile of dims (tdim,tdim,img.shape[2]) offset from upper-left 
    coordinate ul in img, zero pads when tile overlaps image extent 
    '''
    ndim = img.ndim
    if ndim==3:
        nr,nc,nb = img.shape
    elif ndim==2:
        nr,nc = img.shape
        nb = 1
    else:
        raise Exception('invalid number of image dims %d'%ndim)
    
    lr = (ul[0]+tdim,ul[1]+tdim)
    padt,padb = abs(max(0,-ul[0])), tdim-max(0,lr[0]-nr)
    padl,padr = abs(max(0,-ul[1])), tdim-max(0,lr[1]-nc)
    
    ibeg,iend = max(0,ul[0]),min(nr,lr[0])
    jbeg,jend = max(0,ul[1]),min(nc,lr[1])

    if verbose:
        print(ul,nr,nc)
        print(padt,padb,padl,padr)
        print(ibeg,iend,jbeg,jend)

    imgtile = cval*np.ones([tdim,tdim,nb],dtype=img.dtype)
    imgtile[padt:padb,padl:padr] = np.atleast_3d(img[ibeg:iend,jbeg:jend])
    if transpose is not None:
        imgtile = imgtile.transpose(transpose)
    return imgtile

def rotxy(x,y,adeg,xc,yc):
    """
    rotxy(x,y,adeg,xc,yc)

    Summary: rotate point x,y about xc,yc by adeg degrees

    Arguments:
    - x: x coord to rotate
    - y: y coord to rotate
    - adeg: angle of rotation in degrees
    - xc: center x coord
    - yc: center y coord

    Output:
    rotated x,y point
    """
    arad = DEG2RAD*adeg
    sinr,cosr = np.sin(arad),np.cos(arad)
    rotm = [[cosr,-sinr],[sinr,cosr]]
    xyp = np.dot(rotm,np.c_[x-xc,y-yc].T)

    # handle scalar outputs
    if xyp.ndim==2 and xyp.shape[1]==1:
        xyp = xyp.squeeze()
    return xyp[0]+xc,xyp[1]+yc
        
def meshpositions(*arrs):
    grid = np.meshgrid(*arrs)
    positions = np.vstack(map(np.ravel, grid))
    return positions

def chull(p,**kwargs):
    from scipy.spatial import ConvexHull as _chull
    return_indices = kwargs.pop('return_indices',False)
    hullidx = _chull(p,**kwargs).vertices
    hullp = p[hullidx]
    if return_indices:
        return hullp,hullidx
    return hullp

def utm2latlon(easting,northing,zone,hemi='North',alpha=None,datum=DATUM_WGS84):
    if hemi not in ('North','South'):
        print('invalid hemisphere value=',hemi)
        return None,None
    
    zone_alpha = alpha or ('N' if hemi=='North' else 'M')
    lat,lon = LLUTMConv.UTMtoLL(datum,easting,northing,str(zone)+zone_alpha)
    return lat,lon

def sl2xy(s,l,**kwargs):
    """
    sl2xy(s,l,x0=0,y0=0,xps=0,yps=xps,rot=0,mapinfo=None) 

    Given integer pixel coordinates (s,l) convert to map coordinates (x,y)

    Arguments:
    - s,l: sample, line indices

    Keyword Arguments:
    - x0,y0: upper left map coordinate (default = (0,0))    
    - xps: x map pixel size (default=None)
    - yps: y map pixel size (default=xps)
    - rot: map rotation in degrees (default=0)
    - mapinfo: envi map info dict (entries replaced with kwargs above)    

    Returns:
    - x,y: x,y map coordinates of sample s, line l
    """
    mapinfo = kwargs.pop('mapinfo',{})    
    x0 = kwargs.pop('ulx',mapinfo.get('ulx',None))
    y0 = kwargs.pop('uly',mapinfo.get('uly',None))
    xps = kwargs.pop('xps',mapinfo.get('xps',None))
    yps = kwargs.pop('yps',mapinfo.get('yps',xps))
    rot = kwargs.pop('rot',mapinfo.get('rotation',0))

    if None in (x0,y0):
        raise ValueError("ulx or uly undefined")

    if None in (xps,yps):
        raise ValueError("xps or yps undefined")

    if yps == 0:
        yps = xps

    xp,yp = x0+xps*s, y0-yps*l        
    if rot == 0:
        return xp,yp

    ar = DEG2RAD*rot
    X, Y = rotxy(xp,yp,rot,x0,y0)
    #X = x0 + xps * s + ar  * l
    #Y = y0 + ar  * s - yps * l    
    return X, Y

def sl2latlon(s,l,**kwargs):
    mapinfo = kwargs.get('mapinfo',{})
    proj = mapinfo.get('proj',None)
    if not proj:
        raise ValueError("proj undefined")
    elif proj not in ('UTM','Geographic Lat/Lon'):   
        print('unknown projection:',proj)
        return None
    
    x,y = sl2xy(s,l,**kwargs)
    if proj=='Geographic Lat/Lon':
        return y,x

    return utm2latlon(y,x,zone=mapinfo['zone'],hemi=mapinfo['hemi'])

def xy2sl(x,y,**kwargs):
    """
    xy2sl(x,y,x0=0,y0=,xps=0,yps=xps,rot=0,mapinfo=None) 

    Given a orthocorrected image find the (s,l) values for a given (x,y)

    Arguments:
    - x,y: map coordinates

    Keyword Arguments:
    - x0,y0: upper left map coordinate (default = (0,0))
    - xps: x map pixel size (default=None)
    - yps: y map pixel size (default=xps)
    - rot: map rotation in degrees (default=0)
    - mapinfo: envi map info dict (xps,yps,rot override)

    Returns:
    - s,l: sample, line coordinates of x,y
    """
    mapinfo = kwargs.pop('mapinfo',{})

    x0 = kwargs.pop('ulx',mapinfo.get('ulx',None))
    y0 = kwargs.pop('uly',mapinfo.get('uly',None))
    xps = kwargs.pop('xps',mapinfo.get('xps',None))
    yps = kwargs.pop('yps',mapinfo.get('yps',xps))
    rot = kwargs.pop('rot',mapinfo.get('rotation',0))
    #if mapinfo and rot != 0:
    #    # flip sign of mapinfo rot unless otherwise specified
    #    rot = rot * kwargs.pop('rotsign',-1) 
    
    if None in (x0,y0):
        raise ValueError("either ulx or uly defined")

    if xps is None:
        raise ValueError("pixel size defined")

    yps = yps or xps
    xp, yp = (x-x0), (y0-y)
    if rot!=0:
        xp, yp = rotxy(xp,yp,rot,0,0)

    xp,yp = xp/xps,yp/yps
    return xp,yp

def latlon2utm(lat,lon,zone=None,datum=DATUM_WGS84):
    """
    latlon2utm(lat,lon,zone=None,datum=DATUM_WGS84) 

    Arguments:
    - lat: latitude coordinate(s)
    - lon: longitude coordinate(s)

    Keyword Arguments:
    - zone:   UTM zone number (default=None, computed automatically)
    - datum:  reference ellipsoid (default=_DATUM_WGS84)

    Returns:
    - easting:  UTM easting map coordinate
    - northing: UTM northing map coordinate
    - zone:     UTM zone digit
    - hemi:     hemisphere letter (hemi >= 'N' -> Northern hemisphere)

    """

    zonealpha,easting,northing = LLUTMConv.LLtoUTM(datum,lat,lon,
                                                   ZoneNumber=zone)
    return easting, northing, int(zonealpha[:-1]), zonealpha[-1]

def latlon2sl(lat,lon,**kwargs):
    mapinfo = kwargs.get('mapinfo',{})
    proj = mapinfo.get('proj',None)
    if not proj:
        raise ValueError("proj undefined")
    elif proj not in ('UTM','Geographic Lat/Lon'):   
        print('Unknown projection:',proj)
        return None
    
    if proj=='Geographic Lat/Lon':
        return xy2sl(lon,lat,mapinfo=mapinfo)

    zone = int(mapinfo['zone']) if 'zone' in mapinfo else None
    x,y,zone,zonealpha = latlon2utm(lat,lon,zone=zone)
    return xy2sl(x,y,mapinfo=mapinfo)

def mapdict2str(mapdict):
    mapmeta = mapdict.pop('metadata',[])
    mapkeys,mapvals = mapdict.keys(),mapdict.values()
    nargs = 10 if mapdict['proj']=='UTM' else 7
    maplist = map(str,mapvals[:nargs])
    mapkw = zip(mapkeys[nargs:],mapvals[nargs:])
    mapkw = [str(k)+'='+str(v) for k,v in mapkw]
    mapstr = '{ '+(', '.join(maplist+mapkw+mapmeta))+' }'
    return mapstr

def findhdr(img_file):
    from os import path
    dirname = path.dirname(img_file)
    filename,filext = path.splitext(img_file)
    if filext == '.hdr' and path.isfile(img_file): # img_file is a .hdr
        return img_file
    
    hdr_file = img_file+'.hdr' # [img_file.img].hdr or [img_file].hdr
    if path.isfile(hdr_file):
        return path.abspath(hdr_file)
    hdr_file = filename+'.hdr' # [img_file.img] -> [img_file].hdr 
    if path.isfile(hdr_file):
        return hdr_file
    return None

def mapinfo(img,astype=dict):
    from spectral import SpyFile
    _img = (envi_open_file(findhdr(img)) if not isinstance(img,SpyFile) else img)
    maplist = _img.metadata.get('map info',None)    
    if maplist is None:
        return None
    elif astype==list:
        return maplist    
    else:
        mapinfo = OrderedDict()
        mapinfo['proj'] = maplist[0]
        mapinfo['xtie'] = float(maplist[1])
        mapinfo['ytie'] = float(maplist[2])
        mapinfo['ulx']  = float(maplist[3])
        mapinfo['uly']  = float(maplist[4])
        mapinfo['xps']  = float(maplist[5])
        mapinfo['yps']  = float(maplist[6])

        if mapinfo['proj'] == 'UTM':
            mapinfo['zone']  = maplist[7]
            mapinfo['hemi']  = maplist[8]
            mapinfo['datum'] = maplist[9]

        mapmeta = []
        for mapitem in maplist[len(mapinfo):]:
            if '=' in mapitem:
                key,val = map(lambda s: s.strip(),mapitem.split('='))
                mapinfo[key] = val
            else:
                mapmeta.append(mapitem)

        mapinfo['rotation'] = float(mapinfo.get('rotation','0'))
        if len(mapmeta)!=0:
            print('unparsed metadata:',mapmeta)
            mapinfo['metadata'] = mapmeta

    if astype==str:
        return mapdict2str(mapinfo)

    return mapinfo

def prob2geotiff(proboutf,probimg,probimgf):
    from gdal import Open, GetDriverByName, GA_ReadOnly, GDT_Float32
    
    print('saving',proboutf)  
    gtifdrv = GetDriverByName('Gtiff')
    nl_prob,ns_prob,nb_prob = probimg.shape
    g = Open(probimgf, GA_ReadOnly)
    geo_t = g.GetGeoTransform()
    geo_p = g.GetProjectionRef()
    rows,cols = g.RasterYSize,g.RasterXSize

    prob_geo_t = (1, geo_t[1], geo_t[2], 1, geo_t[4], geo_t[5])
    gsub = gtifdrv.Create(proboutf, ns_prob, nl_prob, nb_prob, GDT_Float32)
    gsub.SetGeoTransform(prob_geo_t)
    gsub.SetProjection(geo_p)
    for i in range(nb_prob):
        gsub.GetRasterBand(i+1).WriteArray(probimg[:,:,i])
    gsub = None
    
    
def tile2geotiff(tileimg,tilepos,tilef,baseimgf,validate_coords=True):
    from gdal import Open, GetDriverByName, GA_ReadOnly, GDT_Byte
    
    print('saving',tilef)  
    gtifdrv = GetDriverByName('Gtiff')
    nl_tile,ns_tile,nb_tile = tileimg.shape
    g = Open(baseimgf, GA_ReadOnly)
    geo_t = g.GetGeoTransform()
    geo_p = g.GetProjectionRef()
    nl_img,ns_img = g.RasterYSize,g.RasterXSize

    def sl2map(s,l,ulx,uly,xps):
        return sl2xy(s,l,mapinfo=dict(ulx=ulx,uly=uly,xps=xps,yps=xps,rot=0))

    def ct(s,l,mapdict):
        ulx,uly,xps = [mapdict[k] for k in ('ulx','uly','xps')]
        rot = mapdict.get('rotation',0)
        x,y = sl2map(s,l,ulx,uly,xps)
        if rot==0:
            return x,y
        return rotxy(x,y,rot,ulx,uly)

    def bbox(xy):
        minx,maxx = extrema([x for x,y in xy])
        miny,maxy = extrema([y for x,y in xy])
        return minx,maxy,maxx,miny

    mapdict = mapinfo(baseimgf)
    tile_ct = lambda (s,l): ct(s,l,mapdict)
    
    l0,s0 = tilepos
    tile_bbox_s = [s+s0 for s in [0,     0, ns_tile, ns_tile]]
    tile_bbox_l = [l+l0 for l in [0, nl_tile, 0,     nl_tile]]
    tile_bbox_sl = zip(tile_bbox_s,tile_bbox_l)
    tile_bbox_xy = map(tile_ct,tile_bbox_sl)
    
    #l0 = max(0,min(tilepos[0],nl_img-1))
    #s0 = max(0,min(tilepos[1],ns_img-1))
    #l1 = max(l0,min(tilepos[0]+tileimg.shape[0],ns_img-1))
    #s1 = max(s0,min(tilepos[1]+tileimg.shape[1],ns_img-1))
    x_tile = geo_t[0]+s0*geo_t[1]+l0*geo_t[2]
    y_tile = geo_t[3]+s0*geo_t[4]+l0*geo_t[5]
        
    if validate_coords:
        from skimage.measure import points_in_poly
        img_bbox_s = [0,0,ns_img,ns_img]
        img_bbox_l = [0,nl_img,nl_img,0]
        img_bbox_sl = zip(img_bbox_s,img_bbox_l)
        img_bbox_xy = []
        print(pathsplit(baseimgf)[1],'(s,l) -> (x,y) bounding box')
        for s,l in img_bbox_sl:
            x = geo_t[0]+s*geo_t[1]+l*geo_t[2]
            y = geo_t[3]+s*geo_t[4]+l*geo_t[5]
            img_bbox_xy.append((x,y))
            print((s,l),'->\t',(x,y))

        print('tile (s,l) -> (x,y) bounding box')
        off_lab = ['Upper Left','Lower Left','Upper Right','Lower Right']

        in_img = np.zeros(len(tile_bbox_sl),dtype=np.bool8)
        for i,(st,lt) in enumerate(tile_bbox_sl):
            xt = geo_t[0]+st*geo_t[1]+lt*geo_t[2]
            yt = geo_t[3]+st*geo_t[4]+lt*geo_t[5]
            xydist = (np.float32((xt,yt))-np.float32(tile_bbox_xy[i]))**2
            xydist = np.sqrt(xydist.sum())
            if xydist > 1.0:
                # if gdal/numpy differ by more than 1m, wait
                print('WARNING:','xydist=',xydist,'gdal (x,y)=',(xt,yt),
                      'numpy (x,y)=',tile_bbox_xy[i])
                raw_input()
            in_img[i] = points_in_poly([(xt,yt)],img_bbox_xy)[0]
            print(off_lab[i],(st,lt),'->\t',(xt,yt),'in image bbox=',in_img[i])

        if not in_img.any():
            print('WARNING: tile outside of bounding box')
        elif not in_img.all():
            nzo = (in_img==0).sum()
            print('WARNING: tile overlaps bounding box (%d points outside)'%nzo)

    tile_geo_t = (x_tile, geo_t[1], geo_t[2], y_tile, geo_t[4], geo_t[5])
    gsub = gtifdrv.Create(tilef, ns_tile, nl_tile, nb_tile, GDT_Byte)
    gsub.SetGeoTransform(tile_geo_t)
    gsub.SetProjection(geo_p)
    for i in range(nb_tile):        
        gsub.GetRasterBand(i+1).WriteArray(tileimg[:,:,i])
    gsub = None

def rdn2rgb(rdn):
    rdnmin,rdnmax = extrema(rdn,p=0.99)
    return np.clip((rdn-rdnmin)/(rdnmax-rdnmin),0.0,1.0)

def array2img(outf,img,mapinfostr=None,bandnames=None,**kwargs):
    outhdrf = outf+'.hdr'
    if pathexists(outf) and not kwargs.pop('overwrite',False):
        print('Cannot write image to path',outf)
        return
        
    img = np.atleast_3d(img)
    outmeta = dict(samples=img.shape[1], lines=img.shape[0], bands=img.shape[2],
                   interleave='bip')
    
    outmeta['file type'] = 'ENVI'
    outmeta['byte order'] = 0
    outmeta['header offset'] = 0
    outmeta['data type'] = envitypecode(img.dtype)

    if mapinfostr:
        outmeta['map info'] = mapinfostr

    if bandnames:
        outmeta['band names'] = '{%s}'%", ".join(bandnames)
        
    outmeta.setdefault('data ignore value',NODATA)

    outimg = createimg(outhdrf,outmeta)
    outmm = openimgmm(outimg,writable=True)
    outmm[:] = img
    outmm = None # flush output to disk
    print('saved %s array to %s'%(str(img.shape),outf))

def mad(a,axis=0,medval=None,c=0.67448975019608171):
    '''
    computes the median absolute deviation of a list of values
    mad = median(abs(a - medval))/c
    '''
    from statsmodels.robust.scale import mad as _mad
    center = medval or np.median
    # return np.median(np.abs(np.asarray(a)-medval))
    return _mad(a, c=c, axis=axis, center=center)

    
def kde(img,k):
    from scipy.ndimage import gaussian_filter
    imgkde = gaussian_filter(img,sigma=k,truncate=1)
    imgkde = (imgkde-imgkde.min())/(imgkde.max()-imgkde.min())
    return img*imgkde

def absnorm(img,mask):
    assert((len(img.shape)==2))
    print('normalizing image to absolute range')
    i32 = np.float32(img)
    imax = np.abs(i32[~mask]).max()
    imin = -imax
    imgn=np.clip((i32-imin)/(imax-imin),0.0,1.0)
    return imgn,imin,imax

def smoothbil(img, mask, d, sigmaColor, sigmaSpace, normalize=True):
    from cv2 import bilateralFilter
    print('running bilateralFilter')
    if normalize:
        imgn,imin,imax  = absnorm(img,mask)
    else:
        imgn = img.copy()
        imin,imax = extrema(img[~mask])
    imgn = bilateralFilter(imgn, d, sigmaColor, sigmaSpace)
    imgn = imin+(imgn*(imax-imin))
    return imgn

def relabel_sequential(labimg,**kwargs):
    from skimage.segmentation import relabel_sequential as _rls
    return _rls(labimg,**kwargs)

def remove_small_objects(labimg,**kwargs):
    from skimage.morphology import remove_small_objects as _rso
    kwargs.setdefault('min_size',9)
    kwargs.setdefault('in_place',False)
    return _rso(labimg,**kwargs)

def filtdet(ch4mf,nodata_mask,mfmapinfo,minarea=minarea,mfmin=mfmin,mfmax=mfmax,
            k=kernel,mfminsmall=mfminsmall,skip_kde=False,use_abs=False,
            kde_outf=None,ccomp_outf=None,det_outf=None):
    from skimage.morphology import reconstruction
    print('filtering weakly-connected detections below minppm=%.2fppmm'%mfmin)
    mfmapstr=mapdict2str(mfmapinfo)
    detkde = np.abs(ch4mf) if use_abs else ch4mf.copy()
    ch4min = ch4mf >= mfmin
    if not skip_kde:
        detkde = kde(detkde,k=k)
    detkde = np.clip((detkde-mfmin)/(mfmax-mfmin),0.0,1.0)

    if not skip_kde and kde_outf:
        array2img(kde_outf,detkde,mapinfostr=mfmapstr,overwrite=True)
    detmask = (detkde>0) 
    print('%d candidate components'%imlabel(detmask).max())
    if 0:
        print('%d components before hole removal'%imlabel(detmask).max())    
        seed = detmask.copy()
        seed[1:-1, 1:-1] = detmask.max()
        print('removing interior holes')        
        detmask = np.bool8(reconstruction(seed, detmask, method='erosion'))
        print(extrema(ch4mf[detmask]))
        print('%d components after hole removal'%imlabel(detmask).max())

    print('removing detections with <= minarea=%d pixels'%(minarea))
    detsmall = detmask.copy()
    detmask = remove_small_objects(detmask,min_size=minarea,in_place=False)
    print('%d remaining detections'%imlabel(detmask).max())
    if mfminsmall >= mfmin:
        print('adding small detections with mf>=%f ppmm'%(mfminsmall))
        smallcc = imlabel(detsmall!=detmask)
        smallkeep = np.unique(smallcc[(ch4mf>=mfminsmall)])
        smallkeep = smallkeep[smallkeep!=0]
        print(len(smallkeep),'small detections found')
        smallmask = (np.in1d(smallcc.ravel(),smallkeep).reshape(detmask.shape))
        detmask = (detmask | np.bool8(smallmask)) 

    # exclude nodata+ch4min afterward to exclude interior holes from ccimg
    detcomp = imlabel(detmask)
    detcomp[~ch4min] = 0
    detcomp,_,_ = relabel_sequential(detcomp)
    print('%d final detections'%detcomp.max())
    
    if ccomp_outf:
        detcomp[nodata_mask] = NODATA
        array2img(ccomp_outf,detcomp,mapinfostr=mfmapstr,overwrite=True)

    detkde[~ch4min] = 0
    detkde[nodata_mask] = 0
    detcomp[nodata_mask] = 0
    
    print('detected %d components'%detcomp.max())
    if det_outf:
        detfilt = ch4mf.copy()
        detfilt[~ch4min]=0        

        detfilt[nodata_mask] = NODATA
        array2img(det_outf,detfilt,mapinfostr=mfmapstr,overwrite=True)

    return detkde, detcomp

def loadmaskedimage(maskedimgf,rgb_bands=[],masked_value=np.nan,
                    load_bands=[],memmap=False,verbose=0):
    maskeddir,maskedfile = pathsplit(maskedimgf)    
    maskedimg = envi_open_file(maskedimgf+'.hdr',image=maskedimgf)
    rows,cols,bands = maskedimg.shape
    if verbose:
        print('Loading [%d,%d,%d] masked input image: "%s"'%(rows,cols,bands,maskedimgf))
    if memmap:
        maskeddata = maskedimg.open_memmap(interleave='source',writeable=False)
    else:
        nload=len(load_bands)
        if nload>1:
            maskeddata = maskedimg.read_bands(load_bands)
        elif nload==1:
            maskeddata = maskedimg.read_bands([load_bands[0]])
        else:
            maskeddata = maskedimg.load() # load everything

    maskeddata = np.float32(maskeddata)
    if maskeddata.ndim == 2:
        maskeddata = maskeddata[...,np.newaxis]
    nodata_value = float(maskedimg.metadata.get('data ignore value',np.nan))
    nodata_mask = (maskeddata==nodata_value).any(axis=2)
    if not memmap:
        maskeddata[nodata_mask] = masked_value
    else:
        masked_value = nodata_value

    print(np.count_nonzero(nodata_mask),'nodata pixels in image',maskedimgf)

    outdata = dict(mapinfo=mapinfo(maskedimg,astype=dict),
                   nodata_mask=nodata_mask,
                   nodata_value=nodata_value)

    if bands>=3 and len(rgb_bands)==3:
        # ang cmf image: 3 rgb + 1 cmf band
        image_bands = list(set(range(bands))-set(rgb_bands))
        outdata['rgb'] = rdn2rgb(maskeddata[:,:,rgb_bands])
        if len(image_bands)!=0:
            outdata['image'] = maskeddata[:,:,image_bands].squeeze()
    elif bands==2 and len(set(rgb_bands))==1:
        # avcl cmf image: 1 cmf band + 1 grayscale 
        image_bands = list(set(range(bands))-set(rgb_bands))
        outdata['rgb'] = rdn2rgb(maskeddata[:,:,rgb_bands])
        if len(image_bands)!=0:
            outdata['image'] = maskeddata[:,:,image_bands].squeeze()        
    else:
        outdata['image'] = maskeddata.squeeze()
    
    return outdata

def rgb2labimg(rgbimg):
    assert(rgbimg.shape[2]==3)
    
    labimg = np.zeros(rgbimg.shape[:2],dtype=np.uint8)

    # point sources = [255,0,0], diffuse sources = [0,0,255]
    labmask = rgbimg.sum(axis=2)==255 # only 1 channel==255
    labimg[labmask & (rgbimg[:,:,0]==255)] = POINTSRC
    labimg[labmask & (rgbimg[:,:,2]==255)] = DIFFSRC

    return labimg


def loadlabimg(labf):
    print('loading label mask image: "%s"'%labf)
    labpath,labfile = pathsplit(labf)
    filebase,fileext = splitext(labfile)
    if fileext=='.png':
        labimg = imread(labf)
        if labimg.shape[2] in (3,4):
            labimg = rgb2labimg(labimg[:,:,:3]).squeeze()

    elif fileext=='' and filebase.endswith('class'):
        # envi class map
        imgdat = envi_open_file(labf+'.hdr',image=labf)
        #labimg = imgdat.open_memmap(interleave='source',writeable=False).copy()
        labimg = imgdat.load().squeeze()
    else:
        raise Exception('Unrecognized format %s'%labf)

    labimg = np.uint8(labimg)
    ulab = np.unique(labimg)
    assert(((ulab==0)|(ulab==POINTSRC)|(ulab==DIFFSRC)).all())
    
    return labimg

def loadfiltdet(detfilt_imgf):
    print('loading filtered detection image: "%s"'%detfilt_imgf)    
    detdir,detfile = pathsplit(detfilt_imgf)    
    detimg = envi_open_file(detfilt_imgf+'.hdr',image=detfilt_imgf)
    ch4det = np.float32(detimg.load().squeeze())
    nodata_value = float(detimg.metadata['data ignore value'])
    nodata_mask = ch4det==nodata_value
    ch4det[nodata_mask] = 0
    return dict(ch4det=ch4det,mapinfo=mapinfo(detimg,astype=dict),
                nodata_mask=nodata_mask,nodata_value=nodata_value)

def loaddetids(detid_imgf):
    print('loading detection id image: "%s"'%detid_imgf)    
    detdir,detfile = pathsplit(detid_imgf)    
    detimg = envi_open_file(detid_imgf+'.hdr',image=detid_imgf)
    detids = np.float32(detimg.load().squeeze())
    detmeta = detimg.metadata
    nodata_value = float(detmeta['data ignore value'])
    nodata_mask = detids==nodata_value
    detids[nodata_mask] = 0
    return dict(detids=detids,mapinfo=mapinfo(detimg,astype=dict),
                nodata_mask=nodata_mask,nodata_value=nodata_value)

def loadsaliencemap(salience_imgf):
    print('loading',salience_imgf)
    saliencedir,saliencefile = pathsplit(salience_imgf)
    salienceimg = envi_open_file(salience_imgf+'.hdr',image=salience_imgf)
    saliencemapinfo = mapinfo(salienceimg,astype=dict)
    saliencemap = np.float32(salienceimg.load().squeeze())
    return dict(saliencemap=saliencemap,mapinfo=saliencemapinfo)

def bbox(points,border=0,imgshape=[]):
    """
    bbox(points) 
    computes bounding box of extrema in points array

    Arguments:
    - points: [N x 2] array of [rows, cols]
    """
    from numpy import atleast_2d
    points = atleast_2d(points)
    minv = points.min(axis=0)
    maxv = points.max(axis=0)
    difv = maxv-minv
    
    if isinstance(border,list):
        rborder,cborder = border
        rborder = rborder if rborder > 1 else int(rborder*difv[0])
        cborder = cborder if cborder > 1 else int(cborder*difv[1])
    elif border < 1:
        rborder = cborder = int(border*difv.mean()+0.5)
    elif border == 'mindiff':
        rborder = cborder = min(difv)
    elif border == 'maxdiff':
        rborder = cborder = max(difv)
    else:
        rborder = cborder = border
        
    if len(imgshape)==0:
        imgshape = maxv+max(rborder,cborder)+1
    
    rmin,rmax = max(0,minv[0]-rborder),min(imgshape[0],maxv[0]+rborder+1)
    cmin,cmax = max(0,minv[1]-cborder),min(imgshape[1],maxv[1]+cborder+1)    
    
    return (rmin,cmin),(rmax,cmax)
    
def maskbbox(mask,border=0):
    """
    maskbbox(mask,border=0) 
    computes the bounding box of nonzero pixel coords for input mask
    
    Arguments:
    - mask: [n x m] binary image mask
    
    Keyword Arguments:
    - border: number of pixels to use in padding bounding box (default=0)
    
    Returns:
    - bbox = (rowmin,colmin),(rowmax,colmax)
    """
    
    points = np.c_[np.where(mask)]
    return bbox(points,border,mask.shape)

def region_maxima(img,mask,return_index=False):
    from skimage.measure import regionprops
    ccimg = imlabel(mask)
    ulab = np.unique(ccimg[ccimg!=0])
    rcidx,rcmax = [],[]
    for r in regionprops(ccimg,intensity_image=img):        
        rcmax.append(r.max_intensity)
        if return_index:
            rc=r.coords
            rcidx.append(r.coords[np.argmax(img[rc[:,0],rc[:,1]])])
    rcmax=np.array(rcmax,dtype=img.dtype)
    if return_index:
        return rcmax,np.array(rcidx,dtype=np.int)
    return rcmax

def local_maxima(im,rad,image_max=[]):
    from skimage.feature import peak_local_max
    
    diam = 2*rad
    if 0:
        # image_max is the dilation of im with a diam x diam structuring element
        # It is used within peak_local_max function
        from scipy.ndimage import maximum_filter
        if len(image_max)==0:
            image_max = np.zeros(im.shape[:2])
        maximum_filter(im, size=diam, mode='constant', output=image_max)
    
    # Comparison between image_max and im to find the coordinates of local maxima
    return peak_local_max(im, min_distance=diam)

def runcmd(cmd,verbose=0):
    from subprocess import Popen, PIPE
    cmdstr = ' '.join(cmd) if isinstance(cmd,list) else cmd
    if verbose:
        print("running command:",cmdstr)
    cmdout = PIPE
    for rstr in ['>>','>&','>']:
        if rstr in cmdstr:
            cmdstr,cmdout = map(lambda s:s.strip(),cmdstr.split(rstr))
            mode = 'w' if rstr!='>>' else 'a'
            cmdout = open(cmdout,mode)
            
    p = Popen(cmdstr.split(), stdout=cmdout, stderr=cmdout)
    out, err = p.communicate()
    retcode = p.returncode

    if cmdout != PIPE:
        cmdout.close()
    
    return out,err,retcode

def retrieve_rgb(rgbf):
    wgetbin='wget --no-verbose'
    wgeturl='https://avirisng.jpl.nasa.gov'
    wgetrgb='{wgetbin} -O {rgbdir}/{rgbfile} {wgeturl}/aviris_locator/y{rgbyear}_RGB/{rgbfile}'
    wgetgeo='{wgetbin} -O {rgbdir}/{rgbfile} {wgeturl}/ql/{rgbyear}qlook/{lid}_geo.jpeg'

    wgetretc = 1
    lid = filename2flightid(rgbf)
    if not lid.startswith('ang'):
        raise Exception('retrieve_rgb only works with AVIRIS-NG flightlines')

    if pathexists(rgbf):        
        return 0
    
    try:
        (rgbdir,rgbfile),rgbyear = pathsplit(rgbf),lid[5:7]
        print(rgbdir,rgbfile,rgbyear)
        wgetcmd = wgetrgb if rgbyear!='17' else wgetgeo
        wgetcmd = wgetcmd.format(**vars())
        print(wgetcmd)
        wgetout = runcmd(wgetcmd)
        wgetretc = wgetout[-1]
        if wgetretc != 0:
            print(rgbf,'wget failure, skipping')
            print('wget output:', wgetout[1])
        
    except Exception as e:
        print(rgbf,'not found and unable to retreive, skipping')        

    return wgetretc



if __name__ == '__main__':
    rgbf= 'tiles/tag_mini_intensive/rgb/ang20160922t183143_RGB.jpeg'
    print(retrieve_rgb(rgbf))
    raw_input()
    cmap = 'inferno'
    tilefloat = np.random.rand(100,100)
    tilergba = float2rgba(tilefloat,cmap=cmap)
    rgbafloat = rgba2float(tilergba,cmap=cmap)
    diff = np.sqrt((tilefloat-rgbafloat)**2)
    print(rgbafloat.shape,diff.mean())
    import pylab as pl
    fig,ax = pl.subplots(1,4,sharex=True,sharey=True)
    ax[0].imshow(tilefloat,cmap=cmap,vmin=0,vmax=1)
    ax[1].imshow(tilergba[...,:3])
    ax[2].imshow(rgbafloat,cmap=cmap,vmin=0,vmax=1)
    ax[3].imshow(diff,vmin=0,vmax=0.01)
    pl.show()
    
    if 0:
        baseimg='./data/y16/cmf/ort/ang20160913t205656_cmf_v1n2_img'
        #baseimg='/lustre/ang/y16/cmf/ort/ang20160913t205656_cmf_v1n2_img'    
        tileimg = 255*np.ones([100,100,4],dtype=np.uint8)
        tilepos=(12360,728)
        tilefile='test.tif'
        tile2geotiff(tileimg,tilepos,tilefile,baseimg)
        raw_input()
        z=np.random.rand(10)
        z[3:6] = np.nan
        y=np.random.rand(10,10)
        y[1:5] = np.nan
        r = array2rgba(z)
        print(r,r.shape)
        r = array2rgba(y)
        print(r,r.shape)
        raw_input()
    import pylab as pl
    datadir='./data'    
    #datadir='/lustre/ang'
    if 1:
        cmfdir=datadir+'/y16/cmf/ort'
        labdir=pathjoin(cmfdir,'thorpe_training')
        detdir='tiles/thorpe_training/100'

        labimgf=pathjoin(labdir,'ang20160914t180328_cmf_v1n2_img_class')
        cmfimgf=pathjoin(cmfdir,'ang20160914t180328_cmf_v1n2_img')
        #filtdetimgf=pathjoin(detdir,'ang20160914t180328_cmf_v1n2_filt_det_500_1500')
        filtdetimgf='./tiles/thorpe_training/150/ang20160914t180328_cmf_v1n2_filt_det_350_1500'
        
    else:
        cmfdir=datadir+'/y15/cmf/ort'
        labdir=pathjoin(cmfdir,'thompson_training')
        detdir='tiles/thompson_training/100/'
        labimgf=pathjoin(labdir,'ang20150419t161445_cmf_v1f_img_mask.png')
        cmfimgf=pathjoin(cmfdir,'ang20150419t161445_cmf_v1f_img')        
        filtdetimgf = pathjoin(detdir,'ang20150419t161445_cmf_v1f_filt_det_500_1500')

    tiledim=150
    tilepos = 4678,1115
    tileimgf = './tiles/thorpe_training/150/ang20160914t180328/tp/tp_det4678_1115.png'
    pngimgf=filtdetimgf+'.png'
    sys.path.insert(0,'/Users/bbue/Research/src/python/util/tilepredictor/')
    from tilepredictor_util import imread_image,imread_tile
    
    tileimg = imread_tile(tileimgf,tile_shape=[tiledim,tiledim])
    pngimg = imread_image(pngimgf)
    tilepng = pngimg[tilepos[0]:tilepos[0]+tiledim,tilepos[1]:tilepos[1]+tiledim]
    cmfdata = loadmaskedimage(cmfimgf,rgb_bands=range(3))        
    detdata = loadfiltdet(filtdetimgf)    
    labimg = loadlabimg(labimgf)
    tilelab = labimg[tilepos[0]:tilepos[0]+tiledim,tilepos[1]:tilepos[1]+tiledim]
    # make sure nodata masks match 
    nodata_mask = detdata['nodata_mask']
    assert(cmfdata['nodata_mask'].sum()==nodata_mask.sum())

    ch4det = detdata['ch4det']    
    rgbimg = cmfdata['rgb']
    # remove labeled pixels in nodata regions
    labimg = np.float32(labimg)
    ch4det[nodata_mask | (ch4det<500)] = np.nan
    labimg[~(labimg>0)] = np.nan


    
    fig,ax = pl.subplots(3,1,sharex=True,sharey=True,num=1)
    ax[0].imshow(tileimg)
    ax[1].imshow(tilepng)
    ax[2].imshow(tilelab)
    pl.show()

    pl.clf()
    ax[0].imshow(rgbimg.swapaxes(0,1))
    ax[1].imshow(pngimg.swapaxes(0,1))
    pl.show()
    
    ax[0].imshow(rgbimg.swapaxes(0,1))
    ax[0].imshow(ch4det.swapaxes(0,1),vmin=500,vmax=1500,cmap='YlOrRd')
    ax[1].imshow(labimg.swapaxes(0,1),vmin=0,vmax=1,cmap='Spectral',alpha=0.7)
    ax[1].imshow(ch4det.swapaxes(0,1),vmin=500,vmax=1500,cmap='YlOrRd')
    pl.show()

    
