Source code for fermipy.skymap

# Licensed under a 3-clause BSD style license - see LICENSE.rst
from __future__ import absolute_import, division, print_function
import copy
import numpy as np
import healpy as hp
from scipy.interpolate import RegularGridInterpolator
from scipy.ndimage import map_coordinates
from astropy.io import fits
from astropy.wcs import WCS
from astropy.table import Table
from astropy.coordinates import SkyCoord
from astropy.coordinates import Galactic, ICRS
import gammapy
import fermipy.utils as utils
import fermipy.wcs_utils as wcs_utils
import fermipy.hpx_utils as hpx_utils
import fermipy.fits_utils as fits_utils
from fermipy.hpx_utils import HPX, HpxToWcsMapping


[docs]def coadd_maps(geom, maps, preserve_counts=True): """Coadd a sequence of `~gammapy.maps.Map` objects.""" # FIXME: This functionality should be built into the Map.coadd method map_out = gammapy.maps.Map.from_geom(geom) for m in maps: m_tmp = m if isinstance(m, gammapy.maps.HpxNDMap): if m.geom.order < map_out.geom.order: factor = map_out.geom.nside // m.geom.nside m_tmp = m.upsample(factor, preserve_counts=preserve_counts) map_out.coadd(m_tmp) return map_out
[docs]def make_coadd_map(maps, proj, shape, preserve_counts=True): if isinstance(proj, WCS): return make_coadd_wcs(maps, proj, shape) elif isinstance(proj, HPX): return make_coadd_hpx(maps, proj, shape, preserve_counts=preserve_counts) else: raise Exception("Can't co-add map of unknown type %s" % type(proj))
[docs]def make_coadd_wcs(maps, wcs, shape): data = np.zeros(shape) axes = wcs_utils.wcs_to_axes(wcs, shape) for m in maps: c = wcs_utils.wcs_to_coords(m.wcs, m.counts.shape) o = np.histogramdd(c.T, bins=axes[::-1], weights=np.ravel(m.counts))[0] data += o return Map(data, copy.deepcopy(wcs))
[docs]def make_coadd_hpx(maps, hpx, shape, preserve_counts=True): data = np.zeros(shape) axes = hpx_utils.hpx_to_axes(hpx, shape) for m in maps: if m.hpx.order != hpx.order: m_copy = m.ud_grade(hpx.order, preserve_counts) else: m_copy = m c = hpx_utils.hpx_to_coords(m_copy.hpx, m_copy.counts.shape) o = np.histogramdd(c.T, bins=axes, weights=np.ravel(m_copy.counts))[0] data += o return HpxMap(data, copy.deepcopy(hpx))
[docs]def read_map_from_fits(fitsfile, extname=None): """ """ proj, f, hdu = fits_utils.read_projection_from_fits(fitsfile, extname) if isinstance(proj, WCS): ebins = fits_utils.find_and_read_ebins(f) m = Map(hdu.data, proj, ebins=ebins) elif isinstance(proj, HPX): m = HpxMap.create_from_hdu(hdu, proj.ebins) else: raise Exception("Did not recognize projection type %s" % type(proj)) return m
[docs]class Map_Base(object): """ Abstract representation of a 2D or 3D counts map.""" def __init__(self, counts): self._counts = counts @property def counts(self): return self._counts @property def data(self): return self._counts @data.setter def data(self, val): if val.shape != self.data.shape: raise Exception('Wrong shape.') self._counts = val
[docs] def get_pixel_skydirs(self): """Get a list of sky coordinates for the centers of every pixel. """ raise NotImplementedError("MapBase.get_pixel_skydirs()")
[docs] def get_pixel_indices(self, lats, lons): """Return the indices in the flat array corresponding to a set of coordinates """ raise NotImplementedError("MapBase.get_pixel_indices()")
[docs] def sum_over_energy(self): """Reduce a counts cube to a counts map by summing over the energy planes """ raise NotImplementedError("MapBase.sum_over_energy()")
[docs] def get_map_values(self, lons, lats, ibin=None): """Return the map values corresponding to a set of coordinates. """ raise NotImplementedError("MapBase.get_map_values()")
[docs] def interpolate(self, lon, lat, egy=None): """Return the interpolated map values corresponding to a set of coordinates. """ raise NotImplementedError("MapBase.interpolate()")
[docs]class Map(Map_Base): """ Representation of a 2D or 3D counts map using WCS. """ def __init__(self, counts, wcs, ebins=None): """ Parameters ---------- counts : `~numpy.ndarray` Counts array in row-wise ordering (LON is first dimension). """ Map_Base.__init__(self, counts) self._wcs = wcs self._npix = counts.shape[::-1] if len(self._npix) == 3: self._xindex = 2 self._yindex = 1 elif len(self._npix) == 2: self._xindex = 1 self._yindex = 0 else: raise Exception('Wrong number of dimensions for Map object.') # if len(self._npix) != 3 and len(self._npix) != 2: # raise Exception('Wrong number of dimensions for Map object.') self._width = np.array([np.abs(self.wcs.wcs.cdelt[0]) * self.npix[0], np.abs(self.wcs.wcs.cdelt[1]) * self.npix[1]]) self._pix_center = np.array([(self.npix[0] - 1.0) / 2., (self.npix[1] - 1.0) / 2.]) self._pix_size = np.array([np.abs(self.wcs.wcs.cdelt[0]), np.abs(self.wcs.wcs.cdelt[1])]) self._skydir = SkyCoord.from_pixel(self._pix_center[0], self._pix_center[1], self.wcs) self._ebins = ebins if ebins is not None: self._ectr = np.exp(utils.edge_to_center(np.log(ebins))) else: self._ectr = None @property def wcs(self): return self._wcs @property def npix(self): return self._npix @property def skydir(self): """Return the sky coordinate of the image center.""" return self._skydir @property def width(self): """Return the dimensions of the image.""" return self._width @property def pix_size(self): """Return the pixel size along the two image dimensions.""" return self._pix_size @property def pix_center(self): """Return the ROI center in pixel coordinates.""" return self._pix_center
[docs] @classmethod def create_from_hdu(cls, hdu, wcs): return cls(hdu.data.T, wcs)
[docs] @classmethod def create_from_fits(cls, fitsfile, **kwargs): hdu = kwargs.get('hdu', 0) with fits.open(fitsfile) as hdulist: header = hdulist[hdu].header data = hdulist[hdu].data header = fits.Header.fromstring(header.tostring()) wcs = WCS(header) ebins = None if 'ENERGIES' in hdulist: tab = Table.read(fitsfile, 'ENERGIES') ectr = np.array(tab.columns[0]) ebins = np.exp(utils.center_to_edge(np.log(ectr))) elif 'EBOUNDS' in hdulist: tab = Table.read(fitsfile, 'EBOUNDS') emin = np.array(tab['E_MIN']) / 1E3 emax = np.array(tab['E_MAX']) / 1E3 ebins = np.append(emin, emax[-1]) return cls(data, wcs, ebins)
[docs] @classmethod def create(cls, skydir, cdelt, npix, coordsys='CEL', projection='AIT', ebins=None, differential=False): crpix = np.array([n / 2. + 0.5 for n in npix]) if ebins is not None: if differential: nebins = len(ebins) else: nebins = len(ebins) - 1 data = np.zeros(list(npix) + [nebins]).T naxis = 3 else: data = np.zeros(npix).T naxis = 2 wcs = wcs_utils.create_wcs(skydir, coordsys, projection, cdelt, crpix, naxis=naxis, energies=ebins) return cls(data, wcs, ebins=ebins)
[docs] def create_image_hdu(self, name=None, **kwargs): return fits.ImageHDU(self.counts, header=self.wcs.to_header(), name=name)
[docs] def create_primary_hdu(self): return fits.PrimaryHDU(self.counts, header=self.wcs.to_header())
[docs] def sum_over_energy(self): """ Reduce a 3D counts cube to a 2D counts map """ # Note that the array is using the opposite convention from WCS # so we sum over axis 0 in the array, but drop axis 2 in the WCS object return Map(np.sum(self.counts, axis=0), self.wcs.dropaxis(2))
[docs] def xypix_to_ipix(self, xypix, colwise=False): """Return the flattened pixel indices from an array multi-dimensional pixel indices. Parameters ---------- xypix : list List of pixel indices in the order (LON,LAT,ENERGY). colwise : bool Use column-wise pixel indexing. """ return np.ravel_multi_index(xypix, self.npix, order='F' if colwise else 'C', mode='raise')
[docs] def ipix_to_xypix(self, ipix, colwise=False): """Return array multi-dimensional pixel indices from flattened index. Parameters ---------- colwise : bool Use column-wise pixel indexing. """ return np.unravel_index(ipix, self.npix, order='F' if colwise else 'C')
[docs] def ipix_swap_axes(self, ipix, colwise=False): """ Return the transposed pixel index from the pixel xy coordinates if colwise is True (False) this assumes the original index was in column wise scheme """ xy = self.ipix_to_xypix(ipix, colwise) return self.xypix_to_ipix(xy, not colwise)
[docs] def get_pixel_skydirs(self): """Get a list of sky coordinates for the centers of every pixel. """ xpix = np.linspace(0, self.npix[0] - 1., self.npix[0]) ypix = np.linspace(0, self.npix[1] - 1., self.npix[1]) xypix = np.meshgrid(xpix, ypix, indexing='ij') return SkyCoord.from_pixel(np.ravel(xypix[0]), np.ravel(xypix[1]), self.wcs)
[docs] def get_pixel_indices(self, lons, lats, ibin=None): """Return the indices in the flat array corresponding to a set of coordinates Parameters ---------- lons : array-like 'Longitudes' (RA or GLON) lats : array-like 'Latitidues' (DEC or GLAT) ibin : int or array-like Extract data only for a given energy bin. None -> extract data for all energy bins. Returns ---------- pixcrd : list Pixel indices along each dimension of the map. """ lons = np.array(lons, ndmin=1) lats = np.array(lats, ndmin=1) if len(lats) != len(lons): raise RuntimeError('Map.get_pixel_indices, input lengths ' 'do not match %i %i' % (len(lons), len(lats))) if len(self._npix) == 2: pix_x, pix_y = self._wcs.wcs_world2pix(lons, lats, 0) pixcrd = [np.floor(pix_x).astype(int), np.floor(pix_y).astype(int)] elif len(self._npix) == 3: all_lons = np.expand_dims(lons, -1) all_lats = np.expand_dims(lats, -1) if ibin is None: all_bins = (np.expand_dims( np.arange(self.npix[2]), -1) * np.ones(lons.shape)).T else: all_bins = ibin l = self.wcs.wcs_world2pix(all_lons, all_lats, all_bins, 0) pix_x = l[0] pix_y = l[1] pixcrd = [np.floor(l[0]).astype(int), np.floor(l[1]).astype(int), all_bins.astype(int)] return pixcrd
[docs] def get_map_values(self, lons, lats, ibin=None): """Return the map values corresponding to a set of coordinates. Parameters ---------- lons : array-like 'Longitudes' (RA or GLON) lats : array-like 'Latitidues' (DEC or GLAT) ibin : int or array-like Extract data only for a given energy bin. None -> extract data for all bins Returns ---------- vals : numpy.ndarray((n)) Values of pixels in the flattened map, np.nan used to flag coords outside of map """ pix_idxs = self.get_pixel_indices(lons, lats, ibin) idxs = copy.copy(pix_idxs) m = np.empty_like(idxs[0], dtype=bool) m.fill(True) for i, p in enumerate(pix_idxs): m &= (pix_idxs[i] >= 0) & (pix_idxs[i] < self._npix[i]) idxs[i][~m] = 0 vals = self.counts.T[idxs] vals[~m] = np.nan return vals
[docs] def interpolate(self, lon, lat, egy=None): if len(self.npix) == 2: pixcrd = self.wcs.wcs_world2pix(lon, lat, 0) else: if egy is None: egy = self._ectr pixcrd = self.wcs.wcs_world2pix(lon, lat, egy, 0) pixcrd[2] = np.array(utils.val_to_pix(np.log(self._ectr), np.log(egy)), ndmin=1) points = [] for npix in self.npix: points += [np.linspace(0, npix - 1., npix)] data = self.counts fn = RegularGridInterpolator(points, data.T, bounds_error=False, fill_value=None) return fn(np.column_stack(pixcrd))
[docs] def interpolate_at_skydir(self, skydir): coordsys = wcs_utils.get_coordsys(self.wcs) if coordsys == 'CEL': skydir = skydir.transform_to('icrs') return self.interpolate(skydir.ra.deg, skydir.dec.deg) else: skydir = skydir.transform_to('galactic') return self.interpolate(skydir.l.deg, skydir.b.deg)
[docs]class HpxMap(Map_Base): """ Representation of a 2D or 3D counts map using HEALPix. """ def __init__(self, counts, hpx): """ C'tor, fill with a counts vector and a HPX object """ super(HpxMap, self).__init__(counts) self._hpx = hpx self._wcs2d = None self._hpx2wcs = None @property def hpx(self): return self._hpx
[docs] @classmethod def create_from_hdu(cls, hdu, ebins): """ Creates and returns an HpxMap object from a FITS HDU. hdu : The FITS ebins : Energy bin edges [optional] """ hpx = HPX.create_from_hdu(hdu, ebins) colnames = hdu.columns.names cnames = [] if hpx.conv.convname == 'FGST_SRCMAP_SPARSE': pixs = hdu.data.field('PIX') chans = hdu.data.field('CHANNEL') keys = chans * hpx.npix + pixs vals = hdu.data.field('VALUE') nebin = len(ebins) data = np.zeros((nebin, hpx.npix)) data.flat[keys] = vals else: for c in colnames: if c.find(hpx.conv.colstring) == 0: cnames.append(c) nebin = len(cnames) data = np.ndarray((nebin, hpx.npix)) for i, cname in enumerate(cnames): data[i, 0:] = hdu.data.field(cname) return cls(data, hpx)
[docs] @classmethod def create_from_hdulist(cls, hdulist, **kwargs): """ Creates and returns an HpxMap object from a FITS HDUList extname : The name of the HDU with the map data ebounds : The name of the HDU with the energy bin data """ extname = kwargs.get('hdu', hdulist[1].name) ebins = fits_utils.find_and_read_ebins(hdulist) return cls.create_from_hdu(hdulist[extname], ebins)
[docs] @classmethod def create_from_fits(cls, fitsfile, **kwargs): hdulist = fits.open(fitsfile) return cls.create_from_hdulist(hdulist, **kwargs)
[docs] def create_image_hdu(self, name=None, **kwargs): kwargs['extname'] = name return self.hpx.make_hdu(self.counts, **kwargs)
[docs] def make_wcs_from_hpx(self, sum_ebins=False, proj='CAR', oversample=2, normalize=True): """Make a WCS object and convert HEALPix data into WCS projection NOTE: this re-calculates the mapping, if you have already calculated the mapping it is much faster to use convert_to_cached_wcs() instead Parameters ---------- sum_ebins : bool sum energy bins over energy bins before reprojecting proj : str WCS-projection oversample : int Oversampling factor for WCS map normalize : bool True -> perserve integral by splitting HEALPix values between bins returns (WCS object, np.ndarray() with reprojected data) """ self._wcs_proj = proj self._wcs_oversample = oversample self._wcs_2d = self.hpx.make_wcs(2, proj=proj, oversample=oversample) self._hpx2wcs = HpxToWcsMapping(self.hpx, self._wcs_2d) wcs, wcs_data = self.convert_to_cached_wcs(self.counts, sum_ebins, normalize) return wcs, wcs_data
[docs] def convert_to_cached_wcs(self, hpx_in, sum_ebins=False, normalize=True): """ Make a WCS object and convert HEALPix data into WCS projection Parameters ---------- hpx_in : `~numpy.ndarray` HEALPix input data sum_ebins : bool sum energy bins over energy bins before reprojecting normalize : bool True -> perserve integral by splitting HEALPix values between bins returns (WCS object, np.ndarray() with reprojected data) """ if self._hpx2wcs is None: raise Exception('HpxMap.convert_to_cached_wcs() called ' 'before make_wcs_from_hpx()') if len(hpx_in.shape) == 1: wcs_data = np.ndarray(self._hpx2wcs.npix) loop_ebins = False hpx_data = hpx_in elif len(hpx_in.shape) == 2: if sum_ebins: wcs_data = np.ndarray(self._hpx2wcs.npix) hpx_data = hpx_in.sum(0) loop_ebins = False else: wcs_data = np.ndarray((self.counts.shape[0], self._hpx2wcs.npix[0], self._hpx2wcs.npix[1])) hpx_data = hpx_in loop_ebins = True else: raise Exception('Wrong dimension for HpxMap %i' % len(hpx_in.shape)) if loop_ebins: for i in range(hpx_data.shape[0]): self._hpx2wcs.fill_wcs_map_from_hpx_data( hpx_data[i], wcs_data[i], normalize) pass wcs_data.reshape((self.counts.shape[0], self._hpx2wcs.npix[ 0], self._hpx2wcs.npix[1])) # replace the WCS with a 3D one wcs = self.hpx.make_wcs(3, proj=self._wcs_proj, energies=np.log10(self.hpx.ebins), oversample=self._wcs_oversample) else: self._hpx2wcs.fill_wcs_map_from_hpx_data( hpx_data, wcs_data, normalize) wcs_data.reshape(self._hpx2wcs.npix) wcs = self._wcs_2d return wcs, wcs_data
[docs] def get_pixel_skydirs(self): """Get a list of sky coordinates for the centers of every pixel. """ sky_coords = self._hpx.get_sky_coords() if self.hpx.coordsys == 'GAL': return SkyCoord(l=sky_coords.T[0], b=sky_coords.T[1], unit='deg', frame='galactic') else: return SkyCoord(ra=sky_coords.T[0], dec=sky_coords.T[1], unit='deg', frame='icrs')
[docs] def get_pixel_indices(self, lats, lons): """Return the indices in the flat array corresponding to a set of coordinates """ return self._hpx.get_pixel_indices(lats, lons)
[docs] def sum_over_energy(self): """ Reduce a counts cube to a counts map """ # We sum over axis 0 in the array, and drop the energy binning in the # hpx object return HpxMap(np.sum(self.counts, axis=0), self.hpx.copy_and_drop_energy())
[docs] def get_map_values(self, lons, lats, ibin=None): """Return the indices in the flat array corresponding to a set of coordinates Parameters ---------- lons : array-like 'Longitudes' (RA or GLON) lats : array-like 'Latitidues' (DEC or GLAT) ibin : int or array-like Extract data only for a given energy bin. None -> extract data for all bins Returns ---------- vals : numpy.ndarray((n)) Values of pixels in the flattened map, np.nan used to flag coords outside of map """ theta = np.pi / 2. - np.radians(lats) phi = np.radians(lons) pix = hp.ang2pix(self.hpx.nside, theta, phi, nest=self.hpx.nest) if self.data.ndim == 2: return self.data[:, pix] if ibin is None else self.data[ibin, pix] else: return self.data[pix]
[docs] def interpolate(self, lon, lat, egy=None, interp_log=True): """Interpolate map values. Parameters ---------- interp_log : bool Interpolate the z-coordinate in logspace. """ if self.data.ndim == 1: theta = np.pi / 2. - np.radians(lat) phi = np.radians(lon) return hp.pixelfunc.get_interp_val(self.counts, theta, phi, nest=self.hpx.nest) else: return self._interpolate_cube(lon, lat, egy, interp_log)
def _interpolate_cube(self, lon, lat, egy=None, interp_log=True): """Perform interpolation on a healpix cube. If egy is None then interpolation will be performed on the existing energy planes. """ shape = np.broadcast(lon, lat, egy).shape lon = lon * np.ones(shape) lat = lat * np.ones(shape) theta = np.pi / 2. - np.radians(lat) phi = np.radians(lon) vals = [] for i, _ in enumerate(self.hpx.evals): v = hp.pixelfunc.get_interp_val(self.counts[i], theta, phi, nest=self.hpx.nest) vals += [np.expand_dims(np.array(v, ndmin=1), -1)] vals = np.concatenate(vals, axis=-1) if egy is None: return vals.T egy = egy * np.ones(shape) if interp_log: xvals = utils.val_to_pix(np.log(self.hpx.evals), np.log(egy)) else: xvals = utils.val_to_pix(self.hpx.evals, egy) vals = vals.reshape((-1, vals.shape[-1])) xvals = np.ravel(xvals) v = map_coordinates(vals, [np.arange(vals.shape[0]), xvals], order=1) return v.reshape(shape)
[docs] def swap_scheme(self): """ """ hpx_out = self.hpx.make_swapped_hpx() if self.hpx.nest: if self.data.ndim == 2: data_out = np.vstack([hp.pixelfunc.reorder( self.data[i], n2r=True) for i in range(self.data.shape[0])]) else: data_out = hp.pixelfunc.reorder(self.data, n2r=True) else: if self.data.ndim == 2: data_out = np.vstack([hp.pixelfunc.reorder( self.data[i], r2n=True) for i in range(self.data.shape[0])]) else: data_out = hp.pixelfunc.reorder(self.data, r2n=True) return HpxMap(data_out, hpx_out)
[docs] def expanded_counts_map(self): """ return the full counts map """ if self.hpx._ipix is None: return self.counts output = np.zeros( (self.counts.shape[0], self.hpx._maxpix), self.counts.dtype) for i in range(self.counts.shape[0]): output[i][self.hpx._ipix] = self.counts[i] return output
[docs] def explicit_counts_map(self, pixels=None): """ return a counts map with explicit index scheme Parameters ---------- pixels : `np.ndarray` or None If set, grab only those pixels. If none, grab only non-zero pixels """ # No pixel index, so build one if self.hpx._ipix is None: if self.data.ndim == 2: summed = self.counts.sum(0) if pixels is None: nz = summed.nonzero()[0] else: nz = pixels data_out = np.vstack(self.data[i].flat[nz] for i in range(self.data.shape[0])) else: if pixels is None: nz = self.data.nonzero()[0] else: nz = pixels data_out = self.data[nz] return (nz, data_out) else: if pixels is None: return (self.hpx._ipix, self.data) # FIXME, can we catch this raise RuntimeError( 'HPX.explicit_counts_map called with pixels for a map that already has pixels')
[docs] def sparse_counts_map(self): """ return a counts map with sparse index scheme """ if self.hpx._ipix is None: flatarray = self.data.flattern() else: flatarray = self.expanded_counts_map() nz = flatarray.nonzero()[0] data_out = flatarray[nz] return (nz, data_out)
[docs] def ud_grade(self, order, preserve_counts=False): """ """ new_hpx = self.hpx.ud_graded_hpx(order) if new_hpx.evals is None: nebins = 1 else: nebins = len(new_hpx.evals) shape = self.counts.shape if preserve_counts: power = -2. else: power = 0 if len(shape) == 1: new_data = hp.pixelfunc.ud_grade(self.counts, nside_out=new_hpx.nside, order_in=new_hpx.ordering, order_out=new_hpx.ordering, power=power) else: new_data = np.vstack([hp.pixelfunc.ud_grade(self.counts[i], nside_out=new_hpx.nside, order_in=new_hpx.ordering, order_out=new_hpx.ordering, power=power) for i in range(shape[0])]) return HpxMap(new_data, new_hpx)