Source code for pyimcom.imdestripe

"""
Program to remove correlated noise stripes from RST images.

Classes
-------
Sca_img
     Class defining an SCA image object
Parameters
        Class holding the destriping parameters for a given mosaic
Cost_models
        Class holding the cost function models. Currently only quadratic is supported

Functions
---------
write_to_file
    Function to write some text to an output file.
    If filename is None, output is directed to stdout
save_fits
    Save a 2D image to a FITS file with locking, retries, and atomic rename.
apply_object_mask
    Apply a bright object mask to an image.
quadratic
    Quadratic cost function f(x) = x^2
absolute
    Absolute cost function f(x) = |x|
huber_loss
    Huber loss cost function
quad_prime
    Derivative of quadratic cost function f'(x) = 2x
abs_prime
    Derivative of absolute cost function f'(x) = sign(x)
huber_prime
    Derivative of Huber loss cost function
get_scas
    Function to get a list of all SCA images and their WCSs for this mosaic
interpolate_image_bilinear
    Interpolate values from a "reference" SCA image onto a "target" SCA coordinate grid
transpose_interpolate
    Interpolate backwards from image_A to image_B space.
transpose_par
    Sum up the values of an image across rows
get_effective_gain
    retrieve the effective gain and n_eff of the image. valid only for already-interpolated images
get_ids
    Take an SCA label and parse it out to get the Obsid and SCA id strings.
save_snapshot
    Save restart state to pickle file.
get_neighbors
    Get the neighboring SCAs for each SCA in the mosaic
residual_function
    Calculate the residual image, = grad(epsilon)
residual_function_single
    Helper function to calculate residuals for a single SCA
cost_function_single
    Helper function to calculate cost for a single SCA
linear_search_general
    Perform a linear search to find the optimal step size alpha along a given direction
linear_search_quadratic
    Calculate optimal step size alpha along a given direction, for quadratic cost function
conjugate_gradient
    Perform the conjugate gradient algorithm to minimize the cost function
"""

import copy
import cProfile
import csv
import gc
import glob
import io
import logging
import multiprocessing as mp
import os
import pickle
import pstats
import random
import re
import sys
import time
import traceback
import uuid
from concurrent.futures import ProcessPoolExecutor, as_completed

import asdf
import numpy as np
from astropy import wcs
from astropy.io import fits
from filelock import FileLock, Timeout
from memory_profiler import memory_usage
from scipy.ndimage import binary_dilation, binary_propagation

from .config import JWST, Config, Settings
from .utils import compareutils
from .wcsutil import PyIMCOM_WCS

try:
    import furry_parakeet.pyimcom_croutines as pyimcom_croutines
except ImportError:
    import pyimcom_croutines

import warnings

from asdf.exceptions import AsdfConversionWarning, AsdfPackageVersionWarning

[docs] testoutputs = { "testing": False, # Is this run for testing only? (usually False) "test_image_dir": "./", # Location for test outputs (gets overwritten) }
# Suppress ASDF warnings warnings.filterwarnings("ignore", category=AsdfConversionWarning) warnings.filterwarnings("ignore", category=AsdfPackageVersionWarning) if JWST: Settings.jwst()
[docs] filters = Settings.RomanFilters
[docs] t0_global = time.time() # after imports
# Module settings
[docs] use_output_float = np.float32
[docs] tempdir = str(os.environ["TMPDIR"]) + "/" if "TMPDIR" in os.environ else "./"
# For test outputs: set sca=0 to not produce test outputs. img_full_output = {"obsid": 670, "scaid": 10} if not testoutputs["testing"]:
[docs] img_full_output = {"obsid": -1, "scaid": -1} # don't do these big outputs
[docs] class Cost_models: """ Class holding the cost function models. This is a dictionary of functions """ def __init__(self, cfg): models = { "quadratic": (quadratic, quad_prime), "absolute": (absolute, abs_prime), "huber_loss": (huber_loss, huber_prime), }
[docs] self.model = cfg.cost_model
if self.model == "huber_loss": self.thresh = cfg.hub_thresh write_to_file(f"Cost model is Huber Loss with threshold: {self.thresh}") else: self.thresh = None self.f, self.f_prime = models[self.model]
[docs] class Sca_img: """ Class defining an SCA image object. Parameters -------- scaid : Str the SCA id obsid : Str the observation id cfg : Config object built from the configuration file interpolated : Bool True if you want the interpolated version of this SCA and not the original. Default False add_objmask : Bool True if you want to apply the permanent pixel mask and a bright object mask. Default True Attributes -------- image : 2D np array the SCA image, shape=(Settings.sca_nside, Settings.sca_nside) shape : Tuple the shape of the image w : WCS object the astropy.wcs object associated with this SCA obsid : Str observation ID of this SCA image scaid : Str SCA ID (position on focal plane) of this SCA image mask : 2D np array The full pixel mask that is used on this image. Is correct only after applying masks to image g_eff : 2D np array Effective gain in each pixel of the image params_subtracted : Bool True if parameters have been subtracted from this image. cfg : Config object the configuration object passed in at initialization Methods -------- apply_noise Apply the appropriate lab noise frame to the SCA image apply_permanent_mask Apply the SCA permanent pixel mask to the image apply_asdf_mask Apply the SCA ASDF file mask to the image apply_all_mask Apply the full SCA mask to the image subtract_parameters Subtract a given set of parameters from self.image; updates self.image, self.params_subtracted get_coordinates Create an array of ra, dec coords for the image make_interpolated Construct a version of this SCA interpolated from other, overlapping ones. Writes the interpolated image out to the disk, to be read/used later """ def __init__( self, obsid, scaid, cfg, tempdir=tempdir, interpolated=False, add_objmask=True, indata_type="fits" ): if interpolated: file = fits.open(tempdir + "interpolations/" + obsid + "_" + scaid + "_interp.fits", memmap=True) image_hdu = "PRIMARY" self.w = wcs.WCS(file[image_hdu].header) self.image = np.copy(file[image_hdu].data).astype(np.float64) self.header = file[image_hdu].header self.shape = np.shape(self.image) self._file_handle = None self.type = "interpolated" file.close() else: if indata_type == "fits": file = fits.open( cfg.ds_obsfile + filters[cfg.use_filter] + "_" + obsid + "_" + scaid + ".fits", memmap=True, ) image_hdu = "PRIMARY" self.w = wcs.WCS(file[image_hdu].header) self.image = np.copy(file[image_hdu].data).astype(np.float64) self.header = file[image_hdu].header self.shape = np.shape(self.image) self._file_handle = None self.type = "fits" self.mask_threshold = [0, 0.3] # m, c for object masking file.close() elif indata_type == "asdf": fp = cfg.ds_obsfile + filters[cfg.use_filter] + "_" + obsid + "_" + scaid + ".asdf" self._file_handle = asdf.open(fp, memmap=True, lazy_load=True) self.w = PyIMCOM_WCS(self._file_handle["roman"]["meta"]["wcs"]) self.image = np.array(self._file_handle["roman"]["data"], dtype=np.float64) self.header = None self.shape = np.shape(self.image) self.type = "asdf" self.mask_threshold = [0, 0.3] # m, c for object masking # Note: keep _file_handle open to maintain memmap elif indata_type == "jwst": file = fits.open( cfg.ds_obsfile + obsid + "_" + scaid + "_crf.fits", memmap=True, ) # ds_obsfile is 'path/to/jw' # obsid is <ppppp><ooo><vvv>_<gg><s><aa>_<eeeee> # scaid is nrcb<n> image_hdu = "SCI" self.w = wcs.WCS(file[image_hdu].header) self.image = np.copy(file[image_hdu].data).astype(np.float64) self.header = file[image_hdu].header self.shape = np.shape(self.image) self._file_handle = None self.type = "jwst" self.mask_threshold = [15, 5] # m, c for object masking file.close()
[docs] self.obsid = obsid
[docs] self.scaid = scaid
[docs] self.mask = np.ones(self.shape, dtype=bool)
[docs] self.params_subtracted = False
[docs] self.cfg = cfg
# Calculate effecive gain if cfg.gaindir is False: if not os.path.isfile(tempdir + obsid + "_" + scaid + "_geff.dat"): g_eff = np.memmap( tempdir + obsid + "_" + scaid + "_geff.dat", dtype="float64", mode="w+", shape=self.shape ) ra, dec = self.get_coordinates(pad=2.0) ra = ra.reshape((Settings.sca_nside + 2, Settings.sca_nside + 2)) dec = dec.reshape((Settings.sca_nside + 2, Settings.sca_nside + 2)) derivs = np.array( ( (ra[1:-1, 2:] - ra[1:-1, :-2]) / 2, (ra[2:, 1:-1] - ra[:-2, 1:-1]) / 2, (dec[1:-1, 2:] - dec[1:-1, :-2]) / 2, (dec[2:, 1:-1] - dec[:-2, 1:-1]) / 2, ) ) derivs_px = np.reshape(np.transpose(derivs), (Settings.sca_nside**2, 2, 2)) det_mat = np.reshape(np.linalg.det(derivs_px), (Settings.sca_nside, Settings.sca_nside)) g_eff[:, :] = 1 / ( np.abs(det_mat) * np.cos(np.deg2rad(dec[1 : Settings.sca_nside + 1, 1 : Settings.sca_nside + 1])) ) g_eff.flush() del g_eff self.g_eff = np.memmap( tempdir + obsid + "_" + scaid + "_geff.dat", dtype="float64", mode="r", shape=self.shape ) else: # PLACEHOLDER for reading in real flat fields as gain # Needs to be adapted once actual file format is known g_eff_file = asdf.open(cfg.gaindir + cfg.use_filter + "_geff.fits", memmap=True) self.g_eff = g_eff_file[int(scaid) - 1].data.astype(np.float64) g_eff_file.close() # Add a noise frame if specified in config file if cfg.ds_noisefile is not False: self.apply_noise() if add_objmask: if indata_type == "asdf": self.apply_asdf_mask() elif indata_type == "fits": self.apply_permanent_mask() elif indata_type == "jwst": self.apply_jwst_mask() _, object_mask = apply_object_mask( self.image, threshold_m=self.mask_threshold[0], threshold_c=self.mask_threshold[1], type=self.type, ) self.mask *= np.logical_not( object_mask ) # self.mask = True for good pixels, so set object_mask'ed pixels to False if ( not os.path.exists(cfg.ds_outpath + self.obsid + "_" + self.scaid + "_mask.fits") and testoutputs["testing"] ): mask_img = self.mask.astype("uint8") save_fits( mask_img, self.obsid + "_" + self.scaid + "_mask", dir=cfg.ds_outpath + "masks/", overwrite=True, ) # def close(self): # """ # Close the file handle when done # """ # if self._file_handle is not None: # self._file_handle.close() # self._file_handle = None # def __del__(self): # """ # Automatically close file when object is garbage collected # """ # self.close()
[docs] def apply_noise(self): """ Add detector noise to self.image """ noiseframe = ( np.copy( fits.open(self.cfg.ds_noisefile + self.obsid + "_" + self.scaid + ".fits")["PRIMARY"].data ) * 1.458 * 50 ) # times gain and N_frames self.image += noiseframe[4 : Settings.sca_nside + 4, 4 : Settings.sca_nside + 4] filename = self.obsid + "_" + self.scaid + "_noise" if not os.path.exists(testoutputs["test_image_dir"] + filename + ".fits"): save_fits(self.image, filename, dir=testoutputs["test_image_dir"], overwrite=True)
[docs] def apply_permanent_mask(self): """ Apply permanent pixel mask. Updates self.image and self.mask """ if self.cfg.permanent_mask is None: pm = fits.open( self.cfg.ds_obsfile + filters[self.cfg.use_filter] + "_" + self.obsid + "_" + self.scaid + ".fits", )["MASK"].data.astype(bool) else: pm = fits.open(self.cfg.permanent_mask)[0].data[int(self.scaid) - 1].astype(bool) self.image *= ~pm self.mask *= ~pm
[docs] def apply_asdf_mask(self): """ Apply ASDF file mask. Updates self.image and self.mask """ mask = fits.open( self.cfg.ds_obsfile + filters[self.cfg.use_filter] + "_" + self.obsid + "_" + self.scaid + "_mask.fits", memmap=False, # change once masks are uint8 type )[1].data.astype(bool) self.image *= ~mask self.mask *= ~mask
[docs] def apply_jwst_mask(self): """ Apply JWST data quality mask. Updates self.image and self.mask """ # JWST bad pixels are represented as NaN in SCI; map them to zero and mark as masked. valid = ~np.isnan(self.image) self.image = np.where(valid, self.image, 0.0) self.mask = np.logical_and(self.mask, valid)
[docs] def apply_all_mask(self): """ Apply permanent pixel mask. Updates self.image in-place """ # Multiply by mask but set Nans to zero self.image = np.where(np.isnan(self.image), 0, self.image * self.mask)
# self.image *= self.mask
[docs] def subtract_parameters(self, p, j): """ Subtract a set of parameters from the SCA image. Updates self.image and self.params_subtracted Parameters -------- p : Parameters object containing params of current iteration j : int the index of the SCA image into all_scas list """ if self.params_subtracted: write_to_file("WARNING: PARAMS HAVE ALREADY BEEN SUBTRACTED. ABORTING NOW") sys.exit() params_image = p.forward_par(j) # Make destriping params into an image self.image = self.image - params_image # Update I_A.image to have the params image subtracted off self.params_subtracted = True
[docs] def get_coordinates(self, pad=0.0): """ Create an array of ra, dec coords for the image Parameters -------- pad : Float64 N pixels of padding to add to the array. Default 0.0 Returns -------- ra, dec; 1D np.arrays of length (height*width) 1D arrays of ra, dec coordinates for each pixel in the image """ wcs = self.w h = self.shape[0] + pad w = self.shape[1] + pad x_i, y_i = np.meshgrid(np.arange(h), np.arange(w), indexing="xy") x_i -= pad / 2.0 y_i -= pad / 2.0 x_flat = x_i.flatten() y_flat = y_i.flatten() ra, dec = wcs.all_pix2world(x_flat, y_flat, 0) # 0 is for the first frame (1-indexed) return ra, dec
[docs] def make_interpolated(self, ind, scalist, neighbors, tempdir=tempdir, params=None, N_eff_min=0.5): """ Construct a version of this SCA interpolated from other, overlapping ones. Writes the interpolated image out to the disk, to be read/used later The N_eff_min parameter requires some minimum effective coverage, otherwise masks that pixel. Parameters -------- ind : int index of this SCA in all_scas list scalist : List of Str the list of all SCAs in this mosaic neighbors : Dict dictionary where keys are SCA indices and values are lists of indices of overlapping SCAs tempdir : Str directory to write temporary files to params : Parameters object parameters to be subtracted from contributing SCAs; default None N_eff_min : float Effective coverage needed for a pixel to contribute to the interpolation """ this_interp = np.zeros(self.shape) if not os.path.isfile(tempdir + self.obsid + "_" + self.scaid + "_Neff.dat"): N_eff = np.memmap( tempdir + self.obsid + "_" + self.scaid + "_Neff.dat", dtype="float32", mode="w+", shape=self.shape, ) make_Neff = True else: N_eff = np.memmap( tempdir + self.obsid + "_" + self.scaid + "_Neff.dat", dtype="float32", mode="r", shape=self.shape, ) make_Neff = False N_BinA = 0 sca_b_list = neighbors[ind] for k in sca_b_list: sca_b = scalist[k] obsid_B, scaid_B = get_ids(sca_b, indata_type=self.type) N_BinA += 1 I_B = Sca_img(obsid_B, scaid_B, self.cfg, indata_type=self.type) # Initialize image B if params: I_B.subtract_parameters(params, k) I_B.apply_all_mask() # now I_B is masked B_interp = np.zeros_like(self.image) interpolate_image_bilinear(I_B, self, B_interp) if make_Neff: B_mask_interp = np.zeros_like(self.image) interpolate_image_bilinear( I_B, self, B_mask_interp, mask=I_B.mask ) # interpolate B pixel mask onto A grid if img_full_output["scaid"] != 0 and testoutputs["testing"]: obsid_match = obsid_B == str(img_full_output["obsid"]) scaid_match = scaid_B == str(img_full_output["scaid"]) if obsid_match and scaid_match and make_Neff: filename = ( f'{img_full_output["obsid"]}_{img_full_output["scaid"]}_' f'B_{self.obsid}_{self.scaid}_interp' ) save_fits(B_interp, filename, dir=testoutputs["test_image_dir"]) obsid_match = self.obsid == str(img_full_output["obsid"]) scaid_match = self.scaid == str(img_full_output["scaid"]) if obsid_match and scaid_match and make_Neff: filename = ( f'{img_full_output["obsid"]}_{img_full_output["scaid"]}_' f'A_{obsid_B}_{scaid_B}_interp' ) save_fits(B_interp, filename, dir=testoutputs["test_image_dir"]) this_interp += B_interp if make_Neff: N_eff += B_mask_interp # Free memory / close files # del I_B.image # del I_B.g_eff # I_B.close() # del I_B write_to_file( f"Interpolation of {self.obsid}_{self.scaid} done. Number of contributing SCAs: {N_BinA}" ) new_mask = N_eff > N_eff_min this_interp = np.where(new_mask, this_interp / np.where(new_mask, N_eff, N_eff_min), 0) header = self.w.to_header(relax=True) this_interp = np.divide(this_interp, self.g_eff) # KL move these to tmp? save_fits( this_interp, self.obsid + "_" + self.scaid + "_interp", dir=tempdir + "interpolations/", header=header, ) if make_Neff: N_eff.flush() del N_eff # gc.collect() print(f"Finished Interpolation SCA {self.obsid}_{self.scaid}") sys.stdout.flush() return this_interp, new_mask
[docs] class Parameters: """ Class holding the parameters for a given mosaic. This can be the destriping parameters, or additional parameters that need to be the same shape and have the same methods Parameters ---------- cfg : Config object built from the configuration file scalist : list of Strings the list of SCAs in this mosaic Attributes ---------- model : Str Which destriping model to use, which then specifies the number of parameters per row. Must be a key of the model_params dict n_rows : Int Number of rows in the image, or number of rows to fit ds model over params_per_row : Int Number of parameters per row, set by model_params[model] params : 2D np array The actual array of parameters. current_shape : Str The current shape (1D or 2D) of SCA params scalist : list of Strings the list of SCAs in this mosaic, format: filter_obsid_scaid Methods ------- params_2_images Reshape params into a 2D array, with one row per SCA forward_par Reshape one row of params array (one SCA) into a 2D array by projection along rows """ def __init__(self, cfg, scalist=[]): model_params = {"constant": 1, "linear": 2} if cfg.ds_model not in model_params: raise ValueError(f"Model {cfg.ds_model} not in model_params dictionary.")
[docs] self.model = cfg.ds_model
[docs] self.n_rows = cfg.ds_rows
[docs] self.n_cols = getattr(cfg, "sca_nside", Settings.sca_nside)
[docs] self.params_per_row = model_params[self.model]
if cfg.amp_cols is not None and cfg.amp_cols > 0: if self.n_cols % cfg.amp_cols != 0: raise ValueError( f"Column width (amp_cols={cfg.amp_cols}) does not evenly divide " f"image columns ({self.n_cols}). Ensure n_cols % amp_cols == 0." ) self.amp_cols = cfg.amp_cols self.n_col_blocks = self.n_cols // cfg.amp_cols else: self.amp_cols = None self.n_col_blocks = 0 params_per_row = self.n_rows * self.params_per_row params_per_sca = params_per_row + self.n_col_blocks
[docs] self.params = np.zeros((len(scalist), params_per_sca))
[docs] self.current_shape = "2D"
[docs] self.scalist = scalist
[docs] def params_2_images(self): """ Reshape flattened parameters into 2D array with 1 row per sca and n_rows (in image) * params_per_row entries """ self.params = np.reshape(self.params, (len(self.scalist), self.n_rows * self.params_per_row)) self.current_shape = "2D"
[docs] def forward_par(self, sca_i): """ Takes one SCA row (n_rows) from the params and casts it into 2D (n_rows x n_cols) with additive row and column-block offsets. Parameters -------- sca_i : Int Index of which SCA to recast into 2D Returns -------- 2D np.array, the image of SCA_i's parameters (stripe image) """ if self.current_shape != "2D": self.params_2_images() params_row_len = self.n_rows * self.params_per_row # Build row image using existing per-row logic row_params = np.array(self.params[sca_i, :params_row_len])[:, np.newaxis] row_image = row_params * np.ones((self.n_rows, self.n_cols)) # Build column-block image if enabled if self.n_col_blocks > 0: col_params = self.params[sca_i, params_row_len : params_row_len + self.n_col_blocks] col_image = np.zeros((self.n_rows, self.n_cols)) for b in range(self.n_col_blocks): j_start = b * self.amp_cols j_end = j_start + self.amp_cols col_image[:, j_start:j_end] = col_params[b] # Broadcast over rows return row_image + col_image else: return row_image
[docs] def write_to_file(text, filename=None): """ Function to write some text to an output file Parameters -------- text : Str The text to print filename : Str Filename to write out to, or if None, output is directed to stdout """ if filename is None: print(text) elif not os.path.exists(filename): with open(filename, "w+") as f: f.write(text + "\n") else: with open(filename, "a") as f: f.write(text + "\n")
[docs] def save_fits(image, filename, dir=None, overwrite=True, s=False, header=None, retries=3): """ Save a 2D image to a FITS file with locking, retries, and atomic rename. Parameters ---------- image : np.ndarray 2D array to write. filename : str Output filename without extension. dir : str Directory to save into. overwrite : bool Whether to overwrite the final target file. s : bool Whether to print status messages. header : fits.Header or None Optional FITS header. retries : int Number of write retry attempts if write fails. """ fp = os.path.join(dir, filename + ".fits") lockpath = fp + ".lock" lock = FileLock(lockpath) for attempt in range(retries): try: with lock.acquire(timeout=30): tmp_fp = fp + f".{uuid.uuid4().hex}.tmp" hdu = fits.PrimaryHDU(image, header=header) if header is not None else fits.PrimaryHDU(image) hdu.writeto(tmp_fp, overwrite=overwrite) os.replace(tmp_fp, fp) # Atomic move to final path if s: write_to_file(f"Array {filename} written out to {fp}") return # Success except Timeout: write_to_file(f"Failed to write {filename}; lock acquire timeout") return except OSError as e: if attempt < retries - 1: wait_time = 1 + random.random() print(f"Write failed for {fp} (attempt {attempt + 1}): {e}. Retrying in {wait_time:.2f}s...") time.sleep(wait_time) else: raise RuntimeError(f"Failed to write {fp} after {retries} attempts. Last error: {e}") from e
[docs] def apply_object_mask( image, mask=None, threshold_m=0, threshold_c=0.3, inplace=False, type="fits", ): """ Apply a bright object mask to an image. Parameters -------- image : 2D numpy array the image to be masked. mask : 2D boolean array, optional the pre-existing object mask. Default: None threshold_m : float factor to multiply with the median for thresholding. threshold_c : float constant to add to the threshold. inplace : Bool Whether to modify the input image directly. type : Str Type of the input image. Options: 'fits', 'asdf', 'jwst'. Default: 'fits' Returns -------- image_out : 2D np.array the masked image. neighbor_mask : 2D np.array the mask applied """ if mask is not None and isinstance(mask, np.ndarray): neighbor_mask = mask else: median_val = np.median(image) if type == "jwst": # Use robust background/noise estimation + seeded region growing for JWST scenes. valid = np.isfinite(image) if not np.any(valid): high_value_mask = np.zeros_like(image, dtype=bool) seed_threshold = 0.0 grow_threshold = 0.0 else: work = np.array(image, copy=True) vals = work[valid] # Iterative clipping to estimate sky background robustly. clip_vals = vals for _ in range(3): bkg = np.median(clip_vals) mad = np.median(np.abs(clip_vals - bkg)) sigma = 1.4826 * mad if sigma <= 0: break keep = np.abs(clip_vals - bkg) < (3.0 * sigma) if np.count_nonzero(keep) < 100: break clip_vals = clip_vals[keep] bkg = np.median(clip_vals) mad = np.median(np.abs(clip_vals - bkg)) sigma = 1.4826 * mad if not np.isfinite(sigma) or sigma <= 0: sigma = np.std(clip_vals) if clip_vals.size > 1 else 0.0 residual = np.zeros_like(work) residual[valid] = work[valid] - bkg # Two-level thresholding: bright seeds + lower-threshold growth. seed_threshold = max(threshold_c, 6.0 * sigma) grow_threshold = max(0.5 * threshold_c, 2.5 * sigma) seed_mask = np.logical_and(valid, residual >= seed_threshold) grow_candidates = np.logical_and(valid, residual >= grow_threshold) grown_mask = binary_propagation(seed_mask, mask=grow_candidates) high_value_mask = binary_dilation( grown_mask, structure=np.ones((3, 3), dtype=bool), iterations=2 ) else: high_value_mask = image >= threshold_m * median_val + threshold_c neighbor_mask = binary_dilation(high_value_mask, structure=np.ones((5, 5), dtype=bool)) if inplace: image[neighbor_mask] = 0 return image, neighbor_mask else: image_out = np.where(neighbor_mask, 0, image) return image_out, neighbor_mask
[docs] def quadratic(x): """Quadratic cost function f(x) = x^2""" return x**2
[docs] def absolute(x): """Absolute cost function f(x) = |x|""" return np.abs(x)
[docs] def huber_loss(x, d): """Huber loss cost function""" return np.where(np.abs(x) <= d, quadratic(x), d**2 + 2 * d * (np.abs(x) - d))
[docs] def quad_prime(x): """Derivative of quadratic cost function f'(x) = 2x""" return 2 * x
[docs] def abs_prime(x): """Derivative of absolute cost function f'(x) = sign(x)""" return np.sign(x)
[docs] def huber_prime(x, d): """Derivative of Huber loss cost function""" return np.where(np.abs(x) <= d, quad_prime(x), 2 * d * np.sign(x))
[docs] def get_scas(filter_, obsfile, cfg, indata_type="fits", of=None): """ Function to get a list of all SCA images and their WCSs for this mosaic Parameters -------- filter_ : Str which filter to use for this run. Options: Y106, J129, H158, F184, K213 obsfile : Str prefix / path to the SCA images cfg : Config object built from the configuration file indata_type : Str input data type: 'fits' or 'asdf'. Default 'fits' of : Str filename to write output info to, or if None, output is directed to stdout Returns -------- all_scas : list of strings list of all the SCAs in this mosiac all_wcs : list of WCS objects the WCS object for each SCA in all_scas (same order) """ n_scas = 0 all_scas = [] all_wcs = [] if indata_type == "jwst": for f in sorted(glob.glob(obsfile + "*_crf.fits")): # example to match: jw04793001001_02101_00001_nrcb1_crf.fits m = re.search(r"(jw\d+_\d+_\d+_nrcb\d+)", f) if m: n_scas += 1 this_obsfile = str(m.group(1)) all_scas.append(this_obsfile) this_file = fits.open(f, memmap=True) this_wcs = wcs.WCS(this_file["SCI"].header) all_wcs.append(this_wcs) this_file.close() else: for f in sorted(glob.glob(obsfile + filter_ + "_*")): m = re.search(r"(\w\d+)_(\d+)_(\d+)", f) if m: if indata_type == "fits": n_scas += 1 this_obsfile = str(m.group(0)) all_scas.append(this_obsfile) this_file = fits.open(f, memmap=True) this_wcs = wcs.WCS(this_file["PRIMARY"].header) all_wcs.append(this_wcs) this_file.close() elif indata_type == "asdf": if ("noise" not in f) and ("mask" not in f): n_scas += 1 this_obsfile = str(m.group(0)) all_scas.append(this_obsfile) with asdf.open(f, memmap=False, lazy_load=True) as this_file: this_wcs = PyIMCOM_WCS(this_file["roman"]["meta"]["wcs"]) all_wcs.append(this_wcs) write_to_file(f"N SCA images in this mosaic: {str(n_scas)}", of) write_to_file("------- SCA List -------", of) for i, s in enumerate(all_scas): write_to_file(f"SCA {i}: {s}", of) return all_scas, all_wcs
[docs] def interpolate_image_bilinear(image_B, image_A, interpolated_image, mask=None): """ Interpolate values from a "reference" SCA image onto a "target" SCA coordinate grid Uses pyimcom_croutines.bilinear_interpolation( float* image, float* g_eff, int rows, int cols, float* coords, int num_coords, float* interpolated_image) Parameters -------- image_B : SCA object the image to be interpolated image_A : SCA object the image whose grid you are interpolating B onto interpolated_image : 2D np array all zeros with shape of Image A. Updated in place to be the interpolation of img. B onto A's grid mask : 2D np array, optional if provided, this mask is interpolated instead of image_B.image """ x_target, y_target, is_in_ref = compareutils.map_sca2sca(image_A.w, image_B.w, pad=0) coords = np.column_stack((y_target.ravel(), x_target.ravel())) # Verify data just before C call rows = int(image_B.shape[0]) cols = int(image_B.shape[1]) num_coords = coords.shape[0] sys.stdout.flush() sys.stderr.flush() if mask is not None and isinstance(mask, np.ndarray): mask_geff = np.ones_like(image_A.image) pyimcom_croutines.bilinear_interpolation( mask, mask_geff, rows, cols, coords, num_coords, interpolated_image ) else: pyimcom_croutines.bilinear_interpolation( image_B.image, image_B.g_eff, rows, cols, coords, num_coords, interpolated_image ) sys.stdout.flush() sys.stderr.flush()
[docs] def transpose_interpolate(image_A, wcs_A, image_B, original_image): """ Interpolate backwards from image_A to image_B space. Uses bilinear_transpose( float* image, int rows, int cols, float* coords, int num_coords, float* original_image) Parameters -------- image_A : 2D np array the already-interpolated gradient image wcs_A : wcs.WCS object image A's WCS object image_B : SCA object the image we're interpolating the gradient back onto original_image : 2D np array the gradient image re-interpolated into image B space Updated in place """ x_target, y_target, is_in_ref = compareutils.map_sca2sca(wcs_A, image_B.w, pad=0) coords = np.column_stack((y_target.ravel(), x_target.ravel())) rows = int(image_B.shape[0]) cols = int(image_B.shape[1]) num_coords = coords.shape[0] pyimcom_croutines.bilinear_transpose(image_A, rows, cols, coords, num_coords, original_image)
[docs] def transpose_par(img, cfg=None): """ Extract parameter contributions from an image via adjoint of forward_par. For row-only model: sums each row across columns. For row+column model: also sums each column-block across rows. Parameters -------- img : 2D np.array Input gradient/residual image cfg : Config object or None If provided and amp_cols is enabled, returns concatenated [row_contributions, col_block_contributions]. If None, returns only row sums. Returns -------- 1D np.array, concatenation of [row sums, column-block sums] if cfg enables col_blocks, otherwise just row sums (backward compatible) """ row_contrib = np.sum(img, axis=1) if cfg is not None and hasattr(cfg, "amp_cols") and cfg.amp_cols is not None and cfg.amp_cols > 0: n_cols = img.shape[1] # Use actual image width, not Settings.sca_nside n_col_blocks = n_cols // cfg.amp_cols col_contrib = np.zeros(n_col_blocks) for b in range(n_col_blocks): j_start = b * cfg.amp_cols j_end = j_start + cfg.amp_cols col_contrib[b] = np.sum(img[:, j_start:j_end]) # Sum over all rows in this block return np.concatenate([row_contrib, col_contrib]) else: return row_contrib
[docs] def get_effective_gain(sca, tempdir=tempdir, indata_type="fits"): """ Retrieve the effective gain and n_eff of the image. valid only for already-interpolated images Parameters -------- sca : Str format like "<prefix>_<obsid>_<scaid>" describing which SCA to get the effective gain for tempdir : Str directory where the effective gain files are stored indata_type : Str The type of input data; default "fits" Returns -------- g_eff : memmap 2D np.array the effective gain in each pixel N_eff : memmap 2D np.array how many image "B"s contributed to that interpolated image """ if indata_type == "fits": m = re.search(r"_(\d+)_(\d+)", sca) obsid = m.group(1) scaid = m.group(2) elif indata_type == "jwst": m = re.search(r"jw(\d+_\d+_\d+)_(nrcb\d+)", sca) obsid = m.group(1) scaid = m.group(2) g_eff = np.memmap( tempdir + obsid + "_" + scaid + "_geff.dat", dtype="float64", mode="r", shape=(Settings.sca_nside, Settings.sca_nside), ) N_eff = np.memmap( tempdir + obsid + "_" + scaid + "_Neff.dat", dtype="float32", mode="r", shape=(Settings.sca_nside, Settings.sca_nside), ) return g_eff, N_eff
[docs] def get_ids(sca, indata_type="fits"): """ Take an SCA label and parse it out to get the Obsid and SCA id strings. Parameters -------- sca : Str The sca name from all_scas list, formatted like <obsid>_<scaid> indata_type : Str The type of input data; default "fits" Returns -------- obsid : Str the observation ID scaid : Str the SCA ID (position in focal plane) """ if indata_type == "fits": m = re.search(r"_(\d+)_(\d+)", sca) obsid = m.group(1) scaid = m.group(2) elif indata_type == "jwst": m = re.search(r"jw(\d+_\d+_\d+)_(nrcb\d+)", sca) obsid = m.group(1) scaid = m.group(2) return obsid, scaid
[docs] def save_snapshot( p, grad, epsilon, psi, direction, grad_prev, direction_prev, cg_model, tol, thresh, norm_0, cost_model, i, restart_file, of=None, ): """ Save restart state to pickle file. Parameters -------- p : Parameters object current parameters grad : 2D np array current gradient epsilon : 2D np array current cost psi : 3D np array current residuals direction : 2D np array current CG direction grad_prev : 2D np array previous gradient direction_prev : 2D np array previous CG direction cg_model : Str which CG model is being used tol : Float tolerance for convergence thresh : Float threshold for Huber loss cost function norm_0 : Float initial norm of the gradient cost_model : Str which cost function is being used i : Int current iteration number restart_file : Str path to the restart pickle file of : Str filename to write output info to, or if None, output is directed to stdout """ crash_state = { "iteration": i, "p": p, "grad": grad, "epsilon": epsilon, "direction": direction, "grad_prev": grad_prev, "psi": psi, "direction_prev": direction_prev, "cg_model": cg_model, "tol": tol, "thresh": thresh, "norm_0": norm_0, "cost_model": cost_model, } with open(restart_file, "wb") as f: pickle.dump(crash_state, f) write_to_file(f"Checkpoint saved at iteration {i+1} -> {restart_file}", of)
[docs] def get_neighbors(scalist, ov_mat, overlap_thresh=0.1): """ Get a dictionary of overlapping SCAs for each SCA in the mosaic Parameters -------- scalist : List of Str the list of all SCAs in this mosaic ov_mat : 2D np array the overlap matrix for all SCAs in this mosaic overlap_thresh : Float minimum overlap fraction to consider two SCAs as neighbors; default 0.1 Returns -------- neighbors : Dict dictionary where keys are SCA indices and values are lists of indices of overlapping SCAs """ neighbors = {} for k, sca_a in enumerate(scalist): # noqa: B007 neighbors[k] = [j for j in range(len(scalist)) if ov_mat[k, j] >= overlap_thresh and j != k] return neighbors
[docs] def residual_function( psi, f_prime, scalist, wcslist, neighbors, thresh, workers, cfg, extrareturn=False, of=None, indata_type="fits", ): """ Calculate the residual image, = grad(epsilon) Parameters -------- psi : 3D np array the image difference array (I_A - J_A) (N_SCA, Settings.sca_nside, Settings.sca_nside) f_prime : Function the derivative of the cost function f in the future this should be set by default based on what you pass for f scalist : List of Str the list of all SCAs in this mosaic wcslist : List of WCS objects the WCS object for each SCA in scalist (same order) neighbors : Dict dictionary where keys are SCA indices and values are lists of indices of overlapping SCAs thresh : Float threshold for Huber loss cost function; default None workers : Int number of parallel workers to use cfg : Config object the configuration for this run extrareturn : Bool if True, return residual terms 1 and 2 separately; Default False in addition to full residuals. returns resids, resids1, resids2 of : Str filename to write output info to, or if None, output is directed to stdout indata_type : Str input data type: One of 'fits', 'asdf', 'jwst'. Default 'fits' Returns -------- resids : 2D np array with one row per SCA and one col per parameter """ resids = Parameters(cfg, scalist).params if extrareturn: resids1 = np.zeros_like(resids) resids2 = np.zeros_like(resids) write_to_file("Residual calculation started", of) sys.stdout.flush() t_r_0 = time.time() start_method = "forkserver" if os.name.lower() == "posix" else "spawn" ctx = mp.get_context(start_method) with ProcessPoolExecutor(max_workers=workers, mp_context=ctx) as executor: futures = [ executor.submit( residual_function_single, k, sca_a, wcslist[k], psi[k, :, :], f_prime, scalist, neighbors, thresh, cfg, of=of, indata_type=indata_type, ) for k, sca_a in enumerate(scalist) ] for future in as_completed(futures): k, term_1, term_2_list = future.result() resids[k, :] -= term_1 if extrareturn: resids1[k, :] -= term_1 # Process term_2 contributions for j, term_2 in term_2_list: resids[j, :] += term_2 if extrareturn: resids2[j, :] += term_2 # KL explicitly give output locations to write_to_file (these should go to the diagnostics directory) # could give cfg to write_to_file write_to_file(f"Residuals calculation finished in {(time.time() - t_r_0) / 60} minutes.", of) write_to_file(f"Average time making resids per sca: {(time.time() - t_r_0) / len(scalist)} seconds", of) if extrareturn: return resids, resids1, resids2 return resids
[docs] def residual_function_single( k, sca_a, wcs_a, psi_a, f_prime, scalist, neighbors, thresh, cfg, of=None, indata_type="fits" ): """ Calculate the residual for a single SCA image Parameters -------- k : Int index of this SCA in scalist sca_a : Str the SCA label, formatted like <obsid>_<scaid> wcs_a : wcs.WCS object the WCS object for this SCA psi_a : 2D np array the difference image I_A - J_A for this SCA f_prime : Function the derivative of the cost function f scalist : List of Str the list of all SCAs in this mosaic neighbors : Dict dictionary where keys are SCA indices and values are lists of indices of overlapping SCAs thresh : Float threshold for Huber loss cost function; default None cfg : Config object the configuration for this run of : Str filename to write output info to, or if None, output is directed to stdout indata_type : Str input data type: One of 'fits', 'asdf', 'jwst'. Default 'fits' Returns -------- k : Int index of this SCA in scalist term_1 : 1D np array the first residual term for this SCA term_2_list : List of tuples list of (j, term_2) tuples containing value for term 2 for each SCA j that overlaps with this one """ # Go and get the WCS object for image A obsid_A, scaid_A = get_ids(sca_a, indata_type=indata_type) # Calculate and then transpose the gradient of I_A-J_A gradient_interpolated = f_prime(psi_a, thresh) if thresh is not None else f_prime(psi_a) term_1 = transpose_par(gradient_interpolated, cfg=cfg) # Retrieve the effective gain and N_eff to normalize the gradient before transposing back g_eff_A, n_eff_A = get_effective_gain(sca_a, indata_type=indata_type) # Avoid dividing by zero valid_mask = n_eff_A != 0 # denom = g_eff_A * n_eff_A gradient_interpolated[valid_mask] = gradient_interpolated[valid_mask] / ( g_eff_A[valid_mask] * n_eff_A[valid_mask] ) gradient_interpolated[~valid_mask] = 0 term_2_list = [] for j in neighbors[k]: sca_b = scalist[j] obsid_B, scaid_B = get_ids(sca_b, indata_type=indata_type) I_B = Sca_img(obsid_B, scaid_B, cfg, indata_type=indata_type) # Initialize image B gradient_original = np.zeros(I_B.shape) transpose_interpolate(gradient_interpolated, wcs_a, I_B, gradient_original) gradient_original *= I_B.g_eff term_2 = transpose_par(gradient_original, cfg=cfg) term_2_list.append((j, term_2)) # I_B.close() # del I_B # gc.collect() return k, term_1, term_2_list
[docs] def compute_boundary_continuity_penalty( destriped_image, mask, amp_cols, col_boundary_const, chunk_width=50, chunk_height=100 ): """ Penalize large discontinuities in the destriped image across amp-width boundaries, pushing the final destriped image towards being continuous. The discontinuity is measured as the mean absolute difference in destriped values across boundaries, computed over non-masked pixels in chunks of 100 rows x 100 columns (50 on each side). Parameters ---------- destriped_image : 2D np.array The destriped image I_A, already masked mask : 2D bool array The mask (True = good pixels), same shape as destriped_image amp_cols : int or None Column-block width. If None or <= 0, no penalty. col_boundary_const : int Strength of the column boundary penalty. If <= 0, no penalty. chunk_width : int HALF width of the chunks to consider for the penalty. Mean is over this many cols on each side. (Default: 50) chunk_height : int Height of the chunks to consider for the penalty. This is a number of rows. (Default: 100) Returns ------- penalty : float Scalar penalty value to add to cost function """ if amp_cols is None or amp_cols <= 0: return 0.0 lambda_reg = col_boundary_const if lambda_reg <= 0: return 0.0 n_rows, n_cols = destriped_image.shape n_col_blocks = n_cols // amp_cols penalty = 0.0 write_to_file( f"Computing boundary continuity penalty for {n_col_blocks} column blocks with chunks of size" f" {chunk_width}x{chunk_height} (yielding N chunks per boundary: {n_rows // (4 * chunk_height) } )" ) # Loop over each column-block boundary for b in range(1, n_col_blocks): left_col_lower = b * amp_cols - chunk_width right_col_upper = b * amp_cols + chunk_width col_chunk_lower = slice(left_col_lower, b * amp_cols) col_chunk_upper = slice(b * amp_cols, right_col_upper) # Loop over chunks for chunk_start in range(0, n_rows, 4 * chunk_height): chunk_end = min(chunk_start + chunk_height, n_rows) row_chunk = slice(chunk_start, chunk_end) # Extract left and right columns for this chunk left_vals = destriped_image[row_chunk, col_chunk_lower] right_vals = destriped_image[row_chunk, col_chunk_upper] left_mask = mask[row_chunk, col_chunk_lower] right_mask = mask[row_chunk, col_chunk_upper] # Compute the means of the left and right chunks, but only count pixels that are nonzero left_mean = np.mean(left_vals[left_mask]) right_mean = np.mean(right_vals[right_mask]) # Compute the difference between the means mean_diff = left_mean - right_mean # Square and accumulate (penalizes both positive and negative differences equally) penalty += mean_diff**2 return lambda_reg * penalty
[docs] def cost_function_single(j, sca_a, p, f, scalist, neighbors, thresh, cfg, of=None, indata_type="fits"): """ Calculate the cost function for a single SCA image Parameters -------- j : Int index of this SCA in scalist sca_a : Str the SCA label, formatted like <obsid>_<scaid> p : Parameters object the current parameters for de-striping f : Function the cost function form scalist : List of Str the list of all SCAs in this mosaic neighbors : Dict dictionary where keys are SCA indices and values are lists of indices of overlapping SCAs thresh : Float threshold for Huber loss cost function; default None cfg : Config object the configuration for this run of : Str filename to write output info to, or if None, output is directed to stdout indata_type : Str input data type: One of 'fits', 'asdf', 'jwst'. Default 'fits' Returns -------- j : Int index of this SCA in scalist psi : 2D np array the difference image I_A - J_A for this SCA local_epsilon : Float the cost function value for this SCA """ full_output = globals().get("img_full_output", {"scaid": 0, "obsid": -1}) obsid_A, scaid_A = get_ids(sca_a, indata_type=indata_type) I_A = Sca_img(obsid_A, scaid_A, cfg, indata_type=indata_type) I_A.subtract_parameters(p, j) I_A.apply_all_mask() if full_output["scaid"] != 0 and testoutputs["testing"]: example_obs = obsid_A == str(full_output["obsid"]) example_sca = scaid_A == str(full_output["scaid"]) if example_obs and example_sca: hdu = fits.PrimaryHDU(I_A.image) hdu.writeto( testoutputs["test_image_dir"] + f'{full_output["obsid"]}_{full_output["scaid"]}_I_A_sub_masked.fits', overwrite=True, ) J_A_image, J_A_mask = I_A.make_interpolated(j, scalist, neighbors, params=p) J_A_mask *= I_A.mask psi = np.where(J_A_mask, I_A.image - J_A_image, 0).astype("float32") result = f(psi, thresh) if thresh is not None else f(psi) local_epsilon = np.sum(result) # Add boundary continuity penalty if column-block offsets are enabled if cfg.amp_cols is not None and cfg.amp_cols > 0 and cfg.col_boundary_const > 0: # Build the full destriped image with mask applied (True = good pixels to evaluate) boundary_penalty = compute_boundary_continuity_penalty( I_A.image, I_A.mask, cfg.amp_cols, cfg.col_boundary_const ) local_epsilon += boundary_penalty if full_output["scaid"] != 0 and testoutputs["testing"] and example_obs and example_sca: write_to_file( f"SCA {j}: boundary continuity penalty = {boundary_penalty:.2e}, " f"total cost = {local_epsilon:.2e}", of, ) hdu = fits.PrimaryHDU(J_A_image * J_A_mask) hdu.writeto( testoutputs["test_image_dir"] + f'{full_output["obsid"]}_{full_output["scaid"]}_J_A_masked.fits', overwrite=True, ) hdu = fits.PrimaryHDU(psi) hdu.writeto( testoutputs["test_image_dir"] + f'{full_output["obsid"]}_{full_output["scaid"]}_Psi.fits', overwrite=True, ) write_to_file(f"Sample stats for SCA {full_output}:", of) write_to_file(f"Image A mean: {np.mean(I_A.image)}", of) write_to_file(f"Image B mean: {np.mean(J_A_image)}", of) write_to_file(f"Psi mean: {np.mean(psi)}", of) write_to_file(f"f(Psi) mean: {np.mean(result)}", of) write_to_file(f"Local epsilon for SCA {j}: {local_epsilon}", of) return j, psi, local_epsilon
[docs] def cost_function( p, f, thresh, workers, scalist, neighbors, cfg, tempdir=tempdir, of=None, indata_type="fits" ): """ Calculate the cost function with the current de-striping parameters. Parameters -------- p : parameters object the current parameters for de-striping f : st keyword for function dictionary options; should also set an f_prime thresh : Float threshold for Huber loss cost function; default None workers : Int number of parallel workers to use scalist : List of Str the list of all SCAs in this mosaic neighbors : Dict dictionary where keys are SCA indices and values are lists of indices of threshold-overlapping SCAs cfg : Config object the configuration for this run tempdir : Str directory to store temporary files of : Str filename to write output info to, or if None, output is directed to stdout indata_type : Str input data type: One of 'fits', 'asdf', 'jwst'. Default 'fits' Returns -------- epsilon: int the total cost function summed over all images psi : D np array the difference images I_A-J_A """ write_to_file("Initializing cost function", of) t0_cost = time.time() psi = np.memmap( tempdir + "psi_all.dat", dtype=use_output_float, mode="w+", shape=(len(scalist), Settings.sca_nside, Settings.sca_nside), ) psi.fill(0) epsilon = 0 start_method = "forkserver" if os.name.lower() == "posix" else "spawn" ctx = mp.get_context(start_method) with ProcessPoolExecutor(max_workers=workers, mp_context=ctx) as executor: futures = [ executor.submit( cost_function_single, j, sca_a, p, f, scalist, neighbors, thresh, cfg, of=of, indata_type=indata_type, ) for j, sca_a in enumerate(scalist) ] for future in as_completed(futures): try: j, psi_j, local_eps = future.result() psi[j, :, :] = psi_j del psi_j epsilon += local_eps except Exception as e: write_to_file(f"Worker failed with exception: {e}", of) traceback.print_exc() raise write_to_file(f"Ending cost function. Time elapsed: {(time.time() - t0_cost) / 60} minutes", of) write_to_file( f"Average time per cost function iteration: {(time.time() - t0_cost) / len(scalist)} seconds", of ) return epsilon, psi
[docs] def linear_search_general( p, direction, f, f_prime, cost_model, epsilon_current, psi_current, grad_current, thresh, workers, scalist, wcslist, neighbors, cfg, n_iter=100, tol=10**-4, of=None, indata_type="fits", ): """ Linear search via combination bisection and secant methods for parameters that minimize the function d_epsilon/d_alpha in the given direction . Note alpha = depth of step in direction Parameters -------- p : params object the current de-striping parameters direction : 2D np array direction of conjugate gradient search f : function cost function form f_prime : function derivative of cost function form cost_model : Str which cost function is being used; options: 'quadratic', 'huber_loss' epsilon_current : float current cost function value psi_current : 3D np array current difference images (I_A - J_A) grad_current : 2D np array current gradient AKA current residuals thresh : float threshold for Huber loss cost function; default None workers : Int number of parallel workers to use scalist : List of Str the list of all SCAs in this mosaic wcslist : List of WCS objects the WCS object for each SCA in scalist (same order) neighbors : Dict dictionary where keys are SCA indices and values are lists of indices of overlapping SCAs cfg : Config object the configuration for this run n_iter : int number of iterations at which to stop searching tol : float absolute value of d_cost at which to converge of : Str filename to write output info to, or if None, output is directed to stdout indata_type : Str input data type: One of 'fits', 'asdf', 'jwst'. Default 'fits' Returns -------- best_p : parameters object containing the best parameters found via search best_psi : 3D numpy array the difference images made from images with the best_p params subtracted off """ best_epsilon, best_psi = epsilon_current, psi_current best_p = copy.deepcopy(p) # Simple linear search working_p = copy.deepcopy(p) max_p = copy.deepcopy(p) min_p = copy.deepcopy(p) convergence_crit = 99.0 method = "bisection" eta = 0.1 d_cost_init = np.sum(grad_current * direction) d_cost_tol = np.abs(d_cost_init * 1 * 10**-3) if cost_model == "quadratic": alpha_test = -eta * (np.sum(grad_current * direction)) / (np.sum(direction * direction) + 1e-12) if alpha_test <= 0: # Not a descent direction — fallback alpha_min = -0.9 alpha_max = 1.0 else: # Curvature-based search window alpha_min = alpha_test * 1e-4 alpha_max = alpha_test * 10 elif cost_model == "huber_loss": alpha_test = 1.0 alpha_min = 1e-4 alpha_max = 10 # Calculate f(alpha_max) and f(alpha_min), which need to be defined for secant update write_to_file("### Calculating min and max epsilon and cost", of) max_params = p.params + alpha_max * direction max_p.params = max_params max_epsilon, max_psi = cost_function( max_p, f, thresh, workers, scalist, neighbors, cfg, indata_type=indata_type ) max_resids = residual_function( max_psi, f_prime, scalist, wcslist, neighbors, thresh, workers, cfg, indata_type=indata_type ) del max_psi d_cost_max = np.sum(max_resids * direction) min_params = p.params + alpha_min * direction min_p.params = min_params min_epsilon, min_psi = cost_function( min_p, f, thresh, workers, scalist, neighbors, cfg, indata_type=indata_type ) min_resids = residual_function( min_psi, f_prime, scalist, wcslist, neighbors, thresh, workers, cfg, indata_type=indata_type ) del min_psi d_cost_min = np.sum(min_resids * direction) conv_params = [] for k in range(1, n_iter): t0_ls_iter = time.time() if k == 1: write_to_file("### Beginning linear search", of) write_to_file(f"LS Direction: {direction}", of) write_to_file(f"Initial params: {p.params}", of) write_to_file(f"Initial epsilon: {best_epsilon}", of) write_to_file(f"Initial d_cost: {d_cost_init}, d_cost tol: {d_cost_tol}", of) write_to_file(f"Initial alpha (min, test, max): ({alpha_min}, {alpha_test}, {alpha_max})", of) if k == n_iter - 1: write_to_file("WARNING: Linear search did not converge!!", of) if k != 1: alpha_test = alpha_min - ( d_cost_min * (alpha_max - alpha_min) / (d_cost_max - d_cost_min) ) # secant update write_to_file(f"Secant update: alpha_test={alpha_test}", of) method = "secant" if np.isnan(alpha_test): write_to_file("Secant update fail-- bisecting instead", of) alpha_test = 0.5 * (alpha_min + alpha_max) # bisection update write_to_file(f"Bisection update: alpha_test={alpha_test}", of) method = "bisection" elif k == 1: alpha_test = 0.5 * (alpha_min + alpha_max) # bisection update write_to_file(f"Bisection update: alpha_test={alpha_test}", of) working_params = p.params + alpha_test * direction working_p.params = working_params working_epsilon, working_psi = cost_function( working_p, f, thresh, workers, scalist, neighbors, cfg, indata_type=indata_type ) print(f"Global elapsed t = {(time.time()-t0_global)/60:8.1f}") working_resids = residual_function( working_psi, f_prime, scalist, wcslist, neighbors, thresh, workers, cfg, indata_type=indata_type ) print(f"Global elapsed t = {(time.time()-t0_global)/60:8.1f}") d_cost = np.sum(working_resids * direction) convergence_crit = alpha_max - alpha_min conv_params.append([working_epsilon, alpha_test, d_cost]) write_to_file(f"Ending LS iteration {k}", of) write_to_file(f"Current d_cost = {d_cost}, epsilon = {working_epsilon}", of) write_to_file(f"Working resids: {working_resids}", of) write_to_file(f"Working params: {working_p.params}", of) write_to_file(f"Current alpha range (min, test, max): {alpha_min, alpha_test, alpha_max}", of) write_to_file(f"Current delta alpha: {convergence_crit}", of) write_to_file(f"Time spent in this LS iteration: {(time.time() - t0_ls_iter) / 60} minutes.", of) # Convergence and update criteria and checks if (working_epsilon < best_epsilon + tol * alpha_test * d_cost) and (np.abs(alpha_test) >= 1e-6): best_epsilon = working_epsilon best_p = copy.deepcopy(working_p) best_psi = working_psi best_resids = working_resids write_to_file(f"Linear search convergence in {k} iterations", of) if testoutputs["testing"]: save_fits(best_p.params, "best_p", dir=testoutputs["test_image_dir"], overwrite=True) save_fits( np.array(conv_params), "conv_params", dir=testoutputs["test_image_dir"], overwrite=True ) return best_p, best_psi, best_resids, best_epsilon # Updates for next iteration, if convergence isn't yet reached if d_cost > tol and method == "bisection": alpha_max = alpha_test d_cost_max = d_cost elif d_cost < -tol and method == "bisection": alpha_min = alpha_test d_cost_min = d_cost elif d_cost * d_cost_min < 0 and method == "secant": alpha_max = alpha_test d_cost_max = d_cost elif d_cost * d_cost_max < 0 and method == "secant": alpha_min = alpha_test d_cost_min = d_cost return best_p, best_psi
[docs] def linear_search_quadratic( p, direction, f, f_prime, grad_current, thresh, workers, scalist, wcslist, neighbors, cfg, of=None, indata_type="fits", ): """ For the quadratic cost function, direct calculation of alpha that minimizes the function d_epsilon/d_alpha in the given direction . Note alpha = depth of step in direction Finds the best alpha, computes the new parameters and diff image, and prints the new cost and convergence criteria Parameters -------- p : params object the current de-striping parameters direction : 2D np array direction of conjugate gradient search f : function cost function form f_prime : function derivative of cost function form grad_current : 2D np array current gradient AKA current residuals thresh : float or None threshold for Huber loss cost function workers : Int number of parallel workers to use scalist : List of Str the list of all SCAs in this mosaic wcslist : List of WCS objects the WCS object for each SCA in scalist (same order) neighbors : Dict dictionary where keys are SCA indices and values are lists of indices of overlapping SCAs cfg : Config object the configuration for this run of : Str filename to write output info to, or if None, output is directed to stdout indata_type : Str input data type: One of 'fits', 'asdf', 'jwst'. Default ' Returns -------- new_p: parameters object containing the new parameters found via direct calculation new_psi: 3D numpy array the difference images made from images with the new_p params subtracted off new_resids: 2D np array the new residuals calculated with new_p new_epsilon: float the new cost function value calculated with new_p """ t0_ls = time.time() # Simple linear search new_p = copy.deepcopy(p) trial_p = copy.deepcopy(p) eta = 0.1 alpha_test = -eta * (np.sum(grad_current * direction)) / (np.sum(direction * direction) + 1e-12) alpha_max = 1.0 if alpha_test <= 0 else alpha_test * 10 # Calculate trial_params = p.params + alpha_max * direction trial_p.params = trial_params trial_epsilon, trial_psi = cost_function( trial_p, f, thresh, workers, scalist, neighbors, cfg, indata_type=indata_type ) trial_resids = residual_function( trial_psi, f_prime, scalist, wcslist, neighbors, thresh, workers, cfg, indata_type=indata_type ) del trial_psi, trial_epsilon alpha_new = ( alpha_max * (-np.sum(direction * grad_current)) / (np.sum(direction * (trial_resids - grad_current)) + 1e-12) ) new_params = p.params + alpha_new * direction new_p.params = new_params new_epsilon, new_psi = cost_function( new_p, f, thresh, workers, scalist, neighbors, cfg, indata_type=indata_type ) new_resids = grad_current + (alpha_new / alpha_max) * (trial_resids - grad_current) write_to_file(f"(Inside LS) Global elapsed t = {(time.time()-t0_global)/60:8.1f}", of) sys.stdout.flush() d_cost = np.sum(new_resids * direction) write_to_file("Ending LS", of) write_to_file(f"Current d_cost = {d_cost}", of) write_to_file(f"Current epsilon = {new_epsilon}", of) write_to_file(f"Working resids: {new_resids}", of) write_to_file(f"Working params: {new_p.params}", of) write_to_file(f"Current alpha: {alpha_new}", of) write_to_file(f"Time spent in this LS: {(time.time() - t0_ls) / 60} minutes.", of) sys.stdout.flush() # Convergence and update criteria and checks if testoutputs["testing"]: save_fits(new_p.params, "best_p", dir=testoutputs["test_image_dir"], overwrite=True) return new_p, new_psi, new_resids, new_epsilon
[docs] def conjugate_gradient( p, f, f_prime, thresh, workers, scalist, wcslist, neighbors, restart_file=None, time_limit=None, cfg=None, of=None, indata_type="fits", ): """ Algorithm to use conjugate gradient descent to optimize the parameters for destriping. Direction is updated using Fletcher-Reeves method Parameters -------- p : parameters object containing initial parameters guess f : function functional form to use for cost function f_prime : function the derivative of f. KL: eventually f should dictate f prime thresh : float or None threshold for Huber loss cost function; default None workers : Int number of parallel workers to use scalist : List of Str the list of all SCAs in this mosaic wcslist : List of WCS objects the WCS object for each SCA in scalist (same order) neighbors : Dict dictionary where keys are SCA indices and values are lists of indices of overlapping SCAs restart_file : Str or None if not None, path to pickle file containing restart state time_limit : int or None if not None, how much time to elapse before stopping (minutes) cfg : config object containing all config parameters of : Str output file to write log messages to. If None, messages are printed to stdout indata_type : Str the type of input data; default "fits" Returns -------- p : params object the best fit parameters for destriping the SCA images """ write_to_file("### Starting conjugate gradient optimization", of) write_to_file(f"Global elapsed t = {(time.time()-t0_global)/60:8.1f}", of) write_to_file(f"HL Threshold (None, if cost fn is not Huber Loss): {thresh}", of) write_to_file(f"Restart?: {cfg.ds_restart}\n", of) testoutputs["test_image_dir"] = cfg.ds_outpath + "test_images/" + str(0) + "/" log_file = os.path.join(cfg.ds_outpath, "cg_log.csv") if cfg.ds_restart is not None: with open(cfg.ds_restart, "rb") as f_in: state = pickle.load(f_in) write_to_file(f"Restarting CG from snapshot {cfg.ds_restart} at iteration {state['iteration']+1}", of) i = state["iteration"] p = state["p"] grad = state["grad"] epsilon = state["epsilon"] direction = state["direction"] grad_prev = state["grad_prev"] direction_prev = state["direction_prev"] cg_model = state["cg_model"] tol = state["tol"] thresh = state["thresh"] psi = state["psi"] norm_0 = state["norm_0"] cost_model = state["cost_model"] else: os.makedirs(testoutputs["test_image_dir"], exist_ok=True) # make output directory # Initialize variables cg_model = cfg.cg_model cost_model = cfg.cost_model tol = cfg.cg_tol i = -1 grad_prev = None # No previous gradient initially grad = None direction = None # No initial direction write_to_file("### Starting initial cost function", of) epsilon, psi = cost_function(p, f, thresh, workers, scalist, neighbors, cfg, indata_type=indata_type) print(f"Global elapsed t = {(time.time()-t0_global)/60:8.1f}") with open(log_file, "w", newline="") as csvfile: writer = csv.writer(csvfile) writer.writerow( [ "Iteration", "Current Norm", "Convergence Rate", "Step Size", "Gradient Magnitude", "Final d_cost", "Final Epsilon", "Time (min)", "LS time (min)", "MSE", "Parameter Change", ] ) sys.stdout.flush() for i in range(i + 1, cfg.cg_maxiter): # noqa: B020 write_to_file(f"### CG Iteration: {i + 1}", of) testoutputs["test_image_dir"] = cfg.ds_outpath + "test_images/" + str(i + 1) + "/" os.makedirs(testoutputs["test_image_dir"], exist_ok=True) t_start_CG_iter = time.time() # Compute the gradient if grad is None: grad = residual_function( psi, f_prime, scalist, wcslist, neighbors, thresh, workers, cfg, indata_type=indata_type ) write_to_file( f"Minutes spent in initial residual function: {(time.time() - t_start_CG_iter) / 60}", of ) print(f"Global elapsed t = {(time.time()-t0_global)/60:8.1f}") sys.stdout.flush() # Compute the norm of the gradient current_norm = np.linalg.norm(grad) # doesn't need to be global anymore if i == 0 and grad_prev is None: write_to_file(f"Initial gradient: {grad}", of) norm_0 = np.linalg.norm(grad) write_to_file(f"Initial norm: {norm_0}", of) write_to_file(f"Initial epsilon: {epsilon}", of) tol = tol * norm_0 direction = -grad elif (i + 1) % 10 == 0: beta = 0 write_to_file(f"Current Beta: {beta} (using method: {cg_model})", of) direction = -grad + beta * direction_prev else: # Calculate beta (direction scaling) depending on cg_model if cg_model == "FR": beta = np.sum(np.square(grad)) / np.sum(np.square(grad_prev)) elif cg_model == "PR": beta = max(0, np.sum(grad * (grad - grad_prev)) / (np.sum(np.square(grad_prev)))) elif cg_model == "HS": beta = np.sum(grad * (grad - grad_prev)) / np.sum(-direction_prev * (grad - grad_prev)) elif cg_model == "DY": beta = np.sum(np.square(grad)) / np.sum(-direction_prev * (grad - grad_prev)) else: raise ValueError(f"Unknown method for CG direction update: {cg_model}") write_to_file(f"Current Beta: {beta} (using method: {cg_model})", of) direction = -grad + beta * direction_prev if current_norm < tol: write_to_file( f"Convergence reached at iteration: {i + 1} via norm {current_norm} < tol {tol}", of ) break # Perform linear search t_start_LS = time.time() write_to_file(f"Initiating linear search in direction: {direction}", of) sys.stdout.flush() if cost_model == "quadratic": p_new, psi_new, grad_new, epsilon_new = linear_search_quadratic( p, direction, f, f_prime, grad, thresh, workers, scalist, wcslist, neighbors, cfg, of=of, indata_type=indata_type, ) else: p_new, psi_new, grad_new, epsilon_new = linear_search_general( p, direction, f, f_prime, cost_model, epsilon, psi, grad, thresh, workers, scalist, wcslist, neighbors, cfg, of=of, indata_type=indata_type, ) write_to_file(f"Global elapsed t = {(time.time()-t0_global)/60:8.1f}", of) ls_time = (time.time() - t_start_LS) / 60 write_to_file(f"Total time spent in linear search: {ls_time}", of) write_to_file( f"Current norm: {current_norm}, Tol * Norm_0: {tol}, Difference (CN-TOL): {current_norm - tol}", of, ) sys.stdout.flush() # Calculate additional metrics convergence_rate = (current_norm - np.linalg.norm(grad_new)) / current_norm step_size = np.linalg.norm(p_new.params - p.params) gradient_magnitude = np.linalg.norm(grad_new) mse = np.mean(psi_new**2) parameter_change = np.linalg.norm(p_new.params - p.params) with open(log_file, "a", newline="") as csvfile: writer = csv.writer(csvfile) writer.writerow( [ i + 1, current_norm, convergence_rate, step_size, gradient_magnitude, np.sum(grad * direction), np.sum(psi), (time.time() - t_start_CG_iter) / 60, ls_time, mse, parameter_change, ] ) # Update to current values p = p_new psi = psi_new epsilon = epsilon_new grad_prev = grad grad = grad_new direction_prev = direction write_to_file( f"Total time spent in this CG iteration: {(time.time() - t_start_CG_iter) / 60} minutes.", of ) write_to_file(f"Global elapsed t = {(time.time()-t0_global)/60:8.1f}", of) sys.stdout.flush() # Save checkpoint if walltime exceeded if time_limit is not None: elapsed_minutes = (time.time() - t0_global) / 60 if elapsed_minutes >= time_limit: write_to_file(f"Walltime limit {time_limit} min reached. Exiting early!!!", of) if cfg.ds_restart is None: restart_file = os.path.join(cfg.ds_outpath, "cg_restart.pkl") save_snapshot( p, grad, epsilon, psi, direction, grad_prev, direction_prev, cg_model, tol, thresh, norm_0, cost_model, i, restart_file, of=of, ) return p if i == cfg.cg_maxiter - 1: write_to_file(f"CG reached MAX ITERATIONS {cfg.cg_maxiter} and DID NOT converge!!!!", of) write_to_file(f"Conjugate gradient complete. Finished in {i + 1} / {cfg.cg_maxiter} iterations", of) write_to_file(f"Final parameters: {p.params}", of) write_to_file(f"Final norm: {current_norm}", of) return p
[docs] def main(cfg_file=None, overlaponly=False, of=None, testing=False): """ Main function to run destriping via conjugate gradient descent. Parameters ---------- cfg_file : str, optional Configuration file (if not provided, reads from command line arguments). overlaponly : bool, optional Only compute the overlap matrix, then stop. of : str, optionals Output file for logging. If None, logs are printed to stdout. testing : bool, optional If True, saves additional diagnostic images and info for testing and validation. Returns ------- str Prefix for destriped images. Add f"_{obsid}_{sca}.fits" to get the full file name. """ global img_full_output CG_models = {"FR", "PR", "HS", "DY"} if cfg_file is None: cfg_file = sys.argv[1] if len(sys.argv) > 1 else None if cfg_file is not None: CFG = Config(cfg_file=cfg_file) else: raise ValueError("Please provide a config file as a command line argument.") outpath = CFG.ds_outpath if JWST: indata_type = "jwst" filter_ = CFG.use_filter else: indata_type = "fits" filter_ = filters[CFG.use_filter] if testing: testoutputs["testing"] = True # For test outputs: set sca=0 to not produce test outputs. img_full_output = ( {"obsid": "04793009001_02101_00001", "scaid": "nrcb2"} if JWST else {"obsid": 670, "scaid": 10} ) if not testoutputs["testing"]: img_full_output = {"obsid": -1, "scaid": -1} # don't do these big outputs # Prior on cost function is not yet implemented # if CFG.cost_prior != 0: # cost_prior = CFG.cost_prior if CFG.cg_model not in CG_models: raise ValueError(f"CG model {CFG.cg_model} not in CG_models dictionary.") CFG() t0 = time.time() workers = os.cpu_count() // int(os.environ["OMP_NUM_THREADS"]) if "OMP_NUM_THREADS" in os.environ else 12 write_to_file(f"## Using {workers} workers for parallel processing.", of) all_scas, all_wcs = get_scas(filter_, CFG.ds_obsfile, CFG, indata_type=indata_type, of=of) write_to_file(f"{len(all_scas)} SCAs in this mosaic", of) sys.stdout.flush() # get overlap matrix if os.path.isfile(outpath + "ovmat.npy"): ov_mat = np.load(outpath + "ovmat.npy") else: ovmat_t0 = time.time() write_to_file("Overlap matrix computing start", of) ov_mat = compareutils.get_overlap_matrix(all_wcs, verbose=True, subsamp=4) np.save(outpath + "ovmat.npy", ov_mat) write_to_file(f"Overlap matrix complete. Duration: {(time.time() - ovmat_t0) / 60} Minutes", of) write_to_file(f"Overlap matrix saved to: {outpath}ovmat.npy", of) # if we're only computing overlap matrices, can stop here if overlaponly: return sys.stdout.flush() neighbors = get_neighbors(all_scas, ov_mat) # Initialize parameters p0 = Parameters(cfg=CFG, scalist=all_scas) cm = Cost_models(cfg=CFG) # Do it try: p = conjugate_gradient( p0, cm.f, cm.f_prime, cm.thresh, workers, all_scas, all_wcs, neighbors, time_limit=7200, cfg=CFG, of=of, indata_type=indata_type, ) hdu = fits.PrimaryHDU(p.params) hdu.writeto(outpath + "final_params.fits", overwrite=True) write_to_file(outpath + "final_params.fits created \n", of) sys.stdout.flush() except Exception as e: write_to_file(f"Exception: {e}", of) logging.exception("An error occurred:") write_to_file("Conjugate gradient failed. Restart state saved to cg_restart.pkl\n", of) for i, sca in enumerate(all_scas): obsid, scaid = get_ids(sca, indata_type=indata_type) this_sca = Sca_img(obsid, scaid, CFG, add_objmask=False, indata_type=indata_type) this_param_set = p.forward_par(i) ds_image = this_sca.image - this_param_set hdu = fits.PrimaryHDU(ds_image, header=this_sca.header) hdu.header["TYPE"] = "DESTRIPED_IMAGE" hdu2 = fits.ImageHDU(this_sca.image, header=this_sca.header) hdu2.header["TYPE"] = "SCA_IMAGE" hdu3 = fits.ImageHDU(this_param_set, header=this_sca.header) hdu3.header["TYPE"] = "PARAMS_IMAGE" hdulist = fits.HDUList([hdu, hdu2, hdu3]) hdulist.writeto(outpath + filter_ + "_DS_" + obsid + "_" + scaid + ".fits", overwrite=True) del this_sca.image del this_sca gc.collect() write_to_file(f"Destriped images saved to {outpath + filter_} _DS_*.fits", of) write_to_file(f"Total hours elapsed: {(time.time() - t0) / 3600}", of) return outpath + filter_ + "_DS"
if __name__ == "__main__":
[docs] profiler = cProfile.Profile()
profiler.enable() mem_usage = None try: mem_usage = memory_usage(main, interval=120, retval=False) finally: profiler.disable() stream = io.StringIO() stats = pstats.Stats(profiler, stream=stream) stats.sort_stats("cumulative") stats.print_stats() with open("profile_results.txt", "w") as f: f.write(stream.getvalue()) if mem_usage is not None: with open("memory_profile_results.txt", "w") as f: for i, mem in enumerate(mem_usage): f.write(f"{i}\t{mem:.2f} MiB\n")