Skip to content

Commit

Permalink
Dropzweights (#14)
Browse files Browse the repository at this point in the history
* remove full nan slices

* don't acr by default

* look for max of model only at unflagged locations

* drop invalid gausspars

* accommodate different keys for different header

* idx -> fidx

* fresq -> freqs

* Exclude any bands that might be awful with `-ds`

* pep8 indetation level

* syntax error fix

* Update spi_fitter.py

* Update spi_fitter.py

---------

Co-authored-by: landmanbester <[email protected]>
  • Loading branch information
Athanaseus and landmanbester authored Nov 2, 2023
1 parent ce15fa1 commit a22c45d
Showing 1 changed file with 36 additions and 9 deletions.
45 changes: 36 additions & 9 deletions spimple/apps/spi_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def spi_fitter():
"Will overwrite in fits headers of output.")
parser.add_argument('-otype', '--out_dtype', default='f4', type=str,
help="Data type of output. Default is single precision")
parser.add_argument('-acr', '--add-convolved-residuals', type=str2bool, nargs='?', const=True, default=True,
parser.add_argument('-acr', '--add-convolved-residuals', type=str2bool, nargs='?', const=True, default=False,
help='Flag to add in the convolved residuals before fitting components')
parser.add_argument('-ms', "--ms", nargs="+", type=str,
help="Mesurement sets used to make the image. \n"
Expand All @@ -95,6 +95,10 @@ def spi_fitter():
help="Correlation typ i.e. linear or circular. ")
parser.add_argument('-band', "--band", type=str, default='l',
help="Band to use with JimBeam. L or UHF")
parser.add_argument('-db', '--deselect-bands', default=None, nargs='+', type=int,
help="Indices of subbands to exclude from the fitting \n"
"By default, all the sub-bands are used for the residual image. \n"
"e.g. -db 1 2 will exclude sub-bands indexed at 1 & 2.")

opts = parser.parse_args()
opts = OmegaConf.create(vars(opts))
Expand Down Expand Up @@ -160,10 +164,10 @@ def spi_fitter():
m_coord, ref_m = data_from_header(mhdr, axis=2)
m_coord -= ref_m

if mhdr["CTYPE4"].lower() == 'freq':
if mhdr["CTYPE4"].lower() in ['freq', 'speclnmf']:
freq_axis = 4
stokes_axis = 3
elif mhdr["CTYPE3"].lower() == 'freq':
elif mhdr["CTYPE3"].lower() in ['freq', 'speclnmf']:
freq_axis = 3
stokes_axis = 4
else:
Expand Down Expand Up @@ -354,11 +358,32 @@ def spi_fitter():
else:
print("No residual provided. Setting threshold i.t.o dynamic range. "
f"Max dynamic range is {opts.maxDR}", file=log)
threshold = model.max()/opts.maxDR
mask = ~np.isnan(model)
threshold = model[mask].max()/opts.maxDR
rms_cube = None

print(f"Threshold set to {threshold} Jy.", file=log)

# remove completely nan slices
freq_mask = np.isnan(model)
fidx = ~np.all(freq_mask, axis=(1,2))

# exclude any bands that might be awful
if opts.deselect_bands:
print(f"Deselected bands are: {opts.deselect_bands}", file=log)
for bidx in opts.deselect_bands:
fidx[bidx] = False

if fidx.any():
model = model[fidx]
beam_image = beam_image[fidx]
freqs = freqs[fidx]
gaussparf = list(gaussparf)
#for i, par in enumerate(gaussparf):
# if not fidx[i]:
# gaussparf.remove(par)
new_hdr = set_header_info(mhdr, ref_freq, freq_axis, opts, tuple(gaussparf))

# get pixels above threshold
minimage = np.amin(model, axis=0)
maskindices = np.argwhere(minimage > threshold)
Expand All @@ -372,29 +397,31 @@ def spi_fitter():

# set weights for fit
if opts.channel_weights is not None:
weights = np.array(opts.channel_weights)
weights = np.array(opts.channel_weights)[fidx]
try:
assert weights.size == nband
print("Using provided channel weights.", file=log)
except Exception as e:
print("Number of provided channel weights not equal "
"to number of imaging bands", file=log)
else:
if len(opts.residual) > 1:
if opts.residual:
print("Getting weights from list of image headers.", file=log)
rhdr = [fits.getheader(res) for res in opts.residual]
rhdr = []
for i, res in enumerate(opts.residual):
rhdr.append(fits.getheader(res))
weights = np.array([hdr["WSCVWSUM"] for hdr in rhdr])
weights /= weights.max()
elif rms_cube is not None:
print("Using RMS in each imaging band to determine weights.",
file=log)
weights = np.where(rms_cube > 0, 1.0/rms_cube**2, 0.0)
weights = np.where(rms_cube[fidx] > 0, 1.0/rms_cube[fidx]**2, 0.0)
# normalise
weights /= weights.max()
else:
print("No residual or channel weights provided. "
"Using equal weights.", file=log)
weights = np.ones(nband, dtype=np.float64)
weights = np.ones(fidx.sum(), dtype=np.float64)
print(f"Channel weights: {weights}", file=log)

ncomps, _ = fitcube.shape
Expand Down

0 comments on commit a22c45d

Please sign in to comment.