Source code for pyimcom.splitpsf.imsubtract

"""
Routines for implementing the image subtraction step in PSF wing removal.

Functions
---------
pltshow
    Helper to determine where to save a plot.
get_wcs
    Extracts the World Coordinate System from a cached file.
run_imsubtract
    Main workflow for image subtraction step.

"""

import gc
import os
import re
import sys
import time

import asdf
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from astropy.io import fits
from astropy.wcs import WCS
from astropy.wcs.wcsapi import SlicedLowLevelWCS

# import from furry_parakeet
from furry_parakeet import (
    pyimcom_croutines,
)
from scipy.fft import irfft2, next_fast_len, rfft2
from scipy.signal import fftconvolve
from scipy.signal.windows import tukey
from scipy.special import eval_legendre

# local imports
from ..config import Config, Settings
from ..diagnostics.context_figure import ReportFigContext
from ..utils import compareutils
from ..wcsutil import (
    PyIMCOM_WCS,
    get_pix_area,
)


[docs] def fftconvolve_multi(in1, in2, out, mode="full", nb=4, workers=None, verbose=False): """ Convolve two N-dimensional arrays using FFT. This is almost a drop-in replacement for ``scipy.signal.fftconvolve``. The big difference is that the convolution is directly added to `out`, rather than being a return value. For the 2D `mode` = "valid" case, this splits up the data into `nb` blocks for the convolution. It is designed to be efficient when `in1` is much smaller than `in2`. Parameters ---------- in1 : np.ndarray First input. in2 : np.ndarray Second input. Should have the same number of dimensions as `in1`. out : np.ndarray Location to add to the output image. Must have the right dimensions. mode : str, optional Mode; options are "full", "valid", and "same" (just as for the scipy functions). nb : int, optional Number of blocks to use. workers : int, optional Number of workers for the FFTs if requesting parallelism. verbose : bool, optional Whether to print the intermediate steps. Returns ------- None """ t0 = time.time() # if we're not using valid, or not in 2D, use standard fftconvolve if mode != "valid" or len(np.shape(in1)) != 2: out += fftconvolve(in1, in2, mode=mode, workers=workers) return # Now we know we're 2D and in valid mode. Get shapes (s1y, s1x) = np.shape(in1) (s2y, s2x) = np.shape(in2) Lx = abs(s1x - s2x) + 1 Ly = abs(s1y - s2y) + 1 # If in1 is big enough that it will break the indexing if s1y >= Ly // nb: out += fftconvolve(in1, in2, mode=mode, workers=workers) return # loop over horizontal bands height = (Ly + nb - 1) // nb if height <= s1y: out += fftconvolve(in1, in2, mode=mode, workers=workers) # also return if the strip is too narrow return lenx = next_fast_len(s1x + s2x) leny = next_fast_len(s1y + height) in1_ = np.zeros((leny, lenx)) in2_ = np.zeros((leny, lenx)) in1_[:s1y, :s1x] = in1 in1_ft = rfft2(in1_, workers=workers) del in1_ for j in range(nb): gc.collect() ybottom = j * height ytop = min((j + 1) * height, Ly) dy = ytop - ybottom in2_[:, :] = 0.0 in2_[: dy + s1y - 1, :s2x] = in2[ybottom : ytop + s1y - 1, :] in2_ft = rfft2(in2_, workers=workers) * in1_ft if verbose: print("y =", ybottom, ytop, "of Ly =", Ly) out[ybottom:ytop, :] += irfft2(in2_ft, s=(leny, lenx), workers=workers)[ s1y - 1 : dy + s1y - 1, s1x - 1 : Lx + s1x - 1 ] # B = fftconvolve(in1, in2[ybottom : ytop + s1y - 1, :], mode="valid") # print(np.shape(A), np.shape(B)) # print(np.amax(np.abs(A)), np.amax(np.abs(B)), np.amax(np.abs(A-B))) print(f"t = {time.time()-t0:6.3f} s, shape =", (leny, lenx), "ft =", np.shape(in1_ft)) sys.stdout.flush() del in2_, in1_ft, in2_ft gc.collect()
[docs] def pltshow(plt, display, pars={}): """ Where to save a plot. Parameters ---------- plt : matplotlib.pyplot The pyplot module to use for plotting. display : str or None Sends to file (if string), screen (None), or nowhere (if '/dev/null') pars : dict, optional Parameters for saving the file. Must be provided if a file is requested. Returns ------- None Notes ----- The `pars` dictionary contains the keys: * 'type' : str, currently only supports 'window' * 'obsid' : int, observation ID * 'sca' : int, SCA number * 'ix' : int, x block index * 'iy' : int, y block index """ if display is None: plt.show() return if display == "/dev/null": return # if we get here, we need to save the file if pars["type"].lower() == "window": obsid = pars["obsid"] sca = pars["sca"] ix = pars["ix"] iy = pars["iy"] plt.savefig(display + f"_{obsid}_{sca}_{ix:02d}_{iy:02d}.png")
[docs] def get_wcs(cachefile): """ Gets the WCS from a cached FITS file. If a gwcs is used, finds the attached ASDF file and reads that. Parameters ---------- cachefile : str Name of the cached file. Returns ------- pyimcom.wcsutils.PyIMCOM_WCS The World Coordinate System of the cached file. """ with fits.open(cachefile) as hdul: if "WCSTYPE" in hdul[1].header and hdul[1].header["WCSTYPE"][:4].lower() == "gwcs": with asdf.open(cachefile[:-5] + "_wcs.asdf") as f2: return PyIMCOM_WCS(f2["wcs"]) return PyIMCOM_WCS(hdul["SCIWCS"].header)
[docs] def get_wcs_from_infile(infile): """ Gets the "sub-WCS" from the FITS file. Using SlicedLowLevelWCS avoids extra axes which create additional complications in the WCS. Parameters ---------- infile : str Name of the file. Returns ------- sub-WCS The World Coordinate System of the file with only the necessary axes. """ g = infile[0].header block_wcs = SlicedLowLevelWCS(WCS(g), slices=[0, 0, slice(0, g["NAXIS2"]), slice(0, g["NAXIS1"])]) return block_wcs
[docs] def run_imsubtract( config_file, display=None, scanum=None, local_output=False, max_img=None, workers=None, wcs_shortcut=True ): """ Main routine to run imsubtract. Parameters ---------- config_file : str Location of a configuration file. display : str or None, optional Display location for intermediate steps. scanum : int or None, optional If not None, only run this SCA. Should be in range 1..18, inclusive. (Mostly used for parallelization.) local_output : bool, optional Whether to direct the file to local output instead of the cache directory. (This will normally be the default False; it is provided only so that if more than one user runs tests at the same time, they can use True to avoid a collision.) max_img : int, optional If provided, does computations for a maximum number of SCAs. Most users will want the default of None; this is provided mainly for testing. workers : int, optional Number of workers for the FFTs if requesting parallelism. wcs_shortcut : bool, optional If set, allows interpolation methods to speed up WCS computations. Notes ----- There are several options for `display`: * `display` = None : print to screen * `display` = '/dev/null' : don't save * `display` = any other string : save to ``display+f'_{obsid}_{sca}_{ix:02d}_{iy:02d}.png'`` """ # load the file using Config and get information cfgdata = Config(config_file) info = cfgdata.inlayercache block_path = cfgdata.outstem ra = cfgdata.ra * (np.pi / 180) # convert to radians dec = cfgdata.dec * (np.pi / 180) # convert to radians lonpole = cfgdata.lonpole * (np.pi / 180) # convert to radians nblock = cfgdata.nblock n1 = cfgdata.n1 # number of postage stamps n2 = cfgdata.n2 # size of single run postage_pad = cfgdata.postage_pad # postage stamp padding dtheta_deg = cfgdata.dtheta blocksize_rad = n1 * n2 * dtheta_deg * (np.pi) / 180 # convert to radians # get information from Settings pix_size = Settings.pixscale_native # native pixel scale in arcsec sca_nside = Settings.sca_nside # length of sca side, in native pixels # separate the path from the inlayercache info m = re.search(r"^(.*)\/(.*)", info) if m: path = m.group(1) exp = m.group(2) # create empty list of exposures exps = [] # find all the fits files and add them to the list for _, _, files in os.walk(path): for file in files: if file.startswith(exp) and file.endswith(".fits") and file[-6].isdigit(): exps.append(file) print("list of files:", exps) # loop over the list of observation pair files (for each SCA) count = 0 for exp in exps: # get SCA and obsid m2 = re.search(r"(\w*)_0*(\d*)_(\d*).fits", exp) if m2: obsid = int(m2.group(2)) sca = int(m2.group(3)) if scanum is not None and scanum != sca: continue # only do the given SCA print("OBSID: ", obsid, "SCA: ", sca) # inlayercache data --- changed to context manager structure with fits.open(path + "/" + exp) as hdul: # read in the input image, I I_img = np.memmap( path + "/" + exp[:-5] + "_data.npy", dtype=np.float32, mode="w+", shape=np.shape(hdul[0].data), ) I_img[:, :, :] = hdul[0].data # this is I # find number of layers nlayer = np.shape(I_img)[-3] # get wcs information from fits file (or asdf if indicated) sca_wcs = get_wcs(path + "/" + exp) # results from splitpsf # read in the kernel with fits.open(f"{info}.psf/psf_{obsid:d}.fits") as hdul2: K = np.copy(hdul2[sca + hdul2[0].header["KERSKIP"]].data) # get the number of pixels on the axis axis_num = K.shape[1] # kernel pixels Ncoeff = K.shape[0] # number of coefficients # get the oversampling factor oversamp = hdul2[0].header["OVSAMP"] # number of kernel pixels / native pixels # SCA padding I_pad = int(np.ceil(axis_num / 2 / oversamp)) # native pixels # define the first index needed for convolution first_index = (oversamp + 2 * oversamp * I_pad - axis_num) // 2 # get the kernel size s_in_rad = pix_size * np.pi / (180 * 3600) # convert arcsec to radians ker_size = axis_num / oversamp * s_in_rad # native pixels # start coordinate transformations # define pad pad = ker_size / 2 # at least half of the kernel size in native pixels # convert to x, y, z using wcs coords (center of SCA) x, y, z, p = compareutils.getfootprint(sca_wcs, pad) v = np.array([x, y, z]) # convert to x', y', z' # define coordinates and transformation matrix ex = np.array([np.sin(ra), -np.cos(ra), 0]) ey = np.array([-np.cos(ra) * np.sin(dec), -np.sin(dec) * np.sin(ra), np.cos(dec)]) ez = np.array([-np.cos(dec) * np.cos(ra), -np.cos(dec) * np.sin(ra), -np.sin(dec)]) T = np.array([ex, ey, ez]) # perform transformation and define individual values v_p = np.matmul(T, v) x_p = v_p[0] y_p = v_p[1] z_p = v_p[2] # define the rotation matrix, coefficient, and additional vector rot = np.array([[-np.cos(lonpole), -np.sin(lonpole)], [np.sin(lonpole), -np.cos(lonpole)]]) coeff = 2 / (1 - z_p) v_convert = np.array([x_p, y_p]) # convert to eta and xi (block coordinates) block_coords = coeff * np.matmul(rot, v_convert) # find theta in original coordinates, convert to block coordinates theta = ( 2 * np.arctan(np.sqrt(p / (2 - p))) + blocksize_rad / np.sqrt(2) + np.sqrt(2) * pad + ker_size / np.sqrt(2) ) * coeff theta_block = theta / blocksize_rad # add theta to this set of coords block_coords = np.append(block_coords, theta) # convert the units of this coordinate system to blocks block_coords_blocks = block_coords / blocksize_rad # find the center of SCA relative to the bottom left of the mosaic SCA_coords = block_coords_blocks.copy() SCA_coords[:2] += nblock / 2 # take only the xi and eta directions # find the blocks the SCA covers side = np.arange(nblock) + 0.5 xx, yy = np.meshgrid(side, side) distance = np.hypot(xx - SCA_coords[0], yy - SCA_coords[1]) in_SCA = np.where(distance <= theta_block) block_list = np.stack((in_SCA[1], in_SCA[0]), axis=-1) # define the canvas to add interpolated blocks # size is SCA+padding on both sides scaled back to kernel pixels A = oversamp * (sca_nside + 2 * I_pad) skipblocks = set() # blocks we know we can skip since they turned out to have no overlap lrbt_table = {} # the [left, right, bottom, top] of each block # get pixel area map (once) area_np = ( get_pix_area(sca_wcs, region=[-I_pad, sca_nside + I_pad, -I_pad, sca_nside + I_pad]) / (pix_size * 180 / np.pi) ** 2 ).astype(np.float32) # add for loop over layers (nlayers) for n in range(nlayer): H_canvas = np.zeros((A, A), dtype=np.float32) # define other important quantities for convolution Nl = int(np.floor(np.sqrt(Ncoeff + 0.5))) KH = np.zeros((A - axis_num + 1, A - axis_num + 1), dtype=np.float32) x_canvas = np.linspace(-I_pad - 0.5 + 0.5 / oversamp, sca_nside + I_pad - 0.5 + 0.5 / oversamp, A) u_canvas = (x_canvas - (sca_nside - 1) / 2) / (sca_nside / 2) # loop over the blocks in the list for ix, iy in block_list: if (ix, iy) in skipblocks: continue print("BLOCK: ", ix, iy) sys.stdout.flush() t0 = time.time() # open the block info with fits.open(block_path + f"_{ix:02d}_{iy:02d}.fits") as hdul3: block_wcs = get_wcs_from_infile(hdul3) print(f"+ block: {time.time()-t0:6.2f}") sys.stdout.flush() # determine the length of one axis of the block block_length = hdul3[0].header["NAXIS1"] # length in output pixels overlap = n2 * postage_pad # size of one overlap region due to postage stamp a1 = 2 * (2 * overlap - 1) / (block_length - 1) # percentage of region to have # window function taper # the '-1' is due to scipy's convention on alpha that the denominator is the distance from # the first to the last point, so 1 less than the length window = tukey(block_length, alpha=a1).astype(np.float32) # apply window function to block data in both directions block = hdul3[0].data[0, n, :, :] * window[:, None] * window[None, :] print(f"+ windowed: {time.time()-t0:6.2f}") sys.stdout.flush() # check the window function if display != "/dev/null": print("FIG") with ReportFigContext(matplotlib, plt): plt.plot(np.arange(len(window)), window, color="indigo") plt.axvline(block_length - 1, c="mediumpurple") plt.axvline(block_length - overlap - 1, c="mediumpurple") plt.axvline(block_length - 2 * overlap - 1, c="mediumpurple") plt.xlim(block_length - 3 * overlap, block_length + overlap) plt.plot(block_length - 2, window[block_length - 2], c="darkmagenta", marker="o") plt.plot( block_length - 2 * overlap, window[block_length - 2 * overlap], c="darkmagenta", marker="o", ) plt.plot( block_length - overlap, window[block_length - overlap], c="blueviolet", marker="o" ) plt.plot( block_length - overlap - 2, window[block_length - overlap - 2], c="blueviolet", marker="o", ) pltshow( plt, display, {"type": "window", "obsid": obsid, "sca": sca, "ix": ix, "iy": iy} ) # print( # window[block_length - 1], # window[block_length - 2 * overlap], # window[block_length - 1] + window[block_length - 2 * overlap], # ) # print( # window[block_length - overlap], # window[block_length - overlap - 1], # window[block_length - overlap] + window[block_length - overlap - 1], # ) print(f"+ figure: {time.time()-t0:6.2f}") sys.stdout.flush() gc.collect() if (ix, iy) in lrbt_table: # get bounding box if we already have it [left, right, bottom, top] = lrbt_table[(ix, iy)] else: # find the 'Bounding Box' in SCA coordinates # create mesh grid for output block block_arr = np.arange(block_length) x_out, y_out = np.meshgrid(block_arr, block_arr) # convert to ra and dec using block wcs ra_sca, dec_sca = block_wcs.pixel_to_world_values(x_out, y_out, 0) del x_out, y_out # convert into SCA coordinates x_in, y_in = sca_wcs.all_world2pix(ra_sca, dec_sca, 0) del ra_sca, dec_sca # get the bounding box from the max and min values left = int(np.floor(np.min(x_in))) right = int(np.ceil(np.max(x_in))) bottom = int(np.floor(np.min(y_in))) top = int(np.ceil(np.max(y_in))) del x_in, y_in # trim bounding box to ensure not extending past SCA padding left = np.max([left, -I_pad]) right = np.min([right, sca_nside - 1 + I_pad]) bottom = np.max([bottom, -I_pad]) top = np.min([top, sca_nside - 1 + I_pad]) lrbt_table[(ix, iy)] = [left, right, bottom, top] gc.collect() print(f"+ wcsmap: {time.time()-t0:6.2f}") sys.stdout.flush() # create the bounding box mesh grid, with ovsamp # determine side lengths of the box width = int(oversamp * (right - left + 1)) height = int(oversamp * (top - bottom + 1)) # check if weight, height are positive if width <= 0 or height <= 0: skipblocks.add((ix, iy)) # can skip this block for the next layer continue # two options for getting the inverse WCS mapping: one with shortcut, one without if wcs_shortcut: # create arrays for meshgrid x = np.linspace(left - 0.5, right + 0.5, right - left + 2) y = np.linspace(bottom - 0.5, top + 0.5, top - bottom + 2) bb_x, bb_y = np.meshgrid(x, y) # map bounding box from SCA to output block coordinates ra_map, dec_map = sca_wcs.all_pix2world(bb_x, bb_y, 0) del bb_x, bb_y x_bb_temp, y_bb_temp = block_wcs.world_to_pixel_values(ra_map, dec_map, 0) del ra_map, dec_map x_bb = np.zeros((height, width)) y_bb = np.zeros((height, width)) for i in range(oversamp): fi = (i + 0.5) / oversamp x1 = (1 - fi) * x_bb_temp[:, :-1] + fi * x_bb_temp[:, 1:] y1 = (1 - fi) * y_bb_temp[:, :-1] + fi * y_bb_temp[:, 1:] for j in range(oversamp): fj = (j + 0.5) / oversamp x_bb[j::oversamp, i::oversamp] = (1 - fj) * x1[:-1, :] + fj * x1[1:, :] y_bb[j::oversamp, i::oversamp] = (1 - fj) * y1[:-1, :] + fj * y1[1:, :] del x_bb_temp, y_bb_temp, x1, y1 else: # create arrays for meshgrid x = np.linspace(left - 0.5 + 0.5 / oversamp, right + 0.5 - 0.5 / oversamp, width) y = np.linspace(bottom - 0.5 + 0.5 / oversamp, top + 0.5 + 0.5 / oversamp, height) bb_x, bb_y = np.meshgrid(x, y) # map bounding box from SCA to output block coordinates ra_map, dec_map = sca_wcs.all_pix2world(bb_x, bb_y, 0) del bb_x, bb_y x_bb, y_bb = block_wcs.world_to_pixel_values(ra_map, dec_map, 0) del ra_map, dec_map print(f"+ inv wcs map: {time.time()-t0:6.2f}") sys.stdout.flush() # add padding to the block (with window applied) block_padded = np.pad(block, 5, mode="constant", constant_values=0)[None, :, :].astype( np.float64 ) x_bb += 5 y_bb += 5 # create interpolated version of block H = np.zeros((1, np.size(x_bb))) pyimcom_croutines.iG4460C(block_padded, x_bb.ravel(), y_bb.ravel(), H) # reshape H H = H.reshape(x_bb.shape) print(f"+ interp: {time.time()-t0:6.2f}") sys.stdout.flush() # multiply by Jacobian to H # get native pixel size (in units of the ideal pixel, [0.11 arcsec]^2 for Roman) if wcs_shortcut: # previous area call: get_pix_area(sca_wcs, region=[left, right + 1, bottom, top + 1]) # note that was in steradians, this one is in ideal pixels # this should be faster native_pix = np.repeat( np.repeat( area_np[I_pad + bottom : I_pad + top + 1, I_pad + left : I_pad + right + 1], oversamp, axis=1, ), oversamp, axis=0, ) else: native_pix = ( get_pix_area(sca_wcs, region=[left, right + 1, bottom, top + 1], ovsamp=oversamp) / (pix_size * 180 / np.pi) ** 2 ) H *= native_pix print(f"+ area: {time.time()-t0:6.2f}") sys.stdout.flush() # add H to H_canvas H_canvas[ oversamp * (bottom + I_pad) : oversamp * (top + 1 + I_pad), oversamp * (left + I_pad) : oversamp * (right + 1 + I_pad), ] += H # some cleanup del H, native_pix # apply convolution to canvas for lu in range(Nl): # save first multiplication Hlu = H_canvas * eval_legendre(lu, u_canvas).astype(np.float32)[None, :] for lv in range(Nl): print("Convolve", lu, lv) sys.stdout.flush() fftconvolve_multi( K[lu + lv * Nl, :, :], Hlu * eval_legendre(lv, u_canvas).astype(np.float32)[:, None], KH, mode="valid", nb=6, workers=workers, ) # subtract from the input image (using less memory) I_img[n, :, :] -= KH[first_index:-first_index:oversamp, first_index:-first_index:oversamp] # save outside of the layer for loop # write output file for each exposure fname = f"{info}_{obsid:08d}_{sca:02d}_subI.fits" if local_output: fname = f"{obsid:08d}_{sca:02d}_subI.fits" print("saving >>", fname) sys.stdout.flush() # this version copies HDU #1 (which contains the WCS) with fits.open(path + "/" + exp) as f_in: fits.HDUList([fits.PrimaryHDU(data=I_img), f_in[1]]).writeto(fname, overwrite=True) # exit if we've specified a maximum number of SCAs count += 1 if max_img is not None and count == max_img: return
if __name__ == "__main__": """Calling program is here. python3 -m pyimcom.splitpsf.imsubtract <config> <sca> [<output images>] (uses plt.show() if output stem not specified; output image directory is relative to cache file) """
[docs] start = time.time()
# get the json file config_file = sys.argv[1] # get the SCA (0 for all of them) sca = int(sys.argv[2]) if sca == 0: sca = None display = "/dev/null" if len(sys.argv) > 3: display = sys.argv[3] run_imsubtract(config_file, display=display, scanum=sca, workers=4) end = time.time() elapsed = end - start print(f"finished at t = {elapsed:.2f} s")