from __future__ import absolute_import, division, print_function

import sys,os
from warnings import filterwarnings, warn
warnings_ignore = [
    'is a low contrast image', 
]

for msg in warnings_ignore:
    filterwarnings("ignore", message='.*%s.*'%msg)

import numpy as np

from os.path import join as pathjoin, exists as pathexists, split as pathsplit
from os.path import splitext, abspath, getsize, realpath, isabs, expandvars
from os.path import expanduser

mkdir = os.makedirs

pathfile   = lambda path: pathsplit(path)[1] # /path/to/file.ext -> file.ext
pathbase   = lambda path: splitext(pathfile(path))[0] # /path/to/file.ext -> file

from collections import OrderedDict
from spectral.io.envi import open as envi_open_file
import LatLongUTMconversion as LLUTMConv


DATUM_WGS84 = 23 # WGS-84
DEG2RAD = np.pi/180.0

NODATA = -9999
    
# reused constants
ORDER_NN    = 0
ORDER_BILIN = 1
ORDER_QUAD  = 2
ORDER_CUBIC = 3

CONN4 = 1
CONN8 = 2

POINTSRC = 1
DIFFSRC = 2

use_absmf=False

kernel=50
mfmin,mfmax = 500,1500
minarea=9
mfminsmall  = 1250

def imsave(fname,img,**kwargs):
    from skimage.io import imsave as _imsave    
    if kwargs.pop('verbose',False):
        print('saving',fname)
    return _imsave(fname,img,**kwargs)

def array2rgba(a,**kwargs):
    '''
    converts a 1d array into a set of rgba values according to
    the default or user-provided colormap
    '''
    from pylab import get_cmap,rcParams
    from numpy import isnan,clip,uint8,where
    cm = get_cmap(kwargs.get('cmap',rcParams['image.cmap']))
    aflat = np.float32(a.copy().ravel())
    nanmask = isnan(aflat)
    nvalid = np.count_nonzero(~nanmask)
    avals = aflat[~nanmask]
    vmin = float(kwargs.pop('vmin',avals.min()))
    vmax = float(kwargs.pop('vmax',avals.max()))
    if nvalid>0 and vmax>vmin:                
        aflat[nanmask] = vmin # use vmin to map to zero, below
        aflat = clip(((aflat-vmin)/(vmax-vmin)),0.,1.)
        rgba = uint8(cm(aflat)*255)
        if nanmask.any():
            nanr = np.where(nanmask)
            rgba[nanr[0],:] = 0
        rgba = rgba.reshape(list(a.shape)+[4])
    else:
        if nvalid==0:
            warn('all values nan')
        elif vmax<=vmin:
            warn('vmax <= vmin')
        rgba = np.zeros(list(a.shape)+[4],dtype=uint8)
    return rgba

def rgbdet2ql(rgb,det,detmask):
    ch4idx = np.where(detmask)
    ch4rgba = array2rgba(det[ch4idx],cmap='YlOrRd')
    rgbimg = np.uint8(255*rgb.copy())
    rgbimg[ch4idx[0],ch4idx[1],:3] = ch4rgba[:,:3]
    return rgbimg #.transpose((1,0,2))[::-1])

def progressbar(caption,maxval=None):
    """
    progress(title,maxval=None)
    
    Summary: progress bar wrapper
    
    Arguments:
    - caption: progress bar caption
    - maxval: maximum value
    
    Keyword Arguments:
    None 
    
    Output:
    - progress bar instance (pbar.update(i) to step, pbar.finish() to close)
    """
    
    from progressbar import ProgressBar, Bar, UnknownLength
    from progressbar import Percentage, Counter, ETA

    capstr = caption + ': ' if caption else ''
    if maxval is not None:    
        widgets = [capstr, Percentage(), ' ', Bar('='), ' ', ETA()]
    else:
        maxval = UnknownLength
        widgets = [capstr, Counter(), ' ', Bar('=')]

    return ProgressBar(widgets=widgets, maxval=maxval)    

def findboundaries(labimg,**kwargs):
    from skimage.segmentation import find_boundaries
    kwargs.setdefault('connectivity',CONN8)
    return find_boundaries(labimg,**kwargs)

def thickboundaries(labimg):
    return findboundaries(labimg,mode='thick')

def innerboundaries(labimg):
    return findboundaries(labimg,mode='inner')

def outerboundaries(labimg):
    return findboundaries(labimg,mode='outer')

def createimg(hdrf,metadata,**kwargs):
    from spectral.io.envi import create_image
    return create_image(hdrf, metadata, ext='', force=True)

def openimgmm(img,interleave='source',writable=False):
    img_mm = img.open_memmap(interleave=interleave, writable=writable)
    return img_mm

def imlabel(img,**kwargs):
    from skimage.measure import label as _label
    kwargs.setdefault('connectivity',CONN8)
    return _label(img,**kwargs)

def disk(radius,**kwargs):
    from skimage.morphology import disk as _disk
    return _disk(radius,**kwargs)

def bwdilate(bwimg,**kwargs):
    from skimage.morphology import binary_dilation as _bwd
    kwargs.setdefault('selem',disk(3))
    return _bwd(bwimg,**kwargs)

def bwdist(bwimg,**kwargs):
    metric = kwargs.get('metric','euclidean')
    if metric=='euclidean':
        kwargs.pop('metric',None) # metric is only option for edt
        from scipy.ndimage.morphology import distance_transform_edt as _bwdist
    elif metric in ('chessboard','taxicab'):
        from scipy.ndimage.morphology import distance_transform_cdt as _bwdist
    kwargs.setdefault('return_distances',True)
    kwargs.setdefault('return_indices',False)
    return _bwdist(bwimg,**kwargs)

def mergelabels(labimg,mergedist,return_merged=False,doplot=False):
    # merge labeled regions <= mergedist pixels from each other
    fgmask = (labimg!=0)
    mergecomp = imlabel(bwdist(~fgmask,metric='chessboard')<=mergedist)
    mergelab = np.unique(mergecomp)
    outimg = np.zeros_like(labimg)
    outmerged = {}
    for i,ml in enumerate(mergelab[mergelab!=0]):
        mlmask = (mergecomp==ml) & fgmask
        outimg[mlmask] = i+1
        if return_merged:
            outmerged[ml] = np.unique(labimg[mlmask])
    print(len(np.unique(labimg))-1,'labels before merge',len(mergelab)-1,'after merge')
    if doplot:
        import pylab as pl
        fig,ax = pl.subplots(1,2,sharex=True,sharey=True)
        ax[0].imshow(labimg)
        ax[0].set_title('before')
        ax[1].imshow(outimg)
        ax[1].set_title('after')
        pl.show()
    if return_merged:
        return outimg, outmerged
    return outimg

def imread(fname,**kwargs):
    from skimage.io import imread as _imread
    kwargs.setdefault('plugin',None)
    return _imread(fname,**kwargs)

def imresize(img,output_shape,**kwargs):
    from skimage.transform import resize as _imresize
    
    kwargs.setdefault('order',ORDER_NN) 
    kwargs.setdefault('clip',False)
    kwargs.setdefault('preserve_range',True)
    if kwargs.pop('anti_alias',False):
        from scipy import ndimage as ndi
        cval = kwargs.get('cval',0)
        mode = kwargs.get('mode','constant')
        sigma = kwargs.pop('anti_alias_sigma',None)
        if sigma is None:
            factors = (np.asarray(img.shape, dtype=float) /
                       np.asarray(output_shape, dtype=float))
            sigma = np.maximum(0, (factors - 1) / 2)
        imgrs = ndi.gaussian_filter(img, sigma, cval=cval, mode=mode)
    else:
        imgrs = img
    
    return _imresize(imgrs,output_shape,**kwargs)

def filename2flightid(filename):
    '''
    get flight id from filename
    ang20160922t184215_cmf_v1g_img -> ang20160922t184215
    '''
    imgid = pathbase(filename).split('_')[0]
    return imgid

def filename2flightdate(filename):
    '''
    get flight id from filename
    ang20160922t184215_cmf_v1g_img -> ang20160922t184215
    '''
    flightid = filename2flightid(filename)
    datestr = flightid.split('t')[0][-8:]
    yyyy,mm,dd = map(int,[datestr[:4],datestr[4:6],datestr[6:]])
    return yyyy,mm,dd

def shortstr(s,width=20,placeholder='...',quote=False):
    shortstr = s if len(s) <= width else s[:width]+placeholder
    return '"%s"'%shortstr if quote else shortstr

def filename2calid(infile):
    # e.g., ./ort/ang20160915t194328_cmf_v1n2_img -> v1n2
    _,filen = pathsplit(infile)
    return filen.split('_')[2]

l1calid = filename2calid

def extrema(a,**kwargs):
    p = kwargs.pop('p',1.0)
    if p==1.0:
        return np.amin(a,**kwargs),np.amax(a,**kwargs)
    elif p==0.0:
        return np.amax(a,**kwargs),np.amin(a,**kwargs)
    assert(p>0.0 and p<1.0)
    axis = kwargs.pop('axis',None)
    apercent = lambda q: np.nanpercentile(a,axis=axis,q=q,
                                          interpolation='nearest')
    return apercent((1-p)*100),apercent(p*100)

def envitypecode(np_dtype):
    from numpy import dtype
    from spectral.io.envi import dtype_to_envi
    _dtype = dtype(np_dtype).char
    return dtype_to_envi[_dtype]

def extract_tile(img,ul,tdim,verbose=False,transpose=None,cval=0):
    '''
    extract a tile of dims (tdim,tdim,img.shape[2]) offset from upper-left 
    coordinate ul in img, zero pads when tile overlaps image extent 
    '''
    ndim = img.ndim
    if ndim==3:
        nr,nc,nb = img.shape
    elif ndim==2:
        nr,nc = img.shape
        nb = 1
    else:
        raise Exception('invalid number of image dims %d'%ndim)
    
    lr = (ul[0]+tdim,ul[1]+tdim)
    padt,padb = abs(max(0,-ul[0])), tdim-max(0,lr[0]-nr)
    padl,padr = abs(max(0,-ul[1])), tdim-max(0,lr[1]-nc)
    
    ibeg,iend = max(0,ul[0]),min(nr,lr[0])
    jbeg,jend = max(0,ul[1]),min(nc,lr[1])

    if verbose:
        print(ul,nr,nc)
        print(padt,padb,padl,padr)
        print(ibeg,iend,jbeg,jend)

    imgtile = cval*np.ones([tdim,tdim,nb],dtype=img.dtype)
    imgtile[padt:padb,padl:padr] = np.atleast_3d(img[ibeg:iend,jbeg:jend])
    if transpose is not None:
        imgtile = imgtile.transpose(transpose)
    return imgtile

def rotxy(x,y,adeg,xc,yc):
    """
    rotxy(x,y,adeg,xc,yc)

    Summary: rotate point x,y about xc,yc by adeg degrees

    Arguments:
    - x: x coord to rotate
    - y: y coord to rotate
    - adeg: angle of rotation in degrees
    - xc: center x coord
    - yc: center y coord

    Output:
    rotated x,y point
    """
    arad = DEG2RAD*adeg
    sinr,cosr = np.sin(arad),np.cos(arad)
    rotm = [[cosr,-sinr],[sinr,cosr]]
    xyp = np.dot(rotm,np.c_[x-xc,y-yc].T)

    # handle scalar outputs
    if xyp.ndim==2 and xyp.shape[1]==1:
        xyp = xyp.squeeze()
    return xyp[0]+xc,xyp[1]+yc
        
def meshpositions(*arrs):
    grid = np.meshgrid(*arrs)
    positions = np.vstack(map(np.ravel, grid))
    return positions

def chull(p,**kwargs):
    from scipy.spatial import ConvexHull as _chull
    return_indices = kwargs.pop('return_indices',False)
    hullidx = _chull(p,**kwargs).vertices
    hullp = p[hullidx]
    if return_indices:
        return hullp,hullidx
    return hullp

def utm2latlon(easting,northing,zone,hemi='North',alpha=None,datum=DATUM_WGS84):
    if hemi not in ('North','South'):
        print('invalid hemisphere value=',hemi)
        return None,None
    
    zone_alpha = alpha or ('N' if hemi=='North' else 'M')
    lat,lon = LLUTMConv.UTMtoLL(datum,easting,northing,str(zone)+zone_alpha)
    return lat,lon

def sl2xy(s,l,**kwargs):
    """
    sl2xy(s,l,x0=0,y0=0,xps=0,yps=xps,rot=0,mapinfo=None) 

    Given integer pixel coordinates (s,l) convert to map coordinates (x,y)

    Arguments:
    - s,l: sample, line indices

    Keyword Arguments:
    - x0,y0: upper left map coordinate (default = (0,0))    
    - xps: x map pixel size (default=None)
    - yps: y map pixel size (default=xps)
    - rot: map rotation in degrees (default=0)
    - mapinfo: envi map info dict (entries replaced with kwargs above)    

    Returns:
    - x,y: x,y map coordinates of sample s, line l
    """
    mapinfo = kwargs.pop('mapinfo',{})    
    x0 = kwargs.pop('ulx',mapinfo.get('ulx',None))
    y0 = kwargs.pop('uly',mapinfo.get('uly',None))
    xps = kwargs.pop('xps',mapinfo.get('xps',None))
    yps = kwargs.pop('yps',mapinfo.get('yps',xps))
    rot = kwargs.pop('rot',mapinfo.get('rotation',0))

    if None in (x0,y0):
        raise ValueError("ulx or uly undefined")

    if None in (xps,yps):
        raise ValueError("xps or yps undefined")

    if yps == 0:
        yps = xps

    xp,yp = x0+xps*s, y0-yps*l        
    if rot == 0:
        return xp,yp

    ar = DEG2RAD*rot
    X, Y = rotxy(xp,yp,rot,x0,y0)
    #X = x0 + xps * s + ar  * l
    #Y = y0 + ar  * s - yps * l    
    return X, Y

def sl2latlon(s,l,**kwargs):
    mapinfo = kwargs.get('mapinfo',{})
    proj = mapinfo.get('proj',None)
    if not proj:
        raise ValueError("proj undefined")
    elif proj not in ('UTM','Geographic Lat/Lon'):   
        print('unknown projection:',proj)
        return None
    
    x,y = sl2xy(s,l,**kwargs)
    if proj=='Geographic Lat/Lon':
        return y,x

    return utm2latlon(y,x,zone=mapinfo['zone'],hemi=mapinfo['hemi'])

def xy2sl(x,y,**kwargs):
    """
    xy2sl(x,y,x0=0,y0=,xps=0,yps=xps,rot=0,mapinfo=None) 

    Given a orthocorrected image find the (s,l) values for a given (x,y)

    Arguments:
    - x,y: map coordinates

    Keyword Arguments:
    - x0,y0: upper left map coordinate (default = (0,0))
    - xps: x map pixel size (default=None)
    - yps: y map pixel size (default=xps)
    - rot: map rotation in degrees (default=0)
    - mapinfo: envi map info dict (xps,yps,rot override)

    Returns:
    - s,l: sample, line coordinates of x,y
    """
    mapinfo = kwargs.pop('mapinfo',{})

    x0 = kwargs.pop('ulx',mapinfo.get('ulx',None))
    y0 = kwargs.pop('uly',mapinfo.get('uly',None))
    xps = kwargs.pop('xps',mapinfo.get('xps',None))
    yps = kwargs.pop('yps',mapinfo.get('yps',xps))
    rot = kwargs.pop('rot',mapinfo.get('rotation',0))
    #if mapinfo and rot != 0:
    #    # flip sign of mapinfo rot unless otherwise specified
    #    rot = rot * kwargs.pop('rotsign',-1) 
    
    if None in (x0,y0):
        raise ValueError("either ulx or uly defined")

    if xps is None:
        raise ValueError("pixel size defined")

    yps = yps or xps
    xp, yp = (x-x0), (y0-y)
    if rot!=0:
        xp, yp = rotxy(xp,yp,rot,0,0)

    xp,yp = xp/xps,yp/yps
    return xp,yp

def latlon2utm(lat,lon,zone=None,datum=DATUM_WGS84):
    """
    latlon2utm(lat,lon,zone=None,datum=DATUM_WGS84) 

    Arguments:
    - lat: latitude coordinate(s)
    - lon: longitude coordinate(s)

    Keyword Arguments:
    - zone:   UTM zone number (default=None, computed automatically)
    - datum:  reference ellipsoid (default=_DATUM_WGS84)

    Returns:
    - easting:  UTM easting map coordinate
    - northing: UTM northing map coordinate
    - zone:     UTM zone digit
    - hemi:     hemisphere letter (hemi >= 'N' -> Northern hemisphere)

    """

    zonealpha,easting,northing = LLUTMConv.LLtoUTM(datum,lat,lon,
                                                   ZoneNumber=zone)
    return easting, northing, int(zonealpha[:-1]), zonealpha[-1]

def latlon2sl(lat,lon,**kwargs):
    mapinfo = kwargs.get('mapinfo',{})
    proj = mapinfo.get('proj',None)
    if not proj:
        raise ValueError("proj undefined")
    elif proj not in ('UTM','Geographic Lat/Lon'):   
        print('Unknown projection:',proj)
        return None
    
    if proj=='Geographic Lat/Lon':
        return xy2sl(lon,lat,mapinfo=mapinfo)

    zone = int(mapinfo['zone']) if 'zone' in mapinfo else None
    x,y,zone,zonealpha = latlon2utm(lat,lon,zone=zone)
    return xy2sl(x,y,mapinfo=mapinfo)

def mapdict2str(mapdict):
    mapmeta = mapdict.pop('metadata',[])
    mapkeys,mapvals = mapdict.keys(),mapdict.values()
    nargs = 10 if mapdict['proj']=='UTM' else 7
    maplist = map(str,mapvals[:nargs])
    mapkw = zip(mapkeys[nargs:],mapvals[nargs:])
    mapkw = [str(k)+'='+str(v) for k,v in mapkw]
    mapstr = '{ '+(', '.join(maplist+mapkw+mapmeta))+' }'
    return mapstr

def mapinfo(img,astype=dict):
    maplist = img.metadata.get('map info',None)
    if maplist is None:
        return None
    elif astype==list:
        return maplist    
    else:
        mapinfo = OrderedDict()
        mapinfo['proj'] = maplist[0]
        mapinfo['xtie'] = float(maplist[1])
        mapinfo['ytie'] = float(maplist[2])
        mapinfo['ulx']  = float(maplist[3])
        mapinfo['uly']  = float(maplist[4])
        mapinfo['xps']  = float(maplist[5])
        mapinfo['yps']  = float(maplist[6])

        if mapinfo['proj'] == 'UTM':
            mapinfo['zone']  = maplist[7]
            mapinfo['hemi']  = maplist[8]
            mapinfo['datum'] = maplist[9]

        mapmeta = []
        for mapitem in maplist[len(mapinfo):]:
            if '=' in mapitem:
                key,val = map(lambda s: s.strip(),mapitem.split('='))
                mapinfo[key] = val
            else:
                mapmeta.append(mapitem)

        mapinfo['rotation'] = float(mapinfo.get('rotation','0'))
        if len(mapmeta)!=0:
            print('unparsed metadata:',mapmeta)
            mapinfo['metadata'] = mapmeta

    if astype==str:
        return mapdict2str(mapinfo)

    return mapinfo

def tile2geotiff(tileimg,tilepos,tilef,baseimgf,validate_coords=True):
    from gdal import Open, GetDriverByName, GA_ReadOnly, GDT_Byte
    
    print('saving',tilef)  
    gtifdrv = GetDriverByName('Gtiff')
    nl_tile,ns_tile,nb_tile = tileimg.shape
    g = Open(baseimgf, GA_ReadOnly)
    geo_t = g.GetGeoTransform()
    geo_p = g.GetProjectionRef()
    rows,cols = g.RasterYSize,g.RasterXSize

    l0,s0 = tilepos
    #l0 = max(0,min(tilepos[0],rows-1))
    #s0 = max(0,min(tilepos[1],cols-1))
    #l1 = max(l0,min(tilepos[0]+tileimg.shape[0],cols-1))
    #s1 = max(s0,min(tilepos[1]+tileimg.shape[1],cols-1))
    x_tile = geo_t[0]+s0*geo_t[1]+l0*geo_t[2]
    y_tile = geo_t[3]+s0*geo_t[4]+l0*geo_t[5]
        
    if validate_coords:
        from skimage.measure import points_in_poly
        
        bbox_s = [0,0,cols,cols]
        bbox_l = [0,rows,rows,0]
        bbox_xy = []
        print(pathsplit(baseimgf)[1],'bounding box')
        for s,l in zip(bbox_s,bbox_l):
            x = geo_t[0]+s*geo_t[1]+l*geo_t[2]
            y = geo_t[3]+s*geo_t[4]+l*geo_t[5]
            bbox_xy.append((x,y))
            print((s,l),'->\t',(x,y))

        print('tile bounding box')
        trows,tcols = tileimg.shape[0],tileimg.shape[1]
        s_off = [0,0,tcols,tcols,round(tcols/2)]
        l_off = [0,trows,trows,0,round(trows/2)]
        inbbox = np.zeros(5,dtype=np.bool8)
        for i,(so,lo) in enumerate(zip(s_off,l_off)):
            st,lt = s0+so,l0+lo
            xt = geo_t[0]+st*geo_t[1]+lt*geo_t[2]
            yt = geo_t[3]+st*geo_t[4]+lt*geo_t[5]    
            inbbox[i] = points_in_poly([(xt,yt)],bbox_xy)[0]
            print((st,lt),'->\t',(xt,yt),'in bbox=',inbbox[i])

        
        if not inbbox.any():
            print('WARNING: tile outside of bounding box')
        elif not inbbox.all():
            nzo = (inbbox==0).sum()
            print('WARNING: tile overlaps bounding box (%d points outside)'%nzo)

    tile_geo_t = (x_tile, geo_t[1], geo_t[2], y_tile, geo_t[4], geo_t[5])
    gsub = gtifdrv.Create(tilef, ns_tile, nl_tile, nb_tile, GDT_Byte)
    gsub.SetGeoTransform(tile_geo_t)
    gsub.SetProjection(geo_p)
    for i in range(nb_tile):
        gsub.GetRasterBand(i+1).WriteArray(tileimg[:,:,i])
    gsub = None

def rdn2rgb(rdn):
    rdnmin,rdnmax = extrema(rdn,p=0.99)
    return np.clip((rdn-rdnmin)/(rdnmax-rdnmin),0.0,1.0)

def array2img(outf,img,mapinfostr=None,bandnames=None,**kwargs):
    outhdrf = outf+'.hdr'
    if pathexists(outf) and not kwargs.pop('overwrite',False):
        print('Cannot write image to path',outf)
        return
        
    img = np.atleast_3d(img)
    outmeta = dict(samples=img.shape[1], lines=img.shape[0], bands=img.shape[2],
                   interleave='bip')
    
    outmeta['file type'] = 'ENVI'
    outmeta['byte order'] = 0
    outmeta['header offset'] = 0
    outmeta['data type'] = envitypecode(img.dtype)

    if mapinfostr:
        outmeta['map info'] = mapinfostr

    if bandnames:
        outmeta['band names'] = '{%s}'%", ".join(bandnames)
        
    outmeta.setdefault('data ignore value',NODATA)

    outimg = createimg(outhdrf,outmeta)
    outmm = openimgmm(outimg,writable=True)
    outmm[:] = img
    outmm = None # flush output to disk
    print('saved %s array to %s'%(str(img.shape),outf))
    
def kde(img,k):
    from scipy.ndimage import gaussian_filter
    imgkde = gaussian_filter(img,sigma=k,truncate=1)
    imgkde = (imgkde-imgkde.min())/(imgkde.max()-imgkde.min())
    return img*imgkde

def relabel_sequential(labimg,**kwargs):
    from skimage.segmentation import relabel_sequential as _rls
    return _rls(labimg,**kwargs)

def remove_small_objects(labimg,**kwargs):
    from skimage.morphology import remove_small_objects as _rso
    kwargs.setdefault('min_size',9)
    kwargs.setdefault('in_place',False)
    return _rso(labimg,**kwargs)

def filtdet(ch4mf,nodata_mask,mfmapinfo,minarea=minarea,mfmin=mfmin,mfmax=mfmax,
            k=kernel,mfminsmall=mfminsmall,skip_kde=False,use_abs=False,
            kde_outf=None,ccomp_outf=None,det_outf=None):
    from skimage.morphology import reconstruction
    print('filtering weakly-connected detections below minppm=%.2fppmm'%mfmin)
    mfmapstr=mapdict2str(mfmapinfo)
    detkde = np.abs(ch4mf) if use_abs else ch4mf.copy()
    ch4min = ch4mf >= mfmin
    if not skip_kde:
        detkde = kde(detkde,k=k)
    detkde = np.clip((detkde-mfmin)/(mfmax-mfmin),0.0,1.0)

    if not skip_kde and kde_outf:
        array2img(kde_outf,detkde,mapinfostr=mfmapstr,overwrite=True)
    detmask = (detkde>0) 
    print('%d candidate components'%imlabel(detmask).max())
    if 0:
        print('%d components before hole removal'%imlabel(detmask).max())    
        seed = detmask.copy()
        seed[1:-1, 1:-1] = detmask.max()
        print('removing interior holes')        
        detmask = np.bool8(reconstruction(seed, detmask, method='erosion'))
        print(extrema(ch4mf[detmask]))
        print('%d components after hole removal'%imlabel(detmask).max())

    print('removing detections with <= minarea=%d pixels'%(minarea))
    detsmall = detmask.copy()
    detmask = remove_small_objects(detmask,min_size=minarea,in_place=False)
    print('%d remaining detections'%imlabel(detmask).max())
    if mfminsmall >= mfmin:
        print('adding small detections with mf>=%f ppmm'%(mfminsmall))
        smallcc = imlabel(detsmall!=detmask)
        smallkeep = np.unique(smallcc[(ch4mf>=mfminsmall)])
        smallkeep = smallkeep[smallkeep!=0]
        print(len(smallkeep),'small detections found')
        smallmask = (np.in1d(smallcc.ravel(),smallkeep).reshape(detmask.shape))
        detmask = (detmask | np.bool8(smallmask)) 

    # exclude nodata+ch4min afterward to exclude interior holes from ccimg
    detcomp = imlabel(detmask)
    detcomp[~ch4min] = 0
    detcomp,_,_ = relabel_sequential(detcomp)
    print('%d final detections'%detcomp.max())
    
    if ccomp_outf:
        detcomp[nodata_mask] = NODATA
        array2img(ccomp_outf,detcomp,mapinfostr=mfmapstr,overwrite=True)

    detkde[~ch4min] = 0
    detkde[nodata_mask] = 0
    detcomp[nodata_mask] = 0
    
    print('detected %d components'%detcomp.max())
    if det_outf:
        detfilt = ch4mf.copy()
        detfilt[~ch4min]=0        

        detfilt[nodata_mask] = NODATA
        array2img(det_outf,detfilt,mapinfostr=mfmapstr,overwrite=True)

    return detkde, detcomp


def loadmaskedimage(maskedimgf,rgb_bands=[],masked_value=np.nan):
    maskeddir,maskedfile = pathsplit(maskedimgf)    
    maskedimg = envi_open_file(maskedimgf+'.hdr',image=maskedimgf)
    rows,cols,bands = maskedimg.shape
    print('loading [%d,%d,%d] masked input image: "%s"'%(rows,cols,bands,maskedimgf))
    maskeddata = np.float32(maskedimg.load())
    if maskeddata.ndim == 2:
        maskeddata = maskeddata[...,np.newaxis]
    nodata_value = float(maskedimg.metadata['data ignore value'])
    nodata_mask = (maskeddata==nodata_value).any(axis=2)
    maskeddata[nodata_mask] = masked_value
    outdata = dict(mapinfo=mapinfo(maskedimg,astype=dict),
                   nodata_mask=nodata_mask,
                   nodata_value=nodata_value)

    
    if bands>=3 and len(rgb_bands)==3:
        image_bands = list(set(range(bands))-set(rgb_bands))
        outdata['rgb'] = rdn2rgb(maskeddata[:,:,rgb_bands])
        if len(image_bands)!=0:
            outdata['image'] = maskeddata[:,:,image_bands].squeeze()
    else:
        outdata['image'] = maskeddata.squeeze()
        
    
    return outdata

def rgb2labimg(rgbimg):
    assert(rgbimg.shape[2]==3)
    
    labimg = np.zeros(rgbimg.shape[:2],dtype=np.uint8)

    # point sources = [255,0,0], diffuse sources = [0,0,255]
    labmask = rgbimg.sum(axis=2)==255 # only 1 channel==255
    labimg[labmask & (rgbimg[:,:,0]==255)] = POINTSRC
    labimg[labmask & (rgbimg[:,:,2]==255)] = DIFFSRC

    return labimg

def loadlabimg(labf):
    print('loading label mask image: "%s"'%labf)
    labpath,labfile = pathsplit(labf)
    filebase,fileext = splitext(labfile)
    if fileext=='.png':
        labimg = imread(labf)
        if labimg.shape[2] in (3,4):
            labimg = rgb2labimg(labimg[:,:,:3]).squeeze()

    elif fileext=='' and filebase.endswith('class'):
        # envi class map
        imgdat = envi_open_file(labf+'.hdr',image=labf)
        #labimg = imgdat.open_memmap(interleave='source',writeable=False).copy()
        labimg = imgdat.load().squeeze()
    else:
        raise Exception('Unrecognized format %s'%labf)

    labimg = np.uint8(labimg)
    ulab = np.unique(labimg)
    assert(((ulab==0)|(ulab==POINTSRC)|(ulab==DIFFSRC)).all())
    
    return labimg

def loadfiltdet(detfilt_imgf):
    print('loading filtered detection image: "%s"'%detfilt_imgf)    
    detdir,detfile = pathsplit(detfilt_imgf)    
    detimg = envi_open_file(detfilt_imgf+'.hdr',image=detfilt_imgf)
    ch4det = np.float32(detimg.load().squeeze())
    nodata_value = float(detimg.metadata['data ignore value'])
    nodata_mask = ch4det==nodata_value
    ch4det[nodata_mask] = 0
    return dict(ch4det=ch4det,mapinfo=mapinfo(detimg,astype=dict),
                nodata_mask=nodata_mask,nodata_value=nodata_value)


def loaddetids(detid_imgf):
    print('loading detection id image: "%s"'%detid_imgf)    
    detdir,detfile = pathsplit(detid_imgf)    
    detimg = envi_open_file(detid_imgf+'.hdr',image=detid_imgf)
    detids = np.float32(detimg.load().squeeze())
    detmeta = detimg.metadata
    nodata_value = float(detmeta['data ignore value'])
    nodata_mask = detids==nodata_value
    detids[nodata_mask] = 0
    return dict(detids=detids,mapinfo=mapinfo(detimg,astype=dict),
                nodata_mask=nodata_mask,nodata_value=nodata_value)


def loadsaliencemap(salience_imgf):
    print('loading',salience_imgf)
    saliencedir,saliencefile = pathsplit(salience_imgf)
    salienceimg = envi_open_file(salience_imgf+'.hdr',image=salience_imgf)
    saliencemapinfo = mapinfo(salienceimg,astype=dict)
    saliencemap = np.float32(salienceimg.load().squeeze())
    return dict(saliencemap=saliencemap,mapinfo=saliencemapinfo)

def bbox(points,border=0,imgshape=[]):
    """
    bbox(points) 
    computes bounding box of extrema in points array

    Arguments:
    - points: [N x 2] array of [rows, cols]
    """
    from numpy import atleast_2d
    points = atleast_2d(points)
    minv = points.min(axis=0)
    maxv = points.max(axis=0)
    difv = maxv-minv
    
    if isinstance(border,list):
        rborder,cborder = border
        rborder = rborder if rborder > 1 else int(rborder*difv[0])
        cborder = cborder if cborder > 1 else int(cborder*difv[1])
    elif border < 1:
        rborder = cborder = int(border*difv.mean()+0.5)
    elif border == 'mindiff':
        rborder = cborder = min(difv)
    elif border == 'maxdiff':
        rborder = cborder = max(difv)
    else:
        rborder = cborder = border
        
    if len(imgshape)==0:
        imgshape = maxv+max(rborder,cborder)+1
    
    rmin,rmax = max(0,minv[0]-rborder),min(imgshape[0],maxv[0]+rborder+1)
    cmin,cmax = max(0,minv[1]-cborder),min(imgshape[1],maxv[1]+cborder+1)    
    
    return (rmin,cmin),(rmax,cmax)
    
def maskbbox(mask,border=0):
    """
    maskbbox(mask,border=0) 
    computes the bounding box of nonzero pixel coords for input mask
    
    Arguments:
    - mask: [n x m] binary image mask
    
    Keyword Arguments:
    - border: number of pixels to use in padding bounding box (default=0)
    
    Returns:
    - bbox = (rowmin,colmin),(rowmax,colmax)
    """
    
    points = np.c_[np.where(mask)]
    return bbox(points,border,mask.shape)


if __name__ == '__main__':

    if 0:
        baseimg='./data/y16/cmf/ort/ang20160913t205656_cmf_v1n2_img'
        #baseimg='/lustre/ang/y16/cmf/ort/ang20160913t205656_cmf_v1n2_img'    
        tileimg = 255*np.ones([100,100,4],dtype=np.uint8)
        tilepos=(12360,728)
        tilefile='test.tif'
        tile2geotiff(tileimg,tilepos,tilefile,baseimg)
        raw_input()
        z=np.random.rand(10)
        z[3:6] = np.nan
        y=np.random.rand(10,10)
        y[1:5] = np.nan
        r = array2rgba(z)
        print(r,r.shape)
        r = array2rgba(y)
        print(r,r.shape)
        raw_input()
    import pylab as pl
    datadir='./data'    
    #datadir='/lustre/ang'
    if 1:
        cmfdir=datadir+'/y16/cmf/ort'
        labdir=pathjoin(cmfdir,'thorpe_training')
        detdir='tiles/thorpe_training/100'

        labimgf=pathjoin(labdir,'ang20160914t180328_cmf_v1n2_img_class')
        cmfimgf=pathjoin(cmfdir,'ang20160914t180328_cmf_v1n2_img')
        #filtdetimgf=pathjoin(detdir,'ang20160914t180328_cmf_v1n2_filt_det_500_1500')
        filtdetimgf='./tiles/thorpe_training/150/ang20160914t180328_cmf_v1n2_filt_det_350_1500'
        
    else:
        cmfdir=datadir+'/y15/cmf/ort'
        labdir=pathjoin(cmfdir,'thompson_training')
        detdir='tiles/thompson_training/100/'
        labimgf=pathjoin(labdir,'ang20150419t161445_cmf_v1f_img_mask.png')
        cmfimgf=pathjoin(cmfdir,'ang20150419t161445_cmf_v1f_img')        
        filtdetimgf = pathjoin(detdir,'ang20150419t161445_cmf_v1f_filt_det_500_1500')

    tiledim=150
    tilepos = 4678,1115
    tileimgf = './tiles/thorpe_training/150/ang20160914t180328/tp/tp_det4678_1115.png'
    pngimgf=filtdetimgf+'.png'
    sys.path.insert(0,'/Users/bbue/Research/src/python/util/tilepredictor/')
    from tilepredictor_util import imread_image,imread_tile
    
    tileimg = imread_tile(tileimgf,tile_shape=[tiledim,tiledim])
    pngimg = imread_image(pngimgf)
    tilepng = pngimg[tilepos[0]:tilepos[0]+tiledim,tilepos[1]:tilepos[1]+tiledim]
    cmfdata = loadmaskedimage(cmfimgf,rgb_bands=range(3))        
    detdata = loadfiltdet(filtdetimgf)    
    labimg = loadlabimg(labimgf)
    tilelab = labimg[tilepos[0]:tilepos[0]+tiledim,tilepos[1]:tilepos[1]+tiledim]
    # make sure nodata masks match 
    nodata_mask = detdata['nodata_mask']
    assert(cmfdata['nodata_mask'].sum()==nodata_mask.sum())

    ch4det = detdata['ch4det']    
    rgbimg = cmfdata['rgb']
    # remove labeled pixels in nodata regions
    labimg = np.float32(labimg)
    ch4det[nodata_mask | (ch4det<500)] = np.nan
    labimg[~(labimg>0)] = np.nan


    
    fig,ax = pl.subplots(3,1,sharex=True,sharey=True,num=1)
    ax[0].imshow(tileimg)
    ax[1].imshow(tilepng)
    ax[2].imshow(tilelab)
    pl.show()

    pl.clf()
    ax[0].imshow(rgbimg.swapaxes(0,1))
    ax[1].imshow(pngimg.swapaxes(0,1))
    pl.show()
    
    ax[0].imshow(rgbimg.swapaxes(0,1))
    ax[0].imshow(ch4det.swapaxes(0,1),vmin=500,vmax=1500,cmap='YlOrRd')
    ax[1].imshow(labimg.swapaxes(0,1),vmin=0,vmax=1,cmap='Spectral',alpha=0.7)
    ax[1].imshow(ch4det.swapaxes(0,1),vmin=500,vmax=1500,cmap='YlOrRd')
    pl.show()
