Source code for hyrax.data_sets.lsst_dataset

import functools
import logging
from pathlib import Path

from torch.utils.data import Dataset

from .data_set_registry import HyraxDataset, HyraxImageDataset

[docs] logger = logging.getLogger(__name__)
[docs] class LSSTDataset(HyraxDataset, HyraxImageDataset, Dataset): """LSSTDataset: A dataset to access deep_coadd images from lsst pipelines via the butler. Must be run in an RSP. """
[docs] BANDS = ["u", "g", "r", "i", "z", "y"]
[docs] def __init__(self, config, data_location): """ .. py:method:: __init__ Initialize the dataset with either a HATS catalog or astropy table. Config can specify either: - config["data_set"]["hats_catalog"]: path to HATS catalog - config["data_set"]["astropy_table"]: path to any file readable by Astropy Table """ try: import lsst.daf.butler as butler self.butler = butler.Butler( config["data_set"]["butler_repo"], collections=config["data_set"]["butler_collection"] ) self.skymap = self.butler.get("skyMap", {"skymap": config["data_set"]["skymap"]}) except ImportError: msg = "Did not detect a Butler. You may need to run on the RSP" msg += "" logger.info(msg) self.butler = None self.skymap = None # Set filters from config if provided, otherwise use class default if "filters" in config["data_set"] and config["data_set"]["filters"]: # Validate filters valid_filters = ["u", "g", "r", "i", "z", "y"] for band in config["data_set"]["filters"]: if band not in valid_filters: raise ValueError( f"Invalid filter {band} for Rubin-LSST.\ Valid bands are: {valid_filters}" ) LSSTDataset.BANDS = config["data_set"]["filters"] # Load catalog - either from HATS or astropy table
[docs] self.catalog = self._load_catalog(config["data_set"])
[docs] self.sh_deg = config["data_set"]["semi_height_deg"]
[docs] self.sw_deg = config["data_set"]["semi_width_deg"]
self.set_function_transform() self.set_crop_transform() super().__init__(config, metadata_table=self.catalog, object_id_column_name="NOT A REAL COLUMN")
[docs] def _load_catalog(self, data_set_config): """ Load the catalog from either a HATS catalog or an astropy table. """ if "hats_catalog" in data_set_config: return self._load_hats_catalog(data_set_config["hats_catalog"]) elif "astropy_table" in data_set_config: return self._load_astropy_catalog(data_set_config["astropy_table"]) else: raise ValueError("Must specify either 'hats_catalog' or 'astropy_table' in data_set config")
[docs] def _load_hats_catalog(self, hats_path): """Load catalog from HATS format using LSDB.""" try: import lsdb except ImportError as e: msg = "LSDB is required to load HATS catalogs. Install with: pip install lsdb" raise ImportError(msg) from e # We compute the entire catalog so we have a nested frame which we can access return lsdb.read_hats(hats_path).compute()
[docs] def _load_astropy_catalog(self, table_path): """Load catalog from astropy table format or pickled astropy table.""" import pickle from astropy.table import Table table_path = Path(table_path) # Check if it's a pickle file if table_path.suffix.lower() in [".pkl", ".pickle"]: with open(table_path, "rb") as f: table = pickle.load(f) # Verify it's an astropy Table if not isinstance(table, Table): raise ValueError(f"Pickled file {table_path} does not contain an astropy Table") return table else: # Load using astropy's native readers -- can be any format supported by astropy return Table.read(table_path)
[docs] def __len__(self): return len(self.catalog)
[docs] def get_image(self, idxs): """Get image cutouts for the given indices. Parameters ---------- idxs : int or list of int The index or indices of the cutouts to retrieve. Returns ------- list or torch.Tensor Single cutout tensor or list of cutout tensors. """ from astropy.table import Table from nested_pandas import NestedFrame # Handle different catalog types if isinstance(self.catalog, Table): # Astropy table - extract rows directly if isinstance(idxs, (list, tuple)): rows = [self.catalog[idx] for idx in idxs] cutouts = [self._fetch_single_cutout(row) for row in rows] return cutouts else: row = self.catalog[idxs] return self._fetch_single_cutout(row) else: # NestedFrame (HATS catalog) frame = self.catalog.iloc[idxs] frame = frame if isinstance(frame, NestedFrame) else NestedFrame(frame).T cutouts = [self._fetch_single_cutout(row) for _, row in frame.iterrows()] return cutouts if len(cutouts) > 1 else cutouts[0]
[docs] def __getitem__(self, idxs): """Get default data fields for the this dataset. Parameters ---------- idxs : int or list of int The index or indices of the cutouts to retrieve. Returns ------- dict A dictionary containing the default data fields. """ return {"data": {"image": self.get_image(idxs)}}
# def __getitems__(self, idxs): # return __getitem__(self, idxs)
[docs] def _parse_box(self, patch, row): """ Return a Box2I representing the desired cutout in pixel space, given a "row" of catalog data which includes the semi-height (sh) and semi-width (sw) in degrees desired for the cutout. """ from lsst.geom import Box2D, Box2I, degrees radec = self._parse_sphere_point(row) sw = self.sh_deg * degrees sh = self.sw_deg * degrees # Ra/Dec is left handed on the sky. Pixel coordinates are right handed on the sky. # In the variable names below min/max mean the min/max coordinate values in the # right-handed pixel space, not the left-handed sky space. # Move + in ra (0.0) for width and - in dec (270.0) along a great circle min_pt_sky = radec.offset(0.0 * degrees, sw).offset(270.0 * degrees, sh) # Move - in ra (180.0) for width and + in dec (90.0) along a great circle max_pt_sky = radec.offset(180.0 * degrees, sw).offset(90.0 * degrees, sh) wcs = patch.getWcs() minmax_pt_pixel_f = wcs.skyToPixel([min_pt_sky, max_pt_sky]) box_pixel_f = Box2D(*minmax_pt_pixel_f) box_pixel_i = Box2I(box_pixel_f, Box2I.EXPAND) # Throw if box_pixel_i extends outside the patch outer bbox # TODO: Do we want to fill nan's in this case? # Do we want to conditionally fill nan's if a nan-infill strategy is configured in # hyrax and error otherwise? if not patch.getOuterBBox().contains(box_pixel_i): msg = f"Bounding box for object at ra {radec.getLongitude().asDegrees()} deg " msg += f"dec {radec.getLatitude().asDegrees()} with semi-height {sh.asArcseconds()} arcsec " msg += f"and semi-width {sh.asArcseconds()} arcsec extends outside the bounding box of a " msg += "patch. Choose smaller values for config['data_set']['semi_height_deg'] and " msg += "config['data_set']['semi_width_deg']." raise RuntimeError(msg) # Throw if box_pixel_i does not contain any points if box_pixel_i.isEmpty(): msg = "Calculated size for cutout is 0x0 pixels. Did you set " msg += "config['data_set']['semi_height_deg'] and config['data_set']['semi_width_deg']?" raise RuntimeError(msg) return box_pixel_i
[docs] def _parse_sphere_point(self, row): """ Return a SpherePoint with the ra and deck given in the "row" of catalog data. Row must include the RA and dec as "ra" and "dec" columns respectively """ from lsst.geom import SpherePoint, degrees ra = row["coord_ra"] dec = row["coord_dec"] return SpherePoint(ra, dec, degrees)
[docs] def _get_tract_patch(self, row): """ Return (tractInfo, patchInfo) for a given row. This function only returns the single principle tract and patch in the case of overlap. """ radec = self._parse_sphere_point(row) tract_info = self.skymap.findTract(radec) return (tract_info, tract_info.findPatch(radec))
# super basic patch caching @functools.lru_cache(maxsize=128) # noqa: B019
[docs] def _request_patch(self, tract_index, patch_index): """ Request a patch from the butler. This will be a list of lsst.afw.image objects each corresponding to the configured bands Uses functools.lru_cache for basic in-memory caching. """ data = [] # Get the patch images we need for band in LSSTDataset.BANDS: # Set up the data dict butler_dict = { "tract": tract_index, "patch": patch_index, "skymap": self.config["data_set"]["skymap"], "band": band, } # pull from butler image = self.butler.get("deep_coadd", butler_dict) data.append(image.getImage()) return data
[docs] def _fetch_single_cutout(self, row): """ Make a single cutout, returning a torch tensor. Does not handle edge-of-tract/patch type edge cases, will only work near center of a patch. """ import numpy as np from torch import from_numpy tract_info, patch_info = self._get_tract_patch(row) box_i = self._parse_box(patch_info, row) patch_images = self._request_patch(tract_info.getId(), patch_info.sequential_index) # Actually perform a cutout data = [image[box_i].getArray() for image in patch_images] # Convert to torch format data_np = np.array(data) data_torch = from_numpy(data_np.astype(np.float32)) return self.apply_transform(data_torch)