Source code for lsdb.core.search.index_search
from __future__ import annotations
from typing import Any
import nested_pandas as npd
import numpy as np
from hats.catalog.index.index_catalog import IndexCatalog as HCIndexCatalog
from lsdb.core.search.abstract_search import AbstractSearch
from lsdb.types import HCCatalogTypeVar
[docs]
class IndexSearch(AbstractSearch):
"""Find rows by column values using HATS index catalogs."""
values: dict[str, Any]
"""Mapping of field name to the value we want to match it to"""
index_catalogs: dict[str, HCIndexCatalog]
"""Mapping of field name to respective index catalog"""
[docs]
def __init__(self, values: dict[str, Any], index_catalogs: dict[str, HCIndexCatalog], fine: bool = True):
super().__init__(fine)
if not all(key in index_catalogs for key in values):
raise ValueError(
f"There is a mismatch between the queried fields: "
f"{values.keys()} and the fields of the provided index"
f" catalogs: {index_catalogs.keys()}"
)
self.values = values
self.index_catalogs = index_catalogs
def perform_hc_catalog_filter(self, hc_structure: HCCatalogTypeVar) -> HCCatalogTypeVar:
"""Determine the pixels for which there is a result in each field
Parameters
----------
hc_structure: HCCatalogTypeVar
The hats catalog where partitions will be filtered.
Returns
-------
HCCatalogTypeVar
The filtered hats catalog.
"""
all_pixels = set(hc_structure.get_healpix_pixels())
for field_name, field_value in self.values.items():
field_value = field_value if isinstance(field_value, list) else [field_value]
pixels_for_field = set(self.index_catalogs[field_name].loc_partitions(field_value))
all_pixels = all_pixels.intersection(pixels_for_field)
return hc_structure.filter_from_pixel_list(list(all_pixels))
def search_points(self, frame: npd.NestedFrame, _) -> npd.NestedFrame:
"""Determine the search results within a data frame
Parameters
----------
frame: npd.NestedFrame
A pixel data frame.
_: hc.catalog.TableProperties
The HATS catalog properties.
Returns
-------
npd.NestedFrame
The filtered pixel data frame.
"""
filter_mask = np.ones(len(frame), dtype=np.bool)
for field_name, field_index_catalog in self.index_catalogs.items():
index_column = field_index_catalog.catalog_info.indexing_column
field_values = (
self.values[field_name]
if isinstance(self.values[field_name], list)
else [self.values[field_name]]
)
mask = frame[index_column].isin(field_values)
filter_mask = filter_mask & mask
return frame[filter_mask]