Source code for hyrax.data_sets.downloaded_lsst_dataset

import functools
import logging
import threading
from pathlib import Path

import numpy as np
import torch
from astropy.table import Table
from tqdm import tqdm

from .lsst_dataset import LSSTDataset
from .tensor_cache_mixin import TensorCacheMixin

[docs] logger = logging.getLogger(__name__)
[docs] class DownloadedLSSTDataset(LSSTDataset, TensorCacheMixin): """ DownloadedLSSTDataset: A dataset that inherits from LSSTDataset and downloads cutouts from the LSST butler, saving them as `.pt` files during first access. On subsequent accesses, it loads cutouts directly from these cached files. This class also creates a manifest files with the shape of each cutout and the corresponding filename. Public Methods: download_cutouts(indices=None, sync_filesystem=True, max_workers=None, force_retry=False): Download cutouts with parallel processing. Automatically resumes from previous progress. Use max_workers to control thread count, force_retry to re-attempt failed downloads. manifest_stats(): Returns dict with download statistics: total, successful, failed, pending counts and manifest file path. download_progress(): Returns detailed progress metrics including completion percentage and failure rates. reset_failed_downloads(): Resets all failed download attempts to allow retry without force_retry flag. Returns count of reset entries. save_manifest_now(): Forces immediate manifest save (normally saved periodically during downloads). cache_info(): Returns LRU cache statistics for patch fetching performance monitoring. clear_cache(): Clears the patch LRU cache to free memory. Usage Example: # Initialize Hyrax h = hyrax.Hyrax() a = h.prepare() # Download all cutouts (resumes automatically) a.download_cutouts(max_workers=4) WARNING: The LRU Caching scheme is slightly complicated, so it is recommended to use the default max_workers=1 for the first download. Simply using more workers may not always speed up the download process. # Check progress a.download_progress() # Retry failed downloads a.download_cutouts(force_retry=True) # Access cutouts (loads from cache) cutout = a[0] # Single cutout cutouts = a[0:10] # Multiple cutouts File Organization: - Cutouts saved as: cutout_{object_id}.pt or cutout_{index:04d}.pt - Manifest saved as: manifest.fits (Astropy) or manifest.parquet (HATS) - All files stored in config["general"]["data_dir"] """
[docs] def __init__(self, config, data_location):
[docs] self.download_dir = Path(data_location)
self.download_dir.mkdir(exist_ok=True) # Preventing name collision with parent class config
[docs] self._config = config
# Initialize parent class with config super().__init__(config, data_location=data_location) try: import lsst.daf.butler as _butler # noqa: F401 self._butler_config = { "repo": config["data_set"]["butler_repo"], "collections": config["data_set"]["butler_collection"], "skymap": config["data_set"]["skymap"], } except ImportError: self._butler_config = None # Manifest management
[docs] self._manifest_lock = threading.Lock()
[docs] self._updates_since_save = 0
[docs] self._save_interval = 1000
# Determine naming strategy and initialize manifest (includes band filtering validation) self._setup_naming_strategy() self._initialize_manifest() # Add tracking for band failure statistics (use current BANDS which may be filtered)
[docs] self._band_failure_stats = {band: 0 for band in self.BANDS}
[docs] self._band_failure_lock = threading.Lock()
# Initialize filtering state
[docs] self._manifest_filter_object_ids = None
[docs] self._catalog_to_manifest_index_map = None
[docs] self._manifest_to_catalog_index_map = None
# Initialize tensor caching from mixin self._init_tensor_cache(config)
[docs] def get_objectId(self, idx): # noqa: N802 """Get object ID for a given index based on naming strategy.""" if self.use_object_id: if isinstance(self.catalog, Table): return str(self.catalog[idx][self.object_id_column]) else: return str(self.catalog.iloc[idx][self.object_id_column]) else: return str(idx)
[docs] def ids(self, log_every=None): """Generator yielding object IDs for the entire dataset. Required by TensorCacheMixin""" log = log_every is not None and isinstance(log_every, int) for idx in range(len(self.catalog)): if log and idx != 0 and idx % log_every == 0: logger.info(f"Processed {idx} objects") yield self.get_objectId(idx) # Final log message after completing iteration if log and len(self.catalog) > 0: logger.info(f"Processed {len(self.catalog)} objects")
[docs] def _setup_naming_strategy(self): """Setup file naming strategy based on catalog columns.""" catalog_columns = self.catalog.colnames if hasattr(self.catalog, "colnames") else self.catalog.columns self.use_object_id = False if self._config["data_set"]["object_id_column_name"]: self.use_object_id = True self.object_id_column = self._config["data_set"]["object_id_column_name"] elif "object_id" in catalog_columns: self.use_object_id = True self.object_id_column = "object_id" elif "objectId" in catalog_columns: self.use_object_id = True self.object_id_column = "objectId" else: self.object_id_column = "objectId" if not self.use_object_id: dataset_length = len(self.catalog) self.padding_length = max(4, len(str(dataset_length)))
[docs] def _initialize_manifest(self): """Create new manifest or load/merge with existing manifest, with band filtering validation.""" if isinstance(self.catalog, Table): self.catalog_type = "astropy" self.manifest_path = self.download_dir / "manifest.fits" else: self.catalog_type = "hats" self.manifest_path = self.download_dir / "manifest.parquet" # Initialize band filtering flags self._is_filtering_bands = False self._original_bands = None self._filtered_bands = None self._band_indices = None # Try to load existing manifest first if not self.manifest_path.exists(): if self._butler_config is None: msg = "Cannot find any data source. There is no existing manifest, and there is no " msg += "butler available. Please try to run this on an RSP where a butler is available or " msg += "ensure a proper manifest and cutouts are available in " msg += f"{self.config['general']['data_dir']}" raise RuntimeError(msg) # Create new manifest (no existing manifest found) logger.info("Creating new manifest") if self.catalog_type == "astropy": self.manifest = Table() # Copy only the data columns for col_name in self.catalog.colnames: self.manifest[col_name] = self.catalog[col_name] else: self.manifest = self.catalog.copy() self._add_manifest_columns() self._save_manifest() logger.info(f"Initialized new manifest at {self.manifest_path}") else: logger.info(f"Found existing manifest at {self.manifest_path}") try: existing_manifest = self._load_existing_manifest() # Check for band filtering opportunity available_bands_set, original_band_order = self._get_available_bands_from_manifest( existing_manifest ) if available_bands_set is not None and original_band_order is not None: requested_bands = set(self.BANDS) # Only setup filtering if requested bands are a PROPER SUBSET if requested_bands < available_bands_set: # Proper subset (not equal) logger.info( f"Requested bands {sorted(list(requested_bands))} are a subset of " f"available {sorted(list(available_bands_set))}" ) self._setup_band_filtering(requested_bands, original_band_order) elif requested_bands == available_bands_set: logger.info("Requested bands match available bands exactly, no filtering needed") else: missing_bands = requested_bands - available_bands_set raise ValueError( f"Requested bands {sorted(list(missing_bands))} are not available in downloads. " f"Available bands: {sorted(list(available_bands_set))}. " f"Please set up a new data directory or download missing bands first." ) # Perform manifest merge self.manifest, merge_stats = self._merge_manifests(existing_manifest) # Log merge results logger.info( f"Manifest merge completed: {merge_stats['preserved']} preserved, " f"{merge_stats['added']} added" ) # Warn about new objects that need downloading if merge_stats["added"] > 0: logger.warning( f"{merge_stats['added']} new objects were added to the manifest " f"but are not yet downloaded. Consider running download_cutouts() to " f"download these missing objects." ) # Save the merged manifest self._save_manifest() except Exception as e: logger.error(f"Failed to load/merge existing manifest: {e}") logger.error( "Cannot proceed with incompatible manifest. Specify new data directory to continue." ) raise
[docs] def _load_existing_manifest(self): """Load existing manifest file.""" if self.catalog_type == "astropy": return Table.read(self.manifest_path) else: import pandas as pd df = pd.read_parquet(self.manifest_path) # Convert back to original catalog type if needed return df
[docs] def _merge_manifests(self, existing_manifest): """Merge existing manifest with current catalog based on object_id.""" # Verify object_id merging is possible if not self.use_object_id: raise ValueError("Cannot merge manifests without object_id") # Check required columns exist in existing manifest existing_cols = ( existing_manifest.colnames if self.catalog_type == "astropy" else existing_manifest.columns ) required_cols = ["cutout_shape", "filename", "downloaded_bands"] for col in required_cols: if col not in existing_cols: raise ValueError(f"Existing manifest missing required column: {col}") # Create object_id sets for comparison if self.catalog_type == "astropy": current_object_ids = set(self.catalog[self.object_id_column]) existing_object_ids = set(existing_manifest[self.object_id_column]) else: current_object_ids = set(self.catalog[self.object_id_column]) existing_object_ids = set(existing_manifest[self.object_id_column]) # Check if current catalog is a subset of existing manifest if current_object_ids.issubset(existing_object_ids): # All objects exist in manifest - just filter for current operations logger.info( f"Current catalog ({len(current_object_ids)} objects)\ is a subset of existing manifest " f"({len(existing_object_ids)} objects). Using existing manifest\ with filtering for operations." ) # Keep the FULL existing manifest but store filtering info for operations self._manifest_filter_object_ids = current_object_ids self._build_catalog_to_manifest_index_map(existing_manifest) merge_stats = { "preserved": len(existing_object_ids), "added": 0, "filtered_for_operations": len(current_object_ids), } return existing_manifest, merge_stats # Current catalog contains new objects - add them to existing manifest logger.info( f"Current catalog contains {len(current_object_ids - existing_object_ids)} new objects. " f"Adding to existing manifest while preserving all {len(existing_object_ids)}\ existing objects." ) # Start with existing manifest to preserve all data if self.catalog_type == "astropy": merged_manifest = existing_manifest.copy() else: merged_manifest = existing_manifest.copy() # Find new objects to add new_object_ids = current_object_ids - existing_object_ids if new_object_ids: # Get rows for new objects from current catalog new_rows_indices = [] for i in range(len(self.catalog)): if self.catalog_type == "astropy": obj_id = self.catalog[i][self.object_id_column] else: obj_id = self.catalog.iloc[i][self.object_id_column] if obj_id in new_object_ids: new_rows_indices.append(i) # Extract new rows from catalog if self.catalog_type == "astropy": new_rows = self.catalog[new_rows_indices] # Add manifest columns for new rows n_new_rows = len(new_rows) empty_shape = np.array([0, 0, 0], dtype=int) new_rows["cutout_shape"] = [empty_shape] * n_new_rows new_rows["filename"] = [""] * n_new_rows new_rows["filename"] = new_rows["filename"].astype("U50") new_rows["downloaded_bands"] = [""] * n_new_rows new_rows["downloaded_bands"] = new_rows["downloaded_bands"].astype("U20") # Stack the tables from astropy.table import vstack merged_manifest = vstack([merged_manifest, new_rows]) else: new_rows = self.catalog.iloc[new_rows_indices].copy() # Add manifest columns for new rows new_rows["cutout_shape"] = [None] * len(new_rows) new_rows["filename"] = [None] * len(new_rows) new_rows["downloaded_bands"] = [None] * len(new_rows) # Concatenate dataframes import pandas as pd merged_manifest = pd.concat([merged_manifest, new_rows], ignore_index=True) merge_stats = { "preserved": len(existing_object_ids), "added": len(new_object_ids), "total_in_manifest": len(merged_manifest), } return merged_manifest, merge_stats
[docs] def _build_catalog_to_manifest_index_map(self, manifest): """Build efficient mapping from catalog indices to manifest indices.""" if self._manifest_filter_object_ids is None: # No filtering needed self._catalog_to_manifest_index_map = None self._manifest_to_catalog_index_map = None return # Create object_id to manifest index lookup manifest_lookup = {} for manifest_idx in range(len(manifest)): if isinstance(manifest, Table): obj_id = manifest[manifest_idx][self.object_id_column] else: obj_id = manifest.iloc[manifest_idx][self.object_id_column] manifest_lookup[obj_id] = manifest_idx # Build catalog index to manifest index mapping and reverse mapping self._catalog_to_manifest_index_map = {} self._manifest_to_catalog_index_map = {} for catalog_idx in range(len(self.catalog)): if isinstance(self.catalog, Table): catalog_obj_id = self.catalog[catalog_idx][self.object_id_column] else: catalog_obj_id = self.catalog.iloc[catalog_idx][self.object_id_column] if catalog_obj_id in manifest_lookup: manifest_idx = manifest_lookup[catalog_obj_id] self._catalog_to_manifest_index_map[catalog_idx] = manifest_idx self._manifest_to_catalog_index_map[manifest_idx] = catalog_idx else: raise ValueError(f"Object ID {catalog_obj_id} from catalog not found in manifest")
[docs] def _add_manifest_columns(self): """Add cutout_shape, filename, and downloaded_bands columns to manifest.""" if self.catalog_type == "astropy": n_rows = len(self.manifest) # Create shape column as integer array (assuming 3D tensors like [3, 64, 64]) empty_shape = np.array([0, 0, 0], dtype=int) # Placeholder shape self.manifest["cutout_shape"] = [empty_shape] * n_rows # Create filename column self.manifest["filename"] = [""] * n_rows self.manifest["filename"] = self.manifest["filename"].astype("U50") # 50-char strings # Add downloaded_bands column to track successful bands in tensor order self.manifest["downloaded_bands"] = [""] * n_rows self.manifest["downloaded_bands"] = self.manifest["downloaded_bands"].astype("U20") # e.g., "g,r" else: # For NestedFrame/HATS, add None columns self.manifest["cutout_shape"] = [None] * len(self.manifest) self.manifest["filename"] = [None] * len(self.manifest) self.manifest["downloaded_bands"] = [None] * len(self.manifest)
[docs] def _get_available_bands_from_manifest(self, manifest): """Get available bands by checking first 10 successful downloads for consistency.""" if len(manifest) == 0: return None, None successful_entries = [] # Find first 10 successful downloads (check last up to 12 entries, but handle small manifests) start_idx = max(0, len(manifest) - 12) for i in range(start_idx, len(manifest)): if len(successful_entries) >= 10: break if self.catalog_type == "astropy": filename = manifest["filename"][i] downloaded_bands_str = manifest["downloaded_bands"][i] else: filename = manifest.iloc[i]["filename"] downloaded_bands_str = manifest.iloc[i]["downloaded_bands"] # Only consider successful downloads if ( filename and filename != "Attempted" and downloaded_bands_str and str(downloaded_bands_str).strip() ): bands = [b.strip() for b in str(downloaded_bands_str).split(",") if b.strip()] if bands: # Non-empty band list successful_entries.append(bands) if not successful_entries: return None, None # Check that all successful entries have identical band lists first_bands = successful_entries[0] for i, bands in enumerate(successful_entries[1:], 1): if bands != first_bands: raise ValueError( f"Inconsistent band ordering in manifest. Entry 0 has {first_bands}, " f"but entry {i} has {bands}. Cannot determine consistent band structure." ) return set(first_bands), first_bands
[docs] def _setup_band_filtering(self, requested_bands, original_band_order): """Setup band filtering to extract only requested bands from cached cutouts.""" # Store filtering info self._original_bands = original_band_order self._filtered_bands = [band for band in original_band_order if band in requested_bands] self._is_filtering_bands = True # Create mapping from filtered bands to original tensor indices self._band_indices = [] for band in self._filtered_bands: self._band_indices.append(self._original_bands.index(band)) # Override the BANDS property to reflect filtered bands self.BANDS = tuple(self._filtered_bands) logger.info(f"Band filtering setup: {self._original_bands} -> {self._filtered_bands}") logger.info(f"Tensor indices to extract: {self._band_indices}")
[docs] def _get_cutout_path(self, idx): """Generate cutout file path for a given index.""" if self.use_object_id: if isinstance(self.catalog, Table): object_id = self.catalog[idx][self.object_id_column] else: object_id = self.catalog.iloc[idx][self.object_id_column] return self.download_dir / f"cutout_{object_id}.pt" else: return self.download_dir / f"cutout_{idx:0{self.padding_length}d}.pt"
[docs] def _update_manifest_entry(self, idx, cutout_shape=None, filename="Attempted", downloaded_bands=None): """ Thread-safe manifest update with periodic saves. Args: idx: Index in the manifest cutout_shape: Shape tuple of the cutout tensor, or None for failed downloads filename: Basename of the saved file, or "Attempted" only when ALL bands fail downloaded_bands: List of band names successfully downloaded in tensor order """ with self._manifest_lock: # Update manifest entries if cutout_shape is not None: shape_array = np.array(list(cutout_shape), dtype=int) self.manifest["cutout_shape"][idx] = shape_array else: # For completely failed downloads if self.catalog_type == "astropy": self.manifest["cutout_shape"][idx] = np.array([0, 0, 0], dtype=int) else: self.manifest["cutout_shape"][idx] = None self.manifest["filename"][idx] = filename # Update downloaded_bands tracking in manifest if downloaded_bands is not None: downloaded_bands_str = ",".join(downloaded_bands) if downloaded_bands else "" self.manifest["downloaded_bands"][idx] = downloaded_bands_str else: self.manifest["downloaded_bands"][idx] = "" # Increment update counter and save periodically self._updates_since_save += 1 if self._updates_since_save >= self._save_interval: self._save_manifest() self._updates_since_save = 0 logger.debug(f"Periodic manifest save completed ({self._save_interval} updates)")
[docs] def _save_manifest(self): """Save manifest in appropriate format (FITS for Astropy, Parquet for HATS).""" try: if self.catalog_type == "astropy": self.manifest.write(self.manifest_path, overwrite=True) else: # For HATS catalogs, save as Parquet # Convert to pandas DataFrame if needed manifest_df = self.manifest.compute() if hasattr(self.manifest, "compute") else self.manifest manifest_df.to_parquet(self.manifest_path) logger.debug(f"Manifest saved to {self.manifest_path}") except Exception as e: logger.error(f"Failed to save manifest: {e}")
[docs] def _sync_manifest_with_filesystem(self): """Sync manifest with actual downloaded files on disk.""" logger.info("Syncing manifest with filesystem...") synced_count = 0 # When filtering is active, we need to map manifest indices to catalog indices for manifest_idx in range(len(self.manifest)): # Find the corresponding catalog index for this manifest entry catalog_idx = None if self._manifest_to_catalog_index_map is not None: # Filtering is active - use reverse mapping for O(1) lookup catalog_idx = self._manifest_to_catalog_index_map.get(manifest_idx) # If no catalog index maps to this manifest index, skip (object not in current catalog) if catalog_idx is None: continue else: # No filtering - direct mapping catalog_idx = manifest_idx # Ensure catalog index is within bounds if catalog_idx >= len(self.catalog): continue cutout_path = self._get_cutout_path(catalog_idx) # Get current manifest state if self.catalog_type == "astropy": current_filename = self.manifest["filename"][manifest_idx] else: current_filename = self.manifest.iloc[manifest_idx]["filename"] if cutout_path.exists(): # File exists on disk if not current_filename or current_filename == "Attempted": # Manifest doesn't reflect the file exists, update it try: cutout = torch.load(cutout_path, map_location="cpu", weights_only=True) bands_for_existing = ( list(self._original_bands) if self._is_filtering_bands else list(self.BANDS) ) self._update_manifest_entry( manifest_idx, cutout.shape, cutout_path.name, bands_for_existing ) synced_count += 1 except Exception as e: logger.warning(f"Could not load existing cutout {cutout_path}: {e}") else: # File doesn't exist on disk if current_filename and current_filename != "Attempted": # Manifest says file exists but it doesn't, reset entry self._update_manifest_entry(manifest_idx, None, "", []) synced_count += 1 if synced_count > 0: logger.info(f"Synced {synced_count} manifest entries with filesystem") self.save_manifest_now()
@staticmethod @functools.lru_cache(maxsize=128)
[docs] def _request_patch_cached( tract_index, patch_index, butler_repo, butler_collections, skymap_name, bands_tuple ): """ Cached patch fetching using static method. Static method means no 'self' in cache key, making it truly global. Thread-safe because each call creates its own Butler instance. """ try: import lsst.daf.butler as butler # Create fresh Butler instance for this call thread_butler = butler.Butler(butler_repo, collections=butler_collections) # Track successful data and failed bands separately data = [] failed_bands = [] for band in bands_tuple: butler_dict = { "tract": tract_index, "patch": patch_index, "skymap": skymap_name, "band": band, } try: image = thread_butler.get("deep_coadd", butler_dict) data.append(image.getImage()) except Exception as e: logger.warning(f"Failed to fetch band {band} for patch {tract_index}-{patch_index}: {e}") failed_bands.append(band) data.append(None) # Add None placeholder for failed band; will be filled with NaNs later logger.debug(f"Fetched patch {tract_index}-{patch_index} from Butler") if failed_bands: logger.debug(f"Failed bands for patch {tract_index}-{patch_index}: {failed_bands}") return data, failed_bands except Exception as e: logger.error(f"Failed to fetch patch {tract_index}-{patch_index}: {e}") raise
[docs] def _fetch_single_cutout(self, row, idx=None, manifest_idx=None): """Fetch cutout, using saved cutout if available, with optional band filtering.""" if idx is not None: cutout_path = self._get_cutout_path(idx) if cutout_path.exists(): # Use tensor caching for already downloaded cutouts if self.use_object_id: if isinstance(self.catalog, Table): object_id = str(self.catalog[idx][self.object_id_column]) else: object_id = str(self.catalog.iloc[idx][self.object_id_column]) # Use cached tensor loading from mixin cutout = self._object_id_to_tensor_cached(object_id) else: # Fallback for non-object-id based datasets cutout = torch.load(cutout_path, map_location="cpu", weights_only=True) # Apply band filtering if needed if self._is_filtering_bands and self._band_indices is not None: cutout = cutout[self._band_indices] logger.debug(f"Applied band filtering to cached cutout {idx}: {cutout.shape}") return self.apply_transform(cutout) # For main thread, use parent's method (original caching) import threading if threading.current_thread() is threading.main_thread(): cutout = super()._fetch_single_cutout(row) # When using parent method, assume all original bands were successful downloaded_bands = list(self._original_bands) if self._is_filtering_bands else list(self.BANDS) else: # For worker threads, use our cached method cutout, downloaded_bands = self._fetch_cutout_with_cache(row) # Apply band filtering to new downloads if needed original_cutout_shape = cutout.shape if self._is_filtering_bands and self._band_indices is not None: cutout = cutout[self._band_indices] # Update downloaded_bands to reflect only the filtered bands that were actually present downloaded_bands = [] for i in self._band_indices: if i >= len(self._original_bands): raise ValueError( f"Band index {i} is out of bounds for\ original_bands (length {len(self._original_bands)}). " f"This indicates a bug in band filtering setup." ) downloaded_bands.append(self._original_bands[i]) logger.debug(f"Applied band filtering to new cutout: {original_cutout_shape} -> {cutout.shape}") # Save cutout if idx provided (save the filtered version) if idx is not None: cutout_path = self._get_cutout_path(idx) torch.save(cutout, cutout_path) # Use manifest_idx for updating manifest, fallback to idx if not provided update_idx = manifest_idx if manifest_idx is not None else idx # Determine if this is a complete failure (all bands failed) if len(downloaded_bands) == 0: # All bands failed - mark as "Attempted" self._update_manifest_entry(update_idx, None, "Attempted", downloaded_bands) else: # At least some bands succeeded - save with proper filename filename = cutout_path.name self._update_manifest_entry(update_idx, cutout.shape, filename, downloaded_bands) return self.apply_transform(cutout)
[docs] def _fetch_cutout_with_cache(self, row): """Generate cutout using cached patch fetching with NaN filling for failed bands.""" from torch import from_numpy # Get tract and patch info (using parent's methods) tract_info, patch_info = self._get_tract_patch(row) box_i = self._parse_box(patch_info, row) # Use cached patch fetching - convert bands list to tuple for hashability bands_tuple = tuple(self._original_bands) if self._is_filtering_bands else tuple(self.BANDS) # Get patch data and failed bands info patch_images, failed_bands = self._request_patch_cached( tract_info.getId(), patch_info.sequential_index, self._butler_config["repo"], self._butler_config["collections"], self._butler_config["skymap"], bands_tuple, ) # Extract cutout with NaN filling for failed bands cutout_data = [] downloaded_bands = [] # Track successfully downloaded bands in order bands_to_process = self._original_bands if self._is_filtering_bands else self.BANDS for _i, (band, image) in enumerate(zip(bands_to_process, patch_images)): if image is not None: # Successfully retrieved band cutout_data.append(image[box_i].getArray()) downloaded_bands.append(band) else: # Failed band - create NaN-filled array with same shape as box nan_array = np.full((box_i.getHeight(), box_i.getWidth()), np.nan, dtype=np.float32) cutout_data.append(nan_array) logger.debug(f"Filled band {band} with NaN for failed retrieval") # Update global band failure statistics if failed_bands: with self._band_failure_lock: for band in failed_bands: self._band_failure_stats[band] += 1 data_np = np.array(cutout_data) data_torch = from_numpy(data_np.astype(np.float32)) # Return cutout and downloaded bands info for manifest tracking return data_torch, downloaded_bands
[docs] def _load_tensor_for_cache(self, object_id: str): """Implementation of TensorCacheMixin abstract method.""" # Find the catalog index for this object_id catalog_idx = None if isinstance(self.catalog, Table): for i in range(len(self.catalog)): if str(self.catalog[i][self.object_id_column]) == object_id: catalog_idx = i break else: # pandas/hats catalog mask = self.catalog[self.object_id_column] == object_id matching_indices = self.catalog.index[mask].tolist() if matching_indices: catalog_idx = matching_indices[0] if catalog_idx is None: raise ValueError(f"Object ID {object_id} not found in catalog") cutout_path = self._get_cutout_path(catalog_idx) if cutout_path.exists(): # Load cached cutout cutout = torch.load(cutout_path, map_location="cpu", weights_only=True) # Apply band filtering if needed if self._is_filtering_bands and self._band_indices is not None: cutout = cutout[self._band_indices] return cutout else: # Cutout not downloaded yet, cannot load for cache raise FileNotFoundError(f"Cutout file {cutout_path} not found. Download cutouts first.")
[docs] def __len__(self): """Return length of current catalog, not the full manifest.""" return len(self.catalog)
[docs] def _get_manifest_index_for_catalog_index(self, catalog_idx): """Map catalog index to manifest index when filtering is active.""" if self._catalog_to_manifest_index_map is None: # No filtering - direct mapping return catalog_idx # Use pre-built mapping for efficiency if catalog_idx in self._catalog_to_manifest_index_map: return self._catalog_to_manifest_index_map[catalog_idx] raise ValueError(f"Catalog index {catalog_idx} not found in mapping")
[docs] def get_image(self, idxs): """Fetch image cutout(s) for given index or indices, using caching and band filtering. Parameters: ----------- idxs: int or slice or list Index or indices to fetch. Returns: -------- torch.Tensor or list of torch.Tensor: Single cutout tensor or list of cutout tensors. """ # Handle single index if isinstance(idxs, int): row = self.catalog[idxs] if isinstance(self.catalog, Table) else self.catalog.iloc[idxs] manifest_idx = self._get_manifest_index_for_catalog_index(idxs) return self._fetch_single_cutout(row, idx=idxs, manifest_idx=manifest_idx) # Handle multiple indices cutouts = [] for idx in idxs: row = self.catalog[idx] if isinstance(self.catalog, Table) else self.catalog.iloc[idx] manifest_idx = self._get_manifest_index_for_catalog_index(idx) cutouts.append(self._fetch_single_cutout(row, idx=idx, manifest_idx=manifest_idx)) return cutouts
[docs] def __getitem__(self, idxs) -> dict: """Modified to pass index for saving cutouts. Parameters: ----------- idxs: int or slice or list Index or indices to fetch. Returns: -------- dict: Dictionary with key 'data' containing another dict of default data fields to return. Currently only 'image' is supported. """ return {"data": {"image": self.get_image(idxs)}}
[docs] def download_cutouts(self, indices=None, sync_filesystem=True, max_workers=None, force_retry=False): """Download cutouts using multiple threads with caching. Args: indices: List of indices to download, or None for all sync_filesystem: Whether to sync manifest with existing files on disk max_workers: Maximum number of worker threads, or None to use default force_retry: Whether to retry previously failed downloads """ from concurrent.futures import ThreadPoolExecutor, as_completed if indices is None: indices = range(len(self)) # Optionally sync manifest with filesystem before downloading if sync_filesystem: self._sync_manifest_with_filesystem() # Determine which cutouts need downloading indices_to_download = [] for catalog_idx in indices: manifest_idx = self._get_manifest_index_for_catalog_index(catalog_idx) cutout_path = self._get_cutout_path(catalog_idx) # Check if file exists on disk if cutout_path.exists(): continue # Check manifest status if self.catalog_type == "astropy": filename = self.manifest["filename"][manifest_idx] else: filename = self.manifest.iloc[manifest_idx]["filename"] # Skip if already attempted and failed (unless force_retry is True) if filename == "Attempted" and not force_retry: logger.debug( f"Skipping previously failed download for catalog\ index {catalog_idx} (manifest index {manifest_idx})" ) continue indices_to_download.append((catalog_idx, manifest_idx)) if not indices_to_download: logger.info("All cutouts already downloaded") return # Determine number of workers if max_workers is None: max_workers = self._determine_numprocs_download() logger.info(f"Downloading {len(indices_to_download)} cutouts using {max_workers} threads.") with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = { executor.submit(self._download_single_cutout, catalog_idx, manifest_idx): ( catalog_idx, manifest_idx, ) for catalog_idx, manifest_idx in indices_to_download } with tqdm(total=len(indices_to_download), desc="Downloading cutouts") as pbar: for future in as_completed(futures): try: future.result() pbar.update(1) except Exception as e: catalog_idx, manifest_idx = futures[future] logger.error( f"Failed to download cutout\ catalog_idx={catalog_idx}, manifest_idx={manifest_idx}: {e}" ) self._update_manifest_entry(manifest_idx, None, "Attempted", []) pbar.update(1) # Final manifest save with self._manifest_lock: if self._updates_since_save > 0: self._save_manifest() self._updates_since_save = 0 # Log cache and download stats cache_info = self._request_patch_cached.cache_info() logger.info(f"Download complete. Cache stats: {cache_info}") logger.info(f"Manifest saved to {self.manifest_path}")
[docs] def _download_single_cutout(self, catalog_idx, manifest_idx): """Helper method to download a single cutout.""" cutout_path = self._get_cutout_path(catalog_idx) if cutout_path.exists(): return try: row = ( self.catalog[catalog_idx] if isinstance(self.catalog, Table) else self.catalog.iloc[catalog_idx] ) cutout, downloaded_bands = self._fetch_cutout_with_cache(row) # Only save cutout and mark as successful if at least one band worked if len(downloaded_bands) == 0: # All bands failed - don't save file, mark as "Attempted" self._update_manifest_entry(manifest_idx, None, "Attempted", downloaded_bands) raise RuntimeError( f"All bands failed for cutout\ catalog_idx={catalog_idx}, manifest_idx={manifest_idx}" ) else: # At least some bands succeeded - save the cutout torch.save(cutout, cutout_path) filename = cutout_path.name self._update_manifest_entry(manifest_idx, cutout.shape, filename, downloaded_bands) except Exception as e: logger.error( f"Failed to download cutout\ catalog_idx={catalog_idx}, manifest_idx={manifest_idx}: {e}" ) # Update manifest with failed attempt (all bands failed) self._update_manifest_entry(manifest_idx, None, "Attempted", []) raise
[docs] def cache_info(self): """Get cache statistics.""" return self._request_patch_cached.cache_info()
[docs] def clear_cache(self): """Clear the LRU cache.""" self._request_patch_cached.cache_clear() logger.info("Cleared patch cache")
[docs] def manifest_stats(self): """Get manifest statistics including downloaded bands information.""" with self._manifest_lock: if self.catalog_type == "astropy": successful = sum( 1 for filename in self.manifest["filename"] if filename and filename != "Attempted" ) failed = sum(1 for filename in self.manifest["filename"] if filename == "Attempted") pending = sum(1 for filename in self.manifest["filename"] if not filename) expected_band_count = ( len(self._original_bands) if self._is_filtering_bands else len(self.BANDS) ) # Add statistics about partial downloads (cutouts with missing bands) partial_downloads = sum( 1 for i, (filename, downloaded_bands) in enumerate( zip(self.manifest["filename"], self.manifest["downloaded_bands"]) ) if filename and filename != "Attempted" and downloaded_bands and len(downloaded_bands.split(",")) < expected_band_count ) else: successful = sum( 1 for filename in self.manifest["filename"] if filename and filename != "Attempted" ) failed = sum(1 for filename in self.manifest["filename"] if filename == "Attempted") pending = sum(1 for filename in self.manifest["filename"] if filename is None) expected_band_count = ( len(self._original_bands) if self._is_filtering_bands else len(self.BANDS) ) partial_downloads = sum( 1 for i, (filename, downloaded_bands) in enumerate( zip(self.manifest["filename"], self.manifest["downloaded_bands"]) ) if filename and filename != "Attempted" and downloaded_bands and len(str(downloaded_bands).split(",")) < expected_band_count ) # Collect band failure statistics with self._band_failure_lock: band_stats = dict(self._band_failure_stats) return { "total": len(self.manifest), "successful": successful, "failed": failed, # All bands failed "pending": pending, "partial_downloads": partial_downloads, # Some bands missing "band_failure_counts": band_stats, "manifest_path": str(self.manifest_path), }
[docs] def band_filtering_info(self): """Get information about current band filtering configuration.""" if not self._is_filtering_bands: return { "is_filtering": False, "requested_bands": list(self.BANDS), "original_bands": None, "filtered_bands": None, "band_indices": None, } return { "is_filtering": True, "requested_bands": list(self.BANDS), "original_bands": self._original_bands, "filtered_bands": self._filtered_bands, "band_indices": self._band_indices, }
[docs] def save_manifest_now(self): """Force immediate manifest save.""" with self._manifest_lock: self._save_manifest() self._updates_since_save = 0 logger.info("Manifest manually saved")
@staticmethod
[docs] def _determine_numprocs_download(): """Determine number of threads for downloading.""" # TODO:This is a placeholder for actual logic to determine number of threads. return 1
[docs] def reset_failed_downloads(self): """Reset failed download attempts to allow retry.""" reset_count = 0 for idx in range(len(self.manifest)): if self.catalog_type == "astropy": filename = self.manifest["filename"][idx] else: filename = self.manifest.iloc[idx]["filename"] if filename == "Attempted": self._update_manifest_entry(idx, None, "") reset_count += 1 if reset_count > 0: logger.info(f"Reset {reset_count} failed download attempts") self.save_manifest_now() return reset_count
[docs] def download_progress(self): """Get detailed download progress information.""" stats = self.manifest_stats() # Calculate additional metrics total = stats["total"] successful = stats["successful"] failed = stats["failed"] pending = stats["pending"] progress_percent = (successful / total * 100) if total > 0 else 0 failure_rate = (failed / (successful + failed) * 100) if (successful + failed) > 0 else 0 return { **stats, "progress_percent": round(progress_percent, 2), "failure_rate": round(failure_rate, 2), "completed": successful + failed, "remaining": pending, }
[docs] def download_summary(self): """ Get detailed download and band analysis, accounting for band filtering. """ stats = self.manifest_stats() # Determine which bands to analyze based on filtering bands_to_analyze = self._filtered_bands if self._is_filtering_bands else list(self.BANDS) all_possible_bands = self._original_bands if self._is_filtering_bands else list(self.BANDS) # Analyze downloaded bands per cutout band_success_analysis = {band: 0 for band in bands_to_analyze} complete_downloads = 0 if self.catalog_type == "astropy": downloaded_bands_entries = self.manifest["downloaded_bands"] else: downloaded_bands_entries = self.manifest["downloaded_bands"] for downloaded_bands_str in downloaded_bands_entries: if downloaded_bands_str and str(downloaded_bands_str).strip(): downloaded_bands = [b.strip() for b in str(downloaded_bands_str).split(",") if b.strip()] # Filter to only bands we're interested in relevant_bands = [band for band in downloaded_bands if band in bands_to_analyze] # Count successful downloads per band for band in relevant_bands: if band in band_success_analysis: band_success_analysis[band] += 1 # Count complete downloads (all requested bands present) if len(relevant_bands) == len(bands_to_analyze): complete_downloads += 1 filtering_info = "No filtering applied" if self._is_filtering_bands: filtering_info = f"Filtering {all_possible_bands} -> {bands_to_analyze}" return { "total_cutouts": stats["total"], "complete_downloads": complete_downloads, # All requested bands present "partial_downloads": stats["partial_downloads"], # Some requested bands missing "failed_downloads": stats["failed"], # All bands failed "pending_downloads": stats["pending"], "band_success_counts": band_success_analysis, # How many cutouts have each requested band "band_failure_counts": stats["band_failure_counts"], # How many times each band failed "expected_bands": bands_to_analyze, "band_filtering_info": filtering_info, "percentage_complete": round(complete_downloads / stats["total"] * 100, 2) if stats["total"] > 0 else 0, }