Source code for hyrax.download

import itertools
import logging
from pathlib import Path
from threading import Thread
from typing import Union

import toml
from astropy.io import fits
from astropy.table import Table, hstack

import hyrax.downloadCutout.downloadCutout as dC
from hyrax.download_stats import DownloadStats

[docs] logger = logging.getLogger(__name__)
[docs] class Downloader: """Class with primarily static methods to namespace downloader related constants and functions.""" # These are the fields that are allowed to vary across the locations # input from the catalog fits file. Other values for HSC cutout server # must be provided by config.
[docs] VARIABLE_FIELDS = ["tract", "ra", "dec"]
# These are the column names we retain when writing a rect out to the manifest.fits file # # The construction list(dict.fromkeys()) is to de-duplicate keys that exist in VARIBLE_FIELDS and # dC.Rect.immutable_fields but keep the ordering of keys so that VARIABLE_FIELDS are first and all # of the immutable fields that we rely on for hash checks are also included.
[docs] RECT_COLUMN_NAMES = list(dict.fromkeys(VARIABLE_FIELDS + dC.Rect.immutable_fields + ["dim"]))
[docs] MANIFEST_COLUMN_NAMES = RECT_COLUMN_NAMES + ["filename", "object_id"]
[docs] MANIFEST_FILE_NAME = "manifest.fits"
def __init__(self, config):
[docs] self.config = config
[docs] self.cutout_path = Path(config["general"]["data_dir"]).expanduser().resolve()
[docs] self.manifest_file = self.cutout_path / Downloader.MANIFEST_FILE_NAME
[docs] def run(self): """ Main entrypoint for downloading cutouts from HSC for use with Hyrax Parameters ---------- config : dict Runtime configuration as a nested dictionary """ logger.info("Download command Start") credentials_file = self.config["download"]["credentials_file"] credentials_file_path = Path(credentials_file) if credentials_file else None credentials_configured = bool( self.config["download"]["username"] or self.config["download"]["password"] ) if credentials_file_path.exists() and credentials_configured: msg = f"Credentials file {credentials_file} found in addition to username/password in your " msg += "hyrax config. Credentials must be provided in only one place." raise RuntimeError(msg) if credentials_file_path.exists(): credentials = toml.load(credentials_file_path) username = credentials.get("username", False) password = credentials.get("password", False) else: username = self.config["download"]["username"] password = self.config["download"]["password"] if not username or not password: msg = "Please define a username and password to the HSC cutout service in credentials " msg += f"file: {credentials_file} to use the downloader. The format of the credentials file is:\n" msg += '\nusername = "<your username>"\n' msg += 'password = "<your password>"\n\n' msg += "Please do not check your credentials into git or other version control systems.\n" msg += "Accounts can be created at the following url: \n" msg += " https://hsc-release.mtk.nao.ac.jp/datasearch/new_user/new " raise RuntimeError(msg) fits_file = Path(self.config["download"]["fits_file"]).expanduser().resolve() logger.info(f"Reading in fits catalog: {fits_file}") # Filter the fits file for the fields we want column_names = ["object_id"] + Downloader.VARIABLE_FIELDS locations = Downloader.filterfits(fits_file, column_names) # If offet/length specified, filter to that length offset = self.config["download"]["offset"] num_sources = self.config["download"]["num_sources"] if num_sources > 0: end = offset + num_sources locations = locations[offset:end] logger.info(f"Downloading cutouts to {self.cutout_path}") logger.info("Making a list of cutouts...") # Make a list of rects to pass to downloadCutout self.rects = Downloader.create_rects( locations, offset=0, default=self._rect_from_config(), path=self.cutout_path ) logger.info("Checking the list against currently downloaded cutouts...") # Prune any previously downloaded rects from our list using the manifest from the previous download self.rects = self._prune_downloaded_rects() # Early return if there is nothing to download. if len(self.rects) == 0: logger.info("Download already complete according to manifest.") return # Create thread objects for each of our worker threads num_threads = self.config["download"]["concurrent_connections"] if num_threads > 10: RuntimeError("More than 10 concurrent connections to HSC is disallowed on a per-user basis") # If we are using more than one connection, cut the list of rectangles into # batches, one batch for each thread. # TODO: Remove this in favor of itertools.batched() when we no longer support python < 3.12. def _batched(iterable, n): """Brazenly copied and pasted from the python 3.12 documentation. This is a dodgy version of a new itertools function in Python 3.12 called itertools.batched() """ if n < 1: raise ValueError(f"n must be at least one, it is {n}") iterator = iter(iterable) while batch := tuple(itertools.islice(iterator, n)): yield batch if len(self.rects) < num_threads: msg = f"Only {len(self.rects)} sky locations, which is less than the number of threads, so we " msg += "will use only one thread." logger.info(msg) num_threads = 1 logger.info(f"Dividing {len(self.rects)} sky locations among {num_threads} threads...") thread_rects = ( list(_batched(self.rects, int(len(self.rects) / num_threads))) if num_threads != 1 and len(self.rects) > num_threads else [self.rects] ) # Empty dictionaries for the threads to create download manifests in self.thread_manifests = [dict() for _ in range(num_threads)] shared_thread_args = ( username, password, DownloadStats(print_interval_s=self.config["download"]["stats_print_interval"]), ) shared_thread_kwargs = { "retrywait": self.config["download"]["retry_wait"], "retries": self.config["download"]["retries"], "timeout": self.config["download"]["timeout"], "chunksize": self.config["download"]["chunk_size"], } download_threads = [ Thread( target=Downloader.download_thread, name=f"thread_{i}", daemon=True, # daemon so these threads will die when the main thread is interrupted args=(thread_rects[i],) # rects + shared_thread_args # username, password, download stats + (i, self.thread_manifests[i]), # thread_num, manifest kwargs=shared_thread_kwargs, ) for i in range(num_threads) ] try: logger.info(f"Started {len(download_threads)} request threads") [thread.start() for thread in download_threads] [thread.join() for thread in download_threads] finally: # Ensure manifest is written even when we get a KeyboardInterrupt during download self._write_manifest() logger.info("Done")
[docs] def _prune_downloaded_rects(self): """Prunes already downloaded rects using the manifest in `cutout_path`. `rects` passed in is mutated by this operation Returns ------- list[dC.Rect] Returns `rects` that was passed in. This is only to enable explicit style at the call site. ` rects` is mutated by this function. Raises ------ RuntimeError When there is an issue reading the manifest file, or the manifest file corresponds to a different set of cutouts than the current download being attempted """ # print(rects) # Read in any prior manifest. prior_manifest = self.manifest_to_rects() # If we found a manifest, we are resuming a download if len(prior_manifest) != 0: # Filter rects to figure out which ones are completely downloaded. # This operation consumes prior_manifest in the process self.rects[:] = [rect for rect in self.rects if Downloader._keep_rect(rect, prior_manifest)] # if prior_manifest was not completely consumed, than the earlier download attempted # some sky locations which would not be included in the current download, and we have # a problem. if len(prior_manifest) != 0: # print(len(prior_manifest)) # print (prior_manifest) raise RuntimeError( f"""{self.manifest_file} describes a download with sky locations that would not be downloaded in the download currently being attempted. Are you sure you are resuming the correct download? Deleting the manifest and cutout files will start the download from scratch""" ) return self.rects
@staticmethod
[docs] def _keep_rect(location_rect: dC.Rect, prior_manifest: dict[dC.Rect, str]) -> bool: """Private helper function to prune_downloaded_rects which operates the inner loop of the prune function, and allows it to be written as a list comprehension. This function decides element-by-element for our rects that we want to download whether or not these rects have already been downloaded in a prior download, given the manifest from that prior download. Parameters ---------- location_rect : dC.Rect A rectangle on the sky that we are considering downloading. prior_manifest : dict[dC.Rect,str] The manifest of the prior download. This object is slowly consumed by repeated calls to this function. When the return value is False, all manifest entries corresponding to the passed in location_rect have been removed. Returns ------- bool Whether this sky location `location_rect` should be included in the download """ # Keep any location rect if the manifest passed has nothing in it. if len(prior_manifest) == 0: return True keep_rect = False for filter_rect in location_rect.explode(): # Consume any matching manifest entry, keep the rect if # 1) The manifest entry doesn't exist -> pop returns None # 2) The manifest entry contains "Attempted" for the filename -> The corresponding file wasn't # successfully downloaded matching_manifest_entry = prior_manifest.pop(filter_rect, None) if matching_manifest_entry is None or matching_manifest_entry == "Attempted": keep_rect = True return keep_rect
[docs] def _write_manifest(self): """Write out manifest fits file that is an inventory of the download. The manifest fits file should have columns object_id, ra, dec, tract, filter, filename If filename is empty string ("") that means a download attempt was made, but did not succeed If the object is not present in the manifest, no download was attempted. If the object is present in the manifest and the filename is not empty string that file exists and downloaded successfully. This file respects the existence of other manifest files in the directory and operates additively. If a manifest file is present from an earlier download, this function will read that manifest in, and include the entire content of that manifest in addition to the manifests passed in. The format of the manifest file has the following columns object_id: The object ID from the original catalog filename: The file name where the file can be found OR the string "Attempted" indicating the download did not complete successfully. tract: The HSC tract ID number this either comes from the catalog or is the tract ID returned by the cutout server for downloaded files. ra: Right ascension in degrees of the center of the cutout box dec: Declination in degrees of the center of the cutout box filter: The name of the filter requested sw: Semi-width of the cutout box in degrees sh: Semi-height of the cutout box in degrees rerun: The data release in use e.g. pdr3_wide type: coadd, warp, or other values allowed by the HSC docs dim: Tuple of integers with the dimensions of the image. """ logger.info("Assembling download manifest") # Start building a combined manifest from all threads from the ground truth of the prior manifest # in this directory, which we will be overwriting. combined_manifest = self.manifest_to_rects() # Combine all thread manifests with the prior manifest, so that the current status of a downloaded # rect overwrites any status from the prior run (which is no longer relevant.) for manifest in self.thread_manifests: combined_manifest.update(manifest) logger.info(f"Writing out download manifest with {len(combined_manifest)} entries.") # Convert the combined manifest into an astropy table by building a dict of {column_name: column_data} # for all the fields we require in a manifest columns = {column_name: [] for column_name in Downloader.MANIFEST_COLUMN_NAMES} for rect, msg in combined_manifest.items(): # This parsing relies on the name format set up in create_rects to work properly # We parse the object_id from rect.name in case the filename is "Attempted" because the # download did not finish. rect_filename = Path(rect.name).name object_id = int(rect_filename.split("_")[0]) columns["object_id"].append(object_id) # Remove the leading path from the filename if any. filename = Path(msg).name columns["filename"].append(filename) for key in Downloader.RECT_COLUMN_NAMES: columns[key].append(rect.__dict__[key]) # print(columns) # for key, val in columns.items(): # print (key, len(val), val) manifest_table = Table(columns) manifest_table.write(self.manifest_file, overwrite=True, format="fits") logger.info("Finished writing download manifest")
[docs] def get_manifest(self): """Get the current downloader manifest, which is a list of files where download has been attempted The format of the table is outlined in _write_manifest() Returns ------- astropy.table.Table The entire download manifest """ if self.manifest_file.exists(): return Table.read(self.manifest_file, format="fits") return None
[docs] def manifest_to_rects(self) -> dict[dC.Rect, str]: """Read the manifest.fits file from the given directory and return its contents as a dictionary with downloadCutout.Rectangles as keys and filenames as values. If now manifest file is found, an empty dict is returned. Returns ------- dict[dC.Rect, str] A dictionary containing all the rects in the manifest and all the filenames, or empty dict if no manifest is found. """ manifest_table = self.get_manifest() if manifest_table is not None: rects = Downloader.create_rects( locations=manifest_table, fields=Downloader.RECT_COLUMN_NAMES, path=self.cutout_path ) return {rect: filename for rect, filename in zip(rects, manifest_table["filename"])} else: return {}
@staticmethod
[docs] def _rect_hook(rect: dC.Rect, filename: Union[Path, str]): with fits.open(filename) as hdul: rect.dim = hdul[1].shape
@staticmethod
[docs] def download_thread( rects: list[dC.Rect], user: str, password: str, stats: DownloadStats, thread_num: int, manifest: dict[dC.Rect, str], **kwargs, ): """Download cutouts to the given directory. Called in its own thread with an id number. Calls downloadCutout.download, so supports long lists of rects beyond the limits of the HSC web API Parameters ---------- rects : list[dC.Rect] The rects we would like to download user : string Username for HSC's download service to use password : string Password for HSC's download service to use stats : DownloadStats Instance of DownloadStats to use for stats tracking. thread_num : int, The ID number of thread we are, sequential from zero to num_threads-1 manifest: A dictionary from dC.Rect to filename which we will fill in in as we download rects. This is the chief returned piece of data from each thread. **kwargs: dict Additonal arguments for downloadCutout.download. See downloadCutout.download for details """ logger.info(f"Thread {thread_num} starting download of {len(rects)} rects") with stats as stats_hook: dC.download( rects, user=user, password=password, onmemory=False, request_hook=stats_hook, rect_hook=Downloader._rect_hook, manifest=manifest, **kwargs, )
# TODO add error checking @staticmethod
[docs] def filterfits(filename: Path, column_names: list[str]) -> Table: """Read a fits file with the required column names for making cutouts The easiest way to make such a fits file is to select from the main HSC catalog Parameters ---------- filename : str The fits file to read in column_names : list[str] The columns that are selected from the file and returned in the astropy Table. Returns ------- Table Returns an astropy table containing only the fields specified in column_names """ t = Table.read(filename) columns = [t[column] for column in column_names] return hstack(columns, uniq_col_name="{table_name}", table_names=column_names)
[docs] def _rect_from_config(self) -> dC.Rect: """Takes our runtime config and loads cutout config common to all cutouts into a prototypical Rect for downloading Returns ------- dC.Rect A single rectangle with fields `sw`, `sh`, `filter`, `rerun`, and `type` populated from the config """ return dC.Rect.create( sw=self.config["download"]["sw"], sh=self.config["download"]["sh"], filter=self.config["download"]["filter"], rerun=self.config["download"]["rerun"], type=self.config["download"]["type"], image=self.config["download"]["image"], mask=self.config["download"]["mask"], variance=self.config["download"]["variance"], )
@staticmethod
[docs] def create_rects( locations: Table, path: Path, offset: int = 0, default: dC.Rect | None = None, fields: list[str] | None = None, ) -> list[dC.Rect]: """Create the rects we will need to pass to the downloader. One Rect per location in our list of sky locations. Rects are created with all fields in the default rect pre-filled Offset here is to allow multiple downloads on different sections of the source list without file name clobbering during the download phase. The offset is intended to be the index of the start of the locations table within some larger fits file. Parameters ---------- locations : Table Table containing ra, dec locations in the sky path : Path Directory where the cutuout files ought live. Used to generate file names on the rect object. offset : int, optional Index to start the `lineno` field in the rects at, by default 0. The purpose of this is to allow multiple downloads on different sections of a larger source list without file name clobbering during the download phase. This is important because `lineno` in a rect can becomes a file name parameter The offset is intended to be the index of the start of the locations table within some larger fits file. default : dC.Rect, optional The default Rect that contains properties common to all sky locations, by default None fields : list[str], optional Default fields to pull from the locations table. If not provided, defaults to ["tract", "ra", "dec"] Returns ------- list[dC.Rect] Rects populated with sky locations from the table """ rects = [] fields = fields if fields else Downloader.VARIABLE_FIELDS for index, location in enumerate(locations): args = {field: location.get(field) for field in fields} args["lineno"] = index + offset # tracts are ints in the fits files and dC.rect constructor wants them as str args["tract"] = str(args["tract"]) # Sets the file name on the rect to be the object_id, also includes other rect fields # which are interpolated at save time, and are native fields of dc.Rect. args["name"] = str( path / f"{location['object_id']}_{{type}}_{{ra:.5f}}_{{dec:+.5f}}_{{tract}}_{{filter}}" ) rect = dC.Rect.create(default=default, **args) rects.append(rect) # We sort rects here so they end up tract,ra,dec ordered across all requests made in all threads # Threads do their own sorting prior to all chunked request in downloadCutout.py; however # sorting at this stage will allow a greater number of rects that are co-located in the sky # to end up in the same thread and therefore be sorted into the same request. rects.sort() return rects