from __future__ import absolute_import, division, print_function

import sys,os
from warnings import filterwarnings, warn
warnings_ignore = [
    'is a low contrast image', 
]

for msg in warnings_ignore:
    filterwarnings("ignore", message='.*%s.*'%msg)
    
import numpy as np

from os.path import join as pathjoin, exists as pathexists, split as pathsplit
from os.path import splitext, abspath, getsize, realpath, isabs, expandvars
from os.path import expanduser
from collections import OrderedDict
from spectral.io.envi import open as envi_open_file
import LatLongUTMconversion as LLUTMConv

mkdir = os.makedirs
mkdirs = mkdir

def filename(path):
    '''
    /path/to/file.ext -> file.ext
    '''
    return pathsplit(path)[1]

def fileext(path):
    '''
    /path/to/file.ext -> ext
    '''
    return splitext(path)[1]

def dirname(path):
    '''
    /path/to/file.ext -> /path/to
    '''
    return pathsplit(path)[0]

def basename(path):
    '''
    /path/to/file.ext -> file
    '''
    return splitext(filename(path))[0]

SRCFINDER_ROOT = os.getenv('SRCFINDER_ROOT') or dirname(__file__)
print('SRCFINDER_ROOT: "%s"'%str((SRCFINDER_ROOT)))

DATUM_WGS84 = 23 # WGS-84
DEG2RAD = np.pi/180.0

NODATA = -9999
    
# reused constants
ORDER_NN    = 0
ORDER_BILIN = 1
ORDER_QUAD  = 2
ORDER_CUBIC = 3

CONN4 = 1
CONN8 = 2

POINTSRC = 1
DIFFSRC = 2

use_absmf=False

kernel=50
mfmin,mfmax = 500,1500
minarea=9
mfminsmall  = 1250

def imsave(fname,img,**kwargs):
    from skimage.io import imsave as _imsave    
    if kwargs.pop('verbose',False):
        print('saving',fname)
    return _imsave(fname,img,**kwargs)

def array2rgba(a,**kwargs):
    '''
    converts a 1d array into a set of rgba values according to
    the default or user-provided colormap
    '''
    from pylab import rcParams
    from matplotlib.cm import get_cmap
    from numpy import isnan,clip,uint8,where
    kwargs.setdefault('cmap',rcParams['image.cmap'])
    kwargs.setdefault('lut',rcParams['image.lut'])
    cm = get_cmap(kwargs['cmap'],lut=kwargs['lut'])    
    aflat = np.float32(a.copy().ravel())
    nanmask = isnan(aflat)
    nvalid = np.count_nonzero(~nanmask)
    avals = aflat[~nanmask]
    vmin = float(kwargs.pop('vmin',avals.min()))
    vmax = float(kwargs.pop('vmax',avals.max()))
    if nvalid>0 and vmax>vmin:                
        aflat[nanmask] = vmin # use vmin to map to zero, below
        aflat = clip(((aflat-vmin)/(vmax-vmin)),0.,1.)
        rgba = uint8(cm(aflat)*255)
        if nanmask.any():
            nanr = np.where(nanmask)
            rgba[nanr[0],:] = 0
        rgba = rgba.reshape(list(a.shape)+[4])
    else:
        if nvalid==0:
            warn('all values nan')
        elif vmax<=vmin:
            warn('vmax <= vmin')
        rgba = np.zeros(list(a.shape)+[4],dtype=uint8)
    return rgba

def float2rgba(img,cmap='binary',vmin=0.0,vmax=1.0,alpha=0):
    """
    float2rgba(img,alpha=0)
    
    Summary: Stretch range of unit-scaled 32-bit single band float image to
      4-band (3x8 rgb + 1x8 alpha) uint8 rgba image
    
    Arguments:
    - img: [n,m,1] single band 32-bit float image, pixel values \in [0.0,1.0]
    - alpha: output alpha value (default=0)
    
    Output:
    - [n,m,4] rgba uint8 image with img encoded as 3x8-bit rgb bands and
      constant alpha band

    Notes:
    - img must be scaled to [0,1] range, but should *not* be scaled into
      [FLT_MIN,FLT_MAX] (typically=[1.175494e-38,3.402823e+38]) range
    - most image analysis software ignores alpha band, so we only use the 24-bit
      range instead of the full 32-bit range
    - default alpha=0 will produce transparent images, set alpha=255 or ignore
      alpha band to visualize output
    
    """
    
    # assume img pixel values \in [0,1] range
    assert((img.min()>=vmin) & (img.max()<=vmax))
    if cmap=='binary':
        # max value of uint8 rgb (8x3 bands=24bits) pixel = 2**(24)-1
        rgbavec = np.uint32(((2**24)-1)*img).view(dtype=np.uint8)
    else:
        rgbavec = np.uint8(array2rgba(np.float32(img.ravel()),cmap=cmap,
                                      vmin=vmin,vmax=vmax,lut=4096))
    rgba = rgbavec.reshape([img.shape[0],img.shape[1],4])    
    rgba[...,-1] = np.uint8(alpha)
    return rgba

def rgba2float(img,cmap='binary',alpha=0):
    from pylab import rcParams
    from matplotlib.cm import get_cmap
    imgc = img.copy()
    if cmap=='binary':
        imgc[...,-1] = np.uint8(alpha)    
        imgc = imgc.view(np.uint32) / np.float32((2**24)-1)        
        return imgc.squeeze()
    cm = get_cmap(cmap,lut=4096)    
    lut = cm.colors[...,:3]
    d = ((imgc[...,:3]/255.0-lut[:, None, None, :])**2).sum(axis=-1)
    f = d.argmin(axis=0)/np.float32(len(lut)-1)
    return f # np.round(len(lut)*f))/float(len(lut))

def rgbdet2ql(rgb,det,detmask=[]):
    rgbimg = np.uint8(255*rgb.copy())
    if len(detmask)!=0 and detmask.any():
        ch4idx = np.where(detmask)
        ch4rgba = array2rgba(det[ch4idx],cmap='YlOrRd')
        rgbimg[ch4idx[0],ch4idx[1],:3] = ch4rgba[:,:3]
    return rgbimg #.transpose((1,0,2))[::-1])

def progressbar(caption,maxval=None):
    """
    progress(title,maxval=None)
    
    Summary: progress bar wrapper
    
    Arguments:
    - caption: progress bar caption
    - maxval: maximum value
    
    Keyword Arguments:
    None 
    
    Output:
    - progress bar instance (pbar.update(i) to step, pbar.finish() to close)
    """
    
    from progressbar import ProgressBar, Bar, UnknownLength
    from progressbar import Percentage, Counter, ETA

    capstr = caption + ': ' if caption else ''
    if maxval is not None:    
        widgets = [capstr, Percentage(), ' ', Bar('='), ' ', ETA()]
    else:
        maxval = UnknownLength
        widgets = [capstr, Counter(), ' ', Bar('=')]

    return ProgressBar(widgets=widgets, maxval=maxval)    

def findboundaries(labimg,**kwargs):
    from skimage.segmentation import find_boundaries
    kwargs.setdefault('connectivity',CONN8)
    return find_boundaries(labimg,**kwargs)

def thickboundaries(labimg):
    return findboundaries(labimg,mode='thick')

def innerboundaries(labimg):
    return findboundaries(labimg,mode='inner')

def outerboundaries(labimg):
    return findboundaries(labimg,mode='outer')

def createimg(hdrf,metadata,**kwargs):
    from spectral.io.envi import create_image
    return create_image(hdrf, metadata, ext='', force=True)

def imlabel(img,**kwargs):
    from skimage.measure import label as _label
    kwargs.setdefault('connectivity',CONN8)
    return _label(img,**kwargs)

def disk(radius,**kwargs):
    from skimage.morphology import disk as _disk
    return _disk(radius,**kwargs)

def bwdilate(bwimg,**kwargs):
    from skimage.morphology import binary_dilation as _bwd
    kwargs.setdefault('selem',disk(3))
    return _bwd(bwimg,**kwargs)

def bwdist(bwimg,**kwargs):
    metric = kwargs.get('metric','euclidean')
    if metric=='euclidean':
        kwargs.pop('metric',None) # metric is only option for edt
        from scipy.ndimage.morphology import distance_transform_edt as _bwdist
    elif metric in ('chessboard','taxicab'):
        from scipy.ndimage.morphology import distance_transform_cdt as _bwdist
    kwargs.setdefault('return_distances',True)
    kwargs.setdefault('return_indices',False)
    return _bwdist(bwimg,**kwargs)

def mergelabels(labimg,mergedist,return_merged=False,doplot=False):
    # merge labeled regions <= mergedist pixels from each other
    fgmask = (labimg!=0)
    mergecomp = imlabel(bwdist(~fgmask,metric='chessboard')<=mergedist)
    mergelab = np.unique(mergecomp)
    outimg = np.zeros_like(labimg)
    outmerged = {}
    for i,ml in enumerate(mergelab[mergelab!=0]):
        mlmask = (mergecomp==ml) & fgmask
        outimg[mlmask] = i+1
        if return_merged:
            outmerged[ml] = np.unique(labimg[mlmask])
    print(len(np.unique(labimg))-1,'labels before merge',len(mergelab)-1,'after merge')
    if doplot:
        import pylab as pl
        fig,ax = pl.subplots(1,2,sharex=True,sharey=True)
        ax[0].imshow(labimg)
        ax[0].set_title('before')
        ax[1].imshow(outimg)
        ax[1].set_title('after')
        pl.show()
    if return_merged:
        return outimg, outmerged
    return outimg

def imread(fname,**kwargs):
    from skimage.io import imread as _imread
    kwargs.setdefault('plugin',None)
    return _imread(fname,**kwargs)

def imresize(img,output_shape,**kwargs):
    from skimage.transform import resize as _imresize
    
    kwargs.setdefault('order',ORDER_NN) 
    kwargs.setdefault('clip',False)
    kwargs.setdefault('preserve_range',True)
    if kwargs.pop('anti_alias',False):
        from scipy import ndimage as ndi
        cval = kwargs.get('cval',0)
        mode = kwargs.get('mode','constant')
        sigma = kwargs.pop('anti_alias_sigma',None)
        if sigma is None:
            factors = (np.asarray(img.shape, dtype=float) /
                       np.asarray(output_shape, dtype=float))
            sigma = np.maximum(0, (factors - 1) / 2)
        imgrs = ndi.gaussian_filter(img, sigma, cval=cval, mode=mode)
    else:
        imgrs = img
    
    return _imresize(imgrs,output_shape,**kwargs)

def filename2flightid(filename):
    '''
    get flight id from filename
    ang20160922t184215_cmf_v1g_img -> ang20160922t184215
    '''
    return basename(filename).split('_')[0]

def filename2flightdate(filename):
    '''
    get flight id from filename
    ang20160922t184215_cmf_v1g_img -> ang20160922t184215
    '''
    
    flightid = filename2flightid(filename)
    if flightid.startswith('f'): # avcl
        datestr = flightid.split('t')[0][1:6]
        yyayy,mm,dd = map(int,['20'+datestr[:2],datestr[2:4],datestr[4:6]])
    else:
        datestr = flightid.split('t')[0][-8:]
        yyyy,mm,dd = map(int,[datestr[:4],datestr[4:6],datestr[6:]])
    
    return yyyy,mm,dd

def shortstr(s,width=20,placeholder='...',quote=False):
    shortstr = s if len(s) <= width else s[:width]+placeholder
    return '"%s"'%shortstr if quote else shortstr

def filename2calid(infile):
    # e.g., ./ort/ang20160915t194328_cmf_v1n2_img -> v1n2
    _,filen = pathsplit(infile)
    spl = filen.split('_')
    if filen.startswith('f'): # avcl
        calid = str(spl[1]+'_'+spl[2])
    else:
        calid = spl[2]
    return calid

l1calid = filename2calid

def extrema(a,**kwargs):
    p = kwargs.pop('p',1.0)
    if p==1.0:
        return np.amin(a,**kwargs),np.amax(a,**kwargs)
    elif p==0.0:
        return np.amax(a,**kwargs),np.amin(a,**kwargs)
    assert(p>0.0 and p<1.0)
    axis = kwargs.pop('axis',None)
    apercent = lambda q: np.nanpercentile(a,axis=axis,q=q,
                                          interpolation='nearest')
    return apercent((1-p)*100),apercent(p*100)

def envitypecode(np_dtype):
    from numpy import dtype
    from spectral.io.envi import dtype_to_envi
    _dtype = dtype(np_dtype).char
    return dtype_to_envi[_dtype]

def extract_tile(img,ul,tdim,verbose=False,transpose=None,cval=0):
    '''
    extract a tile of dims (tdim,tdim,img.shape[2]) offset from upper-left 
    coordinate ul in img, zero pads when tile overlaps image extent 
    '''
    ndim = img.ndim
    if ndim==3:
        nr,nc,nb = img.shape
    elif ndim==2:
        nr,nc = img.shape
        nb = 1
    else:
        raise Exception('invalid number of image dims %d'%ndim)
    
    lr = (ul[0]+tdim,ul[1]+tdim)
    padt,padb = abs(max(0,-ul[0])), tdim-max(0,lr[0]-nr)
    padl,padr = abs(max(0,-ul[1])), tdim-max(0,lr[1]-nc)
    
    ibeg,iend = max(0,ul[0]),min(nr,lr[0])
    jbeg,jend = max(0,ul[1]),min(nc,lr[1])

    if verbose:
        print(ul,nr,nc)
        print(padt,padb,padl,padr)
        print(ibeg,iend,jbeg,jend)

    imgtile = cval*np.ones([tdim,tdim,nb],dtype=img.dtype)
    imgtile[padt:padb,padl:padr] = np.atleast_3d(img[ibeg:iend,jbeg:jend])
    if transpose is not None:
        imgtile = imgtile.transpose(transpose)
    return imgtile

def rotxy(x,y,adeg,xc,yc):
    """
    rotxy(x,y,adeg,xc,yc)

    Summary: rotate point x,y about xc,yc by adeg degrees

    Arguments:
    - x: x coord to rotate
    - y: y coord to rotate
    - adeg: angle of rotation in degrees
    - xc: center x coord
    - yc: center y coord

    Output:
    rotated x,y point
    """
    arad = DEG2RAD*adeg
    sinr,cosr = np.sin(arad),np.cos(arad)
    rotm = [[cosr,-sinr],[sinr,cosr]]
    xyp = np.dot(rotm,np.c_[x-xc,y-yc].T)

    # handle scalar outputs
    if xyp.ndim==2 and xyp.shape[1]==1:
        xyp = xyp.squeeze()
    return xyp[0]+xc,xyp[1]+yc
        
def meshpositions(*arrs):
    grid = np.meshgrid(*arrs)
    positions = np.vstack(map(np.ravel, grid))
    return positions

def chull(p,**kwargs):
    from scipy.spatial import ConvexHull as _chull
    return_indices = kwargs.pop('return_indices',False)
    hullidx = _chull(p,**kwargs).vertices
    hullp = p[hullidx]
    if return_indices:
        return hullp,hullidx
    return hullp

def utm2latlon(easting,northing,zone,hemi='North',alpha=None,datum=DATUM_WGS84):
    if hemi not in ('North','South'):
        print('invalid hemisphere value=',hemi)
        return None,None
    
    zone_alpha = alpha or ('N' if hemi=='North' else 'M')
    lat,lon = LLUTMConv.UTMtoLL(datum,easting,northing,str(zone)+zone_alpha)
    return lat,lon

def sl2xy(s,l,**kwargs):
    """
    sl2xy(s,l,x0=0,y0=0,xps=0,yps=xps,rot=0,mapinfo=None) 

    Given integer pixel coordinates (s,l) convert to map coordinates (x,y)

    Arguments:
    - s,l: sample, line indices

    Keyword Arguments:
    - x0,y0: upper left map coordinate (default = (0,0))    
    - xps: x map pixel size (default=None)
    - yps: y map pixel size (default=xps)
    - rot: map rotation in degrees (default=0)
    - mapinfo: envi map info dict (entries replaced with kwargs above)    

    Returns:
    - x,y: x,y map coordinates of sample s, line l
    """
    mapinfo = kwargs.pop('mapinfo',{})    
    x0 = kwargs.pop('ulx',mapinfo.get('ulx',None))
    y0 = kwargs.pop('uly',mapinfo.get('uly',None))
    xps = kwargs.pop('xps',mapinfo.get('xps',None))
    yps = kwargs.pop('yps',mapinfo.get('yps',xps))
    rot = kwargs.pop('rot',mapinfo.get('rotation',0))

    if None in (x0,y0):
        raise ValueError("ulx or uly undefined")

    if None in (xps,yps):
        raise ValueError("xps or yps undefined")

    if yps == 0:
        yps = xps

    xp,yp = x0+xps*s, y0-yps*l        
    if rot == 0:
        return xp,yp

    ar = DEG2RAD*rot
    X, Y = rotxy(xp,yp,rot,x0,y0)
    #X = x0 + xps * s + ar  * l
    #Y = y0 + ar  * s - yps * l    
    return X, Y

def sl2latlon(s,l,**kwargs):
    mapinfo = kwargs.get('mapinfo',{})
    proj = mapinfo.get('proj',None)
    if not proj:
        raise ValueError("proj undefined")
    elif proj not in ('UTM','Geographic Lat/Lon'):   
        print('unknown projection:',proj)
        return None
    
    x,y = sl2xy(s,l,**kwargs)
    if proj=='Geographic Lat/Lon':
        return y,x
    elif proj=='UTM':
        return utm2latlon(y,x,zone=mapinfo['zone'],hemi=mapinfo['hemi'])
    else:
        raise Exception('Unknown projection "%s"'%proj)

def xy2sl(x,y,**kwargs):
    """
    xy2sl(x,y,x0=0,y0=,xps=0,yps=xps,rot=0,mapinfo=None) 

    Given a orthocorrected image find the (s,l) values for a given (x,y)

    Arguments:
    - x,y: map coordinates

    Keyword Arguments:
    - x0,y0: upper left map coordinate (default = (0,0))
    - xps: x map pixel size (default=None)
    - yps: y map pixel size (default=xps)
    - rot: map rotation in degrees (default=0)
    - mapinfo: envi map info dict (xps,yps,rot override)

    Returns:
    - s,l: sample, line coordinates of x,y
    """
    mapinfo = kwargs.pop('mapinfo',{})

    x0 = kwargs.pop('ulx',mapinfo.get('ulx',None))
    y0 = kwargs.pop('uly',mapinfo.get('uly',None))
    xps = kwargs.pop('xps',mapinfo.get('xps',None))
    yps = kwargs.pop('yps',mapinfo.get('yps',xps))
    rot = kwargs.pop('rot',mapinfo.get('rotation',0))
    #if mapinfo and rot != 0:
    #    # flip sign of mapinfo rot unless otherwise specified
    #    rot = rot * kwargs.pop('rotsign',-1) 
    
    if None in (x0,y0):
        raise ValueError("either ulx or uly defined")

    if xps is None:
        raise ValueError("pixel size defined")

    yps = yps or xps
    xp, yp = (x-x0), (y0-y)
    if rot!=0:
        xp, yp = rotxy(xp,yp,rot,0,0)

    xp,yp = xp/xps,yp/yps
    return xp,yp

def latlon2utm(lat,lon,zone=None,datum=DATUM_WGS84):
    """
    latlon2utm(lat,lon,zone=None,datum=DATUM_WGS84) 

    Arguments:
    - lat: latitude coordinate(s)
    - lon: longitude coordinate(s)

    Keyword Arguments:
    - zone:   UTM zone number (default=None, computed automatically)
    - datum:  reference ellipsoid (default=_DATUM_WGS84)

    Returns:
    - easting:  UTM easting map coordinate
    - northing: UTM northing map coordinate
    - zone:     UTM zone digit
    - hemi:     hemisphere letter (hemi >= 'N' -> Northern hemisphere)

    """

    zonealpha,easting,northing = LLUTMConv.LLtoUTM(datum,lat,lon,
                                                   ZoneNumber=zone)
    return easting, northing, int(zonealpha[:-1]), zonealpha[-1]

def latlon2sl(lat,lon,**kwargs):
    mapinfo = kwargs.get('mapinfo',{})
    proj = mapinfo.get('proj',None)
    if not proj:
        raise ValueError("proj undefined")
    elif proj not in ('UTM','Geographic Lat/Lon'):   
        print('Unknown projection:',proj)
        return None
    
    if proj=='Geographic Lat/Lon':
        return xy2sl(lon,lat,mapinfo=mapinfo)

    zone = int(mapinfo['zone']) if 'zone' in mapinfo else None
    x,y,zone,zonealpha = latlon2utm(lat,lon,zone=zone)
    return xy2sl(x,y,mapinfo=mapinfo)

def mapdict2str(mapdict):
    mapmeta = mapdict.pop('metadata',[])
    mapkeys,mapvals = mapdict.keys(),mapdict.values()
    nargs = 10 if mapdict['proj']=='UTM' else 7
    maplist = map(str,mapvals[:nargs])
    mapkw = zip(mapkeys[nargs:],mapvals[nargs:])
    mapkw = [str(k)+'='+str(v) for k,v in mapkw]
    mapstr = '{ '+(', '.join(maplist+mapkw+mapmeta))+' }'
    return mapstr

def findhdr(img_file):
    from os import path
    dirname = path.dirname(img_file)
    filename,filext = path.splitext(img_file)
    if filext == '.hdr' and path.isfile(img_file): # img_file is a .hdr
        return img_file
    
    hdr_file = img_file+'.hdr' # [img_file.img].hdr or [img_file].hdr
    if path.isfile(hdr_file):
        return path.abspath(hdr_file)
    hdr_file = filename+'.hdr' # [img_file.img] -> [img_file].hdr 
    if path.isfile(hdr_file):
        return hdr_file
    return None

def openimg(imgf,hdrf=None,**kwargs):
    from spectral.io.envi import open as _open
    from spectral import SpyFile
    if isinstance(imgf,SpyFile):
        return imgf
    hdrf = hdrf or findhdr(imgf)
    return _open(hdrf,imgf,**kwargs)

def openmm(img,interleave='source',writable=False):
    if isinstance(img,str):
        _img = openimg(img)
        return _img.open_memmap(interleave=interleave, writable=writable)
    return img.open_memmap(interleave=interleave, writable=writable)

def openimgmm(imgf,interleave='source',writable=False):
    """
    openimgandmm(imgf,hdr=None,interleave='source',writable=False) 
    
    Arguments:
    - imgf: image file
    
    Keyword Arguments:
    - hdrf: hdr file corresponding to imgf (default=None)
    - interleave: image interleave (default='source')
    - writable: allow writing to memmap (default=False)
    
    Returns:
    - img: spectralpython img structure
    - img_mm: numpy memmap of img data
    """
    img = openimg(imgf)
    imgmm = openmm(img,interleave=interleave,writable=writable)
    return img,imgmm

def mapinfo(img,astype=dict):
    from spectral import SpyFile
    _img = img if isinstance(img,SpyFile) else openimg(img)
    maplist = _img.metadata.get('map info',None)    
    if maplist is None or astype==list:
        return maplist    

    mapinfo = OrderedDict()
    mapinfo['proj'] = maplist[0]
    mapinfo['xtie'] = float(maplist[1])
    mapinfo['ytie'] = float(maplist[2])
    mapinfo['ulx']  = float(maplist[3])
    mapinfo['uly']  = float(maplist[4])
    mapinfo['xps']  = float(maplist[5])
    mapinfo['yps']  = float(maplist[6])

    if mapinfo['proj'] == 'UTM':
        mapinfo['zone']  = maplist[7]
        mapinfo['hemi']  = maplist[8]
        mapinfo['datum'] = maplist[9]

    mapmeta = []
    for mapitem in maplist[len(mapinfo):]:
        if '=' in mapitem:
            key,val = map(lambda s: s.strip(),mapitem.split('='))
            mapinfo[key] = val
        else:
            mapmeta.append(mapitem)

    mapinfo['rotation'] = float(mapinfo.get('rotation','0'))
    if len(mapmeta)!=0:
        print('unparsed metadata:',mapmeta)
        mapinfo['metadata'] = mapmeta

    if astype==str:
        return mapdict2str(mapinfo)

    return mapinfo

def prob2geotiff(proboutf,probimg,probimgf):
    from gdal import Open, GetDriverByName, GA_ReadOnly, GDT_Float32
    
    print('saving',proboutf)  
    gtifdrv = GetDriverByName('Gtiff')
    nl_prob,ns_prob,nb_prob = probimg.shape
    g = Open(probimgf, GA_ReadOnly)
    geo_t = g.GetGeoTransform()
    geo_p = g.GetProjectionRef()
    rows,cols = g.RasterYSize,g.RasterXSize

    prob_geo_t = (1, geo_t[1], geo_t[2], 1, geo_t[4], geo_t[5])
    gsub = gtifdrv.Create(proboutf, ns_prob, nl_prob, nb_prob, GDT_Float32)
    gsub.SetGeoTransform(prob_geo_t)
    gsub.SetProjection(geo_p)
    for i in range(nb_prob):
        gsub.GetRasterBand(i+1).WriteArray(probimg[:,:,i])
    gsub = None
    
    
def tile2geotiff(tileimg,tilepos,tilef,baseimgf,validate_coords=True,
                 border='bbox'):
    from gdal import Open, GetDriverByName, GA_ReadOnly, GDT_Byte

    print('saving',tilef)  
    gtifdrv = GetDriverByName('Gtiff')
    nl_tile,ns_tile,nb_tile = tileimg.shape
    g = Open(baseimgf, GA_ReadOnly)
    geo_t = g.GetGeoTransform()
    geo_p = g.GetProjectionRef()
    nl_img,ns_img = g.RasterYSize,g.RasterXSize

    if border == 'bbox':
        bordermask = np.ones([nl_tile,ns_tile],dtype=np.bool8)
    elif border=='circle':
        trad,tbufh = tileimg.shape[0]//2, 1
        bordermask = np.bool8(disk(trad+tbufh)[tbufh:-tbufh,tbufh:-tbufh])
    assert(bordermask.shape==tileimg.shape[:2])

    def sl2map(s,l,ulx,uly,xps):
        return sl2xy(s,l,mapinfo=dict(ulx=ulx,uly=uly,xps=xps,yps=xps,rot=0))

    def ct(s,l,mapdict):
        ulx,uly,xps = [mapdict[k] for k in ('ulx','uly','xps')]
        rot = mapdict.get('rotation',0)
        x,y = sl2map(s,l,ulx,uly,xps)
        if rot==0:
            return x,y
        return rotxy(x,y,rot,ulx,uly)

    def bbox(xy):
        minx,maxx = extrema([x for x,y in xy])
        miny,maxy = extrema([y for x,y in xy])
        return minx,maxy,maxx,miny

    mapdict = mapinfo(baseimgf)
    tile_ct = lambda sl: ct(sl[0],sl[1],mapdict)
    
    l0,s0 = tilepos
    tile_bbox_s = [s+s0 for s in [0,     0, ns_tile, ns_tile]]
    tile_bbox_l = [l+l0 for l in [0, nl_tile, 0,     nl_tile]]
    tile_bbox_sl = zip(tile_bbox_s,tile_bbox_l)
    tile_bbox_xy = map(tile_ct,tile_bbox_sl)
    
    #l0 = max(0,min(tilepos[0],nl_img-1))
    #s0 = max(0,min(tilepos[1],ns_img-1))
    #l1 = max(l0,min(tilepos[0]+tileimg.shape[0],ns_img-1))
    #s1 = max(s0,min(tilepos[1]+tileimg.shape[1],ns_img-1))
    x_tile = geo_t[0]+s0*geo_t[1]+l0*geo_t[2]
    y_tile = geo_t[3]+s0*geo_t[4]+l0*geo_t[5]
        
    if validate_coords:
        from skimage.measure import points_in_poly
        img_bbox_s = [0,0,ns_img,ns_img]
        img_bbox_l = [0,nl_img,nl_img,0]
        img_bbox_sl = zip(img_bbox_s,img_bbox_l)
        img_bbox_xy = []
        print(pathsplit(baseimgf)[1],'(s,l) -> (x,y) bounding box')
        for s,l in img_bbox_sl:
            x = geo_t[0]+s*geo_t[1]+l*geo_t[2]
            y = geo_t[3]+s*geo_t[4]+l*geo_t[5]
            img_bbox_xy.append((x,y))
            print((s,l),'->\t',(x,y))

        print('tile (s,l) -> (x,y) bounding box')
        off_lab = ['Upper Left','Lower Left','Upper Right','Lower Right']

        in_img = np.zeros(len(tile_bbox_sl),dtype=np.bool8)
        for i,(st,lt) in enumerate(tile_bbox_sl):
            xt = geo_t[0]+st*geo_t[1]+lt*geo_t[2]
            yt = geo_t[3]+st*geo_t[4]+lt*geo_t[5]
            xydist = (np.float32((xt,yt))-np.float32(tile_bbox_xy[i]))**2
            xydist = np.sqrt(xydist.sum())
            if xydist > 1.0:
                # if gdal/numpy differ by more than 1m, wait
                print('WARNING:','xydist=',xydist,'gdal (x,y)=',(xt,yt),
                      'numpy (x,y)=',tile_bbox_xy[i])
                raw_input()
            in_img[i] = points_in_poly([(xt,yt)],img_bbox_xy)[0]
            print(off_lab[i],(st,lt),'->\t',(xt,yt),'in image bbox=',in_img[i])

        if not in_img.any():
            print('WARNING: tile outside of bounding box')
        elif not in_img.all():
            nzo = (in_img==0).sum()
            print('WARNING: tile overlaps bounding box (%d points outside)'%nzo)

    tile_geo_t = (x_tile, geo_t[1], geo_t[2], y_tile, geo_t[4], geo_t[5])
    gsub = gtifdrv.Create(tilef, ns_tile, nl_tile, nb_tile, GDT_Byte)
    gsub.SetGeoTransform(tile_geo_t)
    gsub.SetProjection(geo_p)
    for i in range(nb_tile):
        tileband = tileimg[:,:,i].copy()
        if border!='bbox':
            tileband[~bordermask] = 0
            if i==0:
                import pylab as pl
                fig0,ax0 = pl.subplots(1,2,sharex=True,sharey=True)
                pl.suptitle(tilef)
                ax0[0].imshow(tileimg); ax0[0].set_xlabel('contour')
                ax0[1].imshow(bordermask); ax0[1].set_xlabel('boundary')
                pl.show()
        gsub.GetRasterBand(i+1).WriteArray(tileband)
    gsub = None # flush to disk

def rdn2rgb(rdn):
    rdnmin,rdnmax = extrema(rdn,p=0.99)
    return np.clip((rdn-rdnmin)/(rdnmax-rdnmin),0.0,1.0)

def array2img(outf,img,mapinfostr=None,bandnames=None,**kwargs):
    outhdrf = outf+'.hdr'
    if pathexists(outf) and not kwargs.pop('overwrite',False):
        print('Cannot write image to path',outf)
        return
        
    img = np.atleast_3d(img)
    outmeta = dict(samples=img.shape[1], lines=img.shape[0], bands=img.shape[2],
                   interleave='bip')
    
    outmeta['file type'] = 'ENVI'
    outmeta['byte order'] = 0
    outmeta['header offset'] = 0
    outmeta['data type'] = envitypecode(img.dtype)

    if mapinfostr:
        outmeta['map info'] = mapinfostr

    if bandnames:
        outmeta['band names'] = '{%s}'%", ".join(bandnames)
        
    outmeta.setdefault('data ignore value',NODATA)

    outimg = createimg(outhdrf,outmeta)
    outmm = openmm(outimg,writable=True)
    outmm[:] = img
    outmm = None # flush output to disk
    print('saved %s array to %s'%(str(img.shape),outf))

def mad(a,axis=0,medval=None,unbiased=False):    
    '''
    computes the median absolute deviation of a list of values
    mad = median(abs(a - medval))/c
    '''
    from statsmodels.robust.scale import mad as _mad
    center = medval or np.median(a,axis=axis)
    c = 0.67448975019608171 if unbiased else 1.0
    # return np.median(np.abs(np.asarray(a)-medval))
    return _mad(a, c=c, axis=axis, center=center)
    
def kde(img,k):
    from scipy.ndimage import gaussian_filter
    imgkde = gaussian_filter(img,sigma=k,truncate=1)
    imgkde = (imgkde-imgkde.min())/(imgkde.max()-imgkde.min())
    return img*imgkde

def absnorm(img,mask):
    assert((len(img.shape)==2))
    print('normalizing image to absolute range')
    i32 = np.float32(img)
    imax = np.abs(i32[~mask]).max()
    imin = -imax
    imgn=np.clip((i32-imin)/(imax-imin),0.0,1.0)
    return imgn,imin,imax

def smoothbil(img, mask, d, sigmaColor, sigmaSpace, normalize=True):
    from cv2 import bilateralFilter
    print('running bilateralFilter')
    if normalize:
        imgn,imin,imax  = absnorm(img,mask)
    else:
        imgn = img.copy()
        imin,imax = extrema(img[~mask])
    imgn = bilateralFilter(imgn, d, sigmaColor, sigmaSpace)
    imgn = imin+(imgn*(imax-imin))
    return imgn

def relabel_sequential(labimg,**kwargs):
    from skimage.segmentation import relabel_sequential as _rls
    return _rls(labimg,**kwargs)

def remove_small_objects(labimg,**kwargs):
    from skimage.morphology import remove_small_objects as _rso
    kwargs.setdefault('min_size',9)
    kwargs.setdefault('in_place',False)
    return _rso(labimg,**kwargs)

def filtdet(ch4mf,nodata_mask,mfmapinfo,minarea=minarea,mfmin=mfmin,mfmax=mfmax,
            k=kernel,mfminsmall=mfminsmall,skip_kde=False,use_abs=False,
            kde_outf=None,ccomp_outf=None,det_outf=None):
    from skimage.morphology import reconstruction
    print('filtering weakly-connected detections below minppm=%.2fppmm'%mfmin)
    mfmapstr=mapdict2str(mfmapinfo)
    detkde = np.abs(ch4mf) if use_abs else ch4mf.copy()
    ch4min = ch4mf >= mfmin
    if not skip_kde:
        detkde = kde(detkde,k=k)
    detkde = np.clip((detkde-mfmin)/(mfmax-mfmin),0.0,1.0)

    if not skip_kde and kde_outf:
        array2img(kde_outf,detkde,mapinfostr=mfmapstr,overwrite=True)
    detmask = (detkde>0) 
    print('%d candidate components'%imlabel(detmask).max())
    if 0:
        print('%d components before hole removal'%imlabel(detmask).max())    
        seed = detmask.copy()
        seed[1:-1, 1:-1] = detmask.max()
        print('removing interior holes')        
        detmask = np.bool8(reconstruction(seed, detmask, method='erosion'))
        print(extrema(ch4mf[detmask]))
        print('%d components after hole removal'%imlabel(detmask).max())

    print('removing detections with <= minarea=%d pixels'%(minarea))
    detsmall = detmask.copy()
    detmask = remove_small_objects(detmask,min_size=minarea,in_place=False)
    print('%d remaining detections'%imlabel(detmask).max())
    if mfminsmall >= mfmin:
        print('adding small detections with mf>=%f ppmm'%(mfminsmall))
        smallcc = imlabel(detsmall!=detmask)
        smallkeep = np.unique(smallcc[(ch4mf>=mfminsmall)])
        smallkeep = smallkeep[smallkeep!=0]
        print(len(smallkeep),'small detections found')
        smallmask = (np.in1d(smallcc.ravel(),smallkeep).reshape(detmask.shape))
        detmask = (detmask | np.bool8(smallmask)) 

    # exclude nodata+ch4min afterward to exclude interior holes from ccimg
    detcomp = imlabel(detmask)
    detcomp[~ch4min] = 0
    detcomp,_,_ = relabel_sequential(detcomp)
    print('%d final detections'%detcomp.max())
    
    if ccomp_outf:
        detcomp[nodata_mask] = NODATA
        array2img(ccomp_outf,detcomp,mapinfostr=mfmapstr,overwrite=True)

    detkde[~ch4min] = 0
    detkde[nodata_mask] = 0
    detcomp[nodata_mask] = 0
    
    print('detected %d components'%detcomp.max())
    if det_outf:
        detfilt = ch4mf.copy()
        detfilt[~ch4min]=0        

        detfilt[nodata_mask] = NODATA
        array2img(det_outf,detfilt,mapinfostr=mfmapstr,overwrite=True)

    return detkde, detcomp

def loadobsimage(obsinfile,obs_bands=11): #Thorpe time start
                     
    obsdir,obsfile = pathsplit(obsinfile)    
    obsimg = envi_open_file(obsinfile+'.hdr',image=obsinfile)
    obsrows,obscols,obsbands = obsimg.shape
    print('obsimg',obsimg)
    obs_outdata = mapinfo(obsimg,astype=dict)
    #obs_outdata = dict(mapinfo=mapinfo(obsimg,astype=dict))
    return obsimg,obs_outdata #Thorpe time end

def loadmaskedimage(maskedimgf,rgb_bands=[],masked_value=np.nan,
                    load_bands=[],memmap=False,verbose=0):
    maskeddir,maskedfile = pathsplit(maskedimgf)    
    maskedimg = envi_open_file(maskedimgf+'.hdr',image=maskedimgf)
    rows,cols,bands = maskedimg.shape
    if verbose:
        print('Loading [%d,%d,%d] masked input image: "%s"'%(rows,cols,bands,maskedimgf))
    if memmap:
        maskeddata = maskedimg.open_memmap(interleave='source',writeable=False)
    else:
        nload=len(load_bands)
        if nload>1:
            maskeddata = maskedimg.read_bands(load_bands)
        elif nload==1:
            maskeddata = maskedimg.read_bands([load_bands[0]])
        else:
            maskeddata = maskedimg.load() # load everything

    maskeddata = np.float32(maskeddata)
    if maskeddata.ndim == 2:
        maskeddata = maskeddata[...,np.newaxis]
    nodata_value = float(maskedimg.metadata.get('data ignore value',np.nan))
    nodata_mask = (maskeddata==nodata_value).any(axis=2)
    if not memmap:
        maskeddata[nodata_mask] = masked_value
    else:
        masked_value = nodata_value

    print(np.count_nonzero(nodata_mask),'nodata pixels in image',maskedimgf)

    outdata = dict(mapinfo=mapinfo(maskedimg,astype=dict),
                   nodata_mask=nodata_mask,
                   nodata_value=nodata_value)

    if bands>=3 and len(rgb_bands)==3:
        # ang cmf image: 3 rgb + 1 cmf band
        image_bands = list(set(range(bands))-set(rgb_bands))
        outdata['rgb'] = rdn2rgb(maskeddata[:,:,rgb_bands])
        if len(image_bands)!=0:
            outdata['image'] = maskeddata[:,:,image_bands].squeeze()
    elif bands==2 and len(set(rgb_bands))==1:
        # avcl cmf image: 1 cmf band + 1 grayscale 
        image_bands = list(set(range(bands))-set(rgb_bands))
        outdata['rgb'] = rdn2rgb(maskeddata[:,:,rgb_bands])
        if len(image_bands)!=0:
            outdata['image'] = maskeddata[:,:,image_bands].squeeze()        
    else:
        outdata['image'] = maskeddata.squeeze()
    
    return outdata

def rgb2labimg(rgbimg):
    assert(rgbimg.shape[2]==3)
    
    labimg = np.zeros(rgbimg.shape[:2],dtype=np.uint8)

    # point sources = [255,0,0], diffuse sources = [0,0,255]
    labmask = rgbimg.sum(axis=2)==255 # only 1 channel==255
    labimg[labmask & (rgbimg[:,:,0]==255)] = POINTSRC
    labimg[labmask & (rgbimg[:,:,2]==255)] = DIFFSRC

    return labimg


def loadlabimg(labf):
    print('loading label mask image: "%s"'%labf)
    labpath,labfile = pathsplit(labf)
    filebase,fileext = splitext(labfile)
    if fileext=='.png':
        labimg = imread(labf)
        if labimg.shape[2] in (3,4):
            labimg = rgb2labimg(labimg[:,:,:3]).squeeze()

    elif fileext=='' and filebase.endswith('class'):
        # envi class map
        imgdat = envi_open_file(labf+'.hdr',image=labf)
        #labimg = imgdat.open_memmap(interleave='source',writeable=False).copy()
        labimg = imgdat.load().squeeze()
    else:
        raise Exception('Unrecognized format %s'%labf)

    labimg = np.uint8(labimg)
    ulab = np.unique(labimg)
    assert(((ulab==0)|(ulab==POINTSRC)|(ulab==DIFFSRC)).all())
    
    return labimg

def loadfiltdet(detfilt_imgf):
    print('loading filtered detection image: "%s"'%detfilt_imgf)    
    detdir,detfile = pathsplit(detfilt_imgf)    
    detimg = envi_open_file(detfilt_imgf+'.hdr',image=detfilt_imgf)
    ch4det = np.float32(detimg.load().squeeze())
    nodata_value = float(detimg.metadata['data ignore value'])
    nodata_mask = ch4det==nodata_value
    ch4det[nodata_mask] = 0
    return dict(ch4det=ch4det,mapinfo=mapinfo(detimg,astype=dict),
                nodata_mask=nodata_mask,nodata_value=nodata_value)

def loaddetids(detid_imgf):
    print('loading detection id image: "%s"'%detid_imgf)    
    detdir,detfile = pathsplit(detid_imgf)    
    detimg = envi_open_file(detid_imgf+'.hdr',image=detid_imgf)
    detids = np.float32(detimg.load().squeeze())
    detmeta = detimg.metadata
    nodata_value = float(detmeta['data ignore value'])
    nodata_mask = detids==nodata_value
    detids[nodata_mask] = 0
    return dict(detids=detids,mapinfo=mapinfo(detimg,astype=dict),
                nodata_mask=nodata_mask,nodata_value=nodata_value)

def loadsaliencemap(salience_imgf):
    print('loading',salience_imgf)
    saliencedir,saliencefile = pathsplit(salience_imgf)
    salienceimg = envi_open_file(salience_imgf+'.hdr',image=salience_imgf)
    saliencemapinfo = mapinfo(salienceimg,astype=dict)
    saliencemap = np.float32(salienceimg.load().squeeze())
    return dict(saliencemap=saliencemap,mapinfo=saliencemapinfo)

def bbox(points,border=0,imgshape=[]):
    """
    bbox(points) 
    computes bounding box of extrema in points array

    Arguments:
    - points: [N x 2] array of [rows, cols]
    """
    from numpy import atleast_2d
    points = atleast_2d(points)
    minv = points.min(axis=0)
    maxv = points.max(axis=0)
    difv = maxv-minv
    
    if isinstance(border,list):
        rborder,cborder = border
        rborder = rborder if rborder > 1 else int(rborder*difv[0])
        cborder = cborder if cborder > 1 else int(cborder*difv[1])
    elif border < 1:
        rborder = cborder = int(border*difv.mean()+0.5)
    elif border == 'mindiff':
        rborder = cborder = min(difv)
    elif border == 'maxdiff':
        rborder = cborder = max(difv)
    else:
        rborder = cborder = border
        
    if len(imgshape)==0:
        imgshape = maxv+max(rborder,cborder)+1
    
    rmin,rmax = max(0,minv[0]-rborder),min(imgshape[0],maxv[0]+rborder+1)
    cmin,cmax = max(0,minv[1]-cborder),min(imgshape[1],maxv[1]+cborder+1)    
    
    return (rmin,cmin),(rmax,cmax)
    
def maskbbox(mask,border=0):
    """
    maskbbox(mask,border=0) 
    computes the bounding box of nonzero pixel coords for input mask
    
    Arguments:
    - mask: [n x m] binary image mask
    
    Keyword Arguments:
    - border: number of pixels to use in padding bounding box (default=0)
    
    Returns:
    - bbox = (rowmin,colmin),(rowmax,colmax)
    """
    
    points = np.c_[np.where(mask)]
    return bbox(points,border,mask.shape)


def region_maxima(img,mask,return_index=False):
    from skimage.measure import regionprops
    ccimg = imlabel(mask)
    ulab = np.unique(ccimg[ccimg!=0])
    rcidx,rcmax = [],[]
    for r in regionprops(ccimg,intensity_image=img):        
        rcmax.append(r.max_intensity)
        if return_index:
            rc=r.coords
            rcidx.append(r.coords[np.argmax(img[rc[:,0],rc[:,1]])])
    rcmax=np.array(rcmax,dtype=img.dtype)
    if return_index:
        return rcmax,np.array(rcidx,dtype=np.int)
    return rcmax

def local_maxima(im,rad,image_max=[]):
    from skimage.feature import peak_local_max
    
    diam = 2*rad
    if 0:
        # image_max is the dilation of im with a diam x diam structuring element
        # It is used within peak_local_max function
        from scipy.ndimage import maximum_filter
        if len(image_max)==0:
            image_max = np.zeros(im.shape[:2])
        maximum_filter(im, size=diam, mode='constant', output=image_max)
    
    # Comparison between image_max and im to find the coordinates of local maxima
    return peak_local_max(im, min_distance=diam)

def runcmd(cmd,verbose=0):
    from subprocess import Popen, PIPE
    cmdstr = ' '.join(cmd) if isinstance(cmd,list) else cmd
    if verbose:
        print("running command:",cmdstr)
    cmdout = PIPE
    for rstr in ['>>','>&','>']:
        if rstr in cmdstr:
            cmdstr,cmdout = map(lambda s:s.strip(),cmdstr.split(rstr))
            mode = 'w' if rstr!='>>' else 'a'
            cmdout = open(cmdout,mode)
            
    p = Popen(cmdstr.split(), stdout=cmdout, stderr=cmdout)
    out, err = p.communicate()
    retcode = p.returncode

    if cmdout != PIPE:
        cmdout.close()
    
    return out,err,retcode

def retrieve_rgb(rgbf):
    wgetbin='wget --no-verbose'
    wgeturl='https://avirisng.jpl.nasa.gov'
    wgetrgb='{wgetbin} -O {rgbdir}/{rgbfile} {wgeturl}/aviris_locator/y{rgbyear}_RGB/{rgbfile}'
    wgetgeo='{wgetbin} -O {rgbdir}/{rgbfile} {wgeturl}/ql/{rgbyear}qlook/{lid}_geo.jpeg'

    wgetretc = 1
    lid = filename2flightid(rgbf)
    if not lid.startswith('ang'):
        raise Exception('retrieve_rgb only works with AVIRIS-NG flightlines')

    if pathexists(rgbf):        
        return 0
    
    try:
        (rgbdir,rgbfile),rgbyear = pathsplit(rgbf),lid[5:7]
        print(rgbdir,rgbfile,rgbyear)
        wgetcmd = wgetrgb if rgbyear!='17' else wgetgeo
        wgetcmd = wgetcmd.format(**vars())
        print(wgetcmd)
        wgetout = runcmd(wgetcmd)
        wgetretc = wgetout[-1]
        if wgetretc != 0:
            print(rgbf,'wget failure, skipping')
            print('wget output:', wgetout[1])
        
    except Exception as e:
        print(rgbf,'not found and unable to retreive, skipping')        

    return wgetretc

if __name__ == '__main__':
    rgbf= 'tiles/tag_mini_intensive/rgb/ang20160922t183143_RGB.jpeg'
    print(retrieve_rgb(rgbf))
    raw_input()
    cmap = 'inferno'
    tilefloat = np.random.rand(100,100)
    tilergba = float2rgba(tilefloat,cmap=cmap)
    rgbafloat = rgba2float(tilergba,cmap=cmap)
    diff = np.sqrt((tilefloat-rgbafloat)**2)
    print(rgbafloat.shape,diff.mean())
    import pylab as pl
    fig,ax = pl.subplots(1,4,sharex=True,sharey=True)
    ax[0].imshow(tilefloat,cmap=cmap,vmin=0,vmax=1)
    ax[1].imshow(tilergba[...,:3])
    ax[2].imshow(rgbafloat,cmap=cmap,vmin=0,vmax=1)
    ax[3].imshow(diff,vmin=0,vmax=0.01)
    pl.show()
    
    if 0:
        baseimg='./data/y16/cmf/ort/ang20160913t205656_cmf_v1n2_img'
        #baseimg='/lustre/ang/y16/cmf/ort/ang20160913t205656_cmf_v1n2_img'    
        tileimg = 255*np.ones([100,100,4],dtype=np.uint8)
        tilepos=(12360,728)
        tilefile='test.tif'
        tile2geotiff(tileimg,tilepos,tilefile,baseimg)
        raw_input()
        z=np.random.rand(10)
        z[3:6] = np.nan
        y=np.random.rand(10,10)
        y[1:5] = np.nan
        r = array2rgba(z)
        print(r,r.shape)
        r = array2rgba(y)
        print(r,r.shape)
        raw_input()
    import pylab as pl
    datadir='./data'    
    #datadir='/lustre/ang'
    if 1:
        cmfdir=datadir+'/y16/cmf/ort'
        labdir=pathjoin(cmfdir,'thorpe_training')
        detdir='tiles/thorpe_training/100'

        labimgf=pathjoin(labdir,'ang20160914t180328_cmf_v1n2_img_class')
        cmfimgf=pathjoin(cmfdir,'ang20160914t180328_cmf_v1n2_img')
        #filtdetimgf=pathjoin(detdir,'ang20160914t180328_cmf_v1n2_filt_det_500_1500')
        filtdetimgf='./tiles/thorpe_training/150/ang20160914t180328_cmf_v1n2_filt_det_350_1500'
        
    else:
        cmfdir=datadir+'/y15/cmf/ort'
        labdir=pathjoin(cmfdir,'thompson_training')
        detdir='tiles/thompson_training/100/'
        labimgf=pathjoin(labdir,'ang20150419t161445_cmf_v1f_img_mask.png')
        cmfimgf=pathjoin(cmfdir,'ang20150419t161445_cmf_v1f_img')        
        filtdetimgf = pathjoin(detdir,'ang20150419t161445_cmf_v1f_filt_det_500_1500')

    tiledim=150
    tilepos = 4678,1115
    tileimgf = './tiles/thorpe_training/150/ang20160914t180328/tp/tp_det4678_1115.png'
    pngimgf=filtdetimgf+'.png'
    sys.path.insert(0,'/Users/bbue/Research/src/python/util/tilepredictor/')
    from tilepredictor_util import imread_rgb,imread_tile
    
    tileimg = imread_tile(tileimgf,tile_shape=[tiledim,tiledim])
    pngimg = imread_rgb(pngimgf)
    tilepng = pngimg[tilepos[0]:tilepos[0]+tiledim,tilepos[1]:tilepos[1]+tiledim]
    cmfdata = loadmaskedimage(cmfimgf,rgb_bands=range(3))        
    detdata = loadfiltdet(filtdetimgf)    
    labimg = loadlabimg(labimgf)
    tilelab = labimg[tilepos[0]:tilepos[0]+tiledim,tilepos[1]:tilepos[1]+tiledim]
    # make sure nodata masks match 
    nodata_mask = detdata['nodata_mask']
    assert(cmfdata['nodata_mask'].sum()==nodata_mask.sum())

    ch4det = detdata['ch4det']    
    rgbimg = cmfdata['rgb']
    # remove labeled pixels in nodata regions
    labimg = np.float32(labimg)
    ch4det[nodata_mask | (ch4det<500)] = np.nan
    labimg[~(labimg>0)] = np.nan


    
    fig,ax = pl.subplots(3,1,sharex=True,sharey=True,num=1)
    ax[0].imshow(tileimg)
    ax[1].imshow(tilepng)
    ax[2].imshow(tilelab)
    pl.show()

    pl.clf()
    ax[0].imshow(rgbimg.swapaxes(0,1))
    ax[1].imshow(pngimg.swapaxes(0,1))
    pl.show()
    
    ax[0].imshow(rgbimg.swapaxes(0,1))
    ax[0].imshow(ch4det.swapaxes(0,1),vmin=500,vmax=1500,cmap='YlOrRd')
    ax[1].imshow(labimg.swapaxes(0,1),vmin=0,vmax=1,cmap='Spectral',alpha=0.7)
    ax[1].imshow(ch4det.swapaxes(0,1),vmin=500,vmax=1500,cmap='YlOrRd')
    pl.show()

    
