diff --git a/spimple/apps/spi_fitter.py b/spimple/apps/spi_fitter.py index 35d0d07..b9a054d 100755 --- a/spimple/apps/spi_fitter.py +++ b/spimple/apps/spi_fitter.py @@ -72,7 +72,7 @@ def spi_fitter(): "Will overwrite in fits headers of output.") parser.add_argument('-otype', '--out_dtype', default='f4', type=str, help="Data type of output. Default is single precision") - parser.add_argument('-acr', '--add-convolved-residuals', type=str2bool, nargs='?', const=True, default=True, + parser.add_argument('-acr', '--add-convolved-residuals', type=str2bool, nargs='?', const=True, default=False, help='Flag to add in the convolved residuals before fitting components') parser.add_argument('-ms', "--ms", nargs="+", type=str, help="Mesurement sets used to make the image. \n" @@ -95,6 +95,10 @@ def spi_fitter(): help="Correlation typ i.e. linear or circular. ") parser.add_argument('-band', "--band", type=str, default='l', help="Band to use with JimBeam. L or UHF") + parser.add_argument('-db', '--deselect-bands', default=None, nargs='+', type=int, + help="Indices of subbands to exclude from the fitting \n" + "By default, all the sub-bands are used for the residual image. \n" + "e.g. -db 1 2 will exclude sub-bands indexed at 1 & 2.") opts = parser.parse_args() opts = OmegaConf.create(vars(opts)) @@ -160,10 +164,10 @@ def spi_fitter(): m_coord, ref_m = data_from_header(mhdr, axis=2) m_coord -= ref_m - if mhdr["CTYPE4"].lower() == 'freq': + if mhdr["CTYPE4"].lower() in ['freq', 'speclnmf']: freq_axis = 4 stokes_axis = 3 - elif mhdr["CTYPE3"].lower() == 'freq': + elif mhdr["CTYPE3"].lower() in ['freq', 'speclnmf']: freq_axis = 3 stokes_axis = 4 else: @@ -354,11 +358,32 @@ def spi_fitter(): else: print("No residual provided. Setting threshold i.t.o dynamic range. " f"Max dynamic range is {opts.maxDR}", file=log) - threshold = model.max()/opts.maxDR + mask = ~np.isnan(model) + threshold = model[mask].max()/opts.maxDR rms_cube = None print(f"Threshold set to {threshold} Jy.", file=log) + # remove completely nan slices + freq_mask = np.isnan(model) + fidx = ~np.all(freq_mask, axis=(1,2)) + + # exclude any bands that might be awful + if opts.deselect_bands: + print(f"Deselected bands are: {opts.deselect_bands}", file=log) + for bidx in opts.deselect_bands: + fidx[bidx] = False + + if fidx.any(): + model = model[fidx] + beam_image = beam_image[fidx] + freqs = freqs[fidx] + gaussparf = list(gaussparf) + #for i, par in enumerate(gaussparf): + # if not fidx[i]: + # gaussparf.remove(par) + new_hdr = set_header_info(mhdr, ref_freq, freq_axis, opts, tuple(gaussparf)) + # get pixels above threshold minimage = np.amin(model, axis=0) maskindices = np.argwhere(minimage > threshold) @@ -372,7 +397,7 @@ def spi_fitter(): # set weights for fit if opts.channel_weights is not None: - weights = np.array(opts.channel_weights) + weights = np.array(opts.channel_weights)[fidx] try: assert weights.size == nband print("Using provided channel weights.", file=log) @@ -380,21 +405,23 @@ def spi_fitter(): print("Number of provided channel weights not equal " "to number of imaging bands", file=log) else: - if len(opts.residual) > 1: + if opts.residual: print("Getting weights from list of image headers.", file=log) - rhdr = [fits.getheader(res) for res in opts.residual] + rhdr = [] + for i, res in enumerate(opts.residual): + rhdr.append(fits.getheader(res)) weights = np.array([hdr["WSCVWSUM"] for hdr in rhdr]) weights /= weights.max() elif rms_cube is not None: print("Using RMS in each imaging band to determine weights.", file=log) - weights = np.where(rms_cube > 0, 1.0/rms_cube**2, 0.0) + weights = np.where(rms_cube[fidx] > 0, 1.0/rms_cube[fidx]**2, 0.0) # normalise weights /= weights.max() else: print("No residual or channel weights provided. " "Using equal weights.", file=log) - weights = np.ones(nband, dtype=np.float64) + weights = np.ones(fidx.sum(), dtype=np.float64) print(f"Channel weights: {weights}", file=log) ncomps, _ = fitcube.shape