Source code for sqlspec.data_dictionary._types

from typing import TYPE_CHECKING, TypedDict, cast

from mypy_extensions import mypyc_attr

if TYPE_CHECKING:
    from re import Pattern

__all__ = (
    "ColumnMetadata",
    "DialectConfig",
    "FeatureFlags",
    "FeatureVersions",
    "ForeignKeyMetadata",
    "IndexMetadata",
    "TableMetadata",
    "VersionCacheResult",
    "VersionInfo",
)


class ColumnMetadata(TypedDict, total=False):
    """Metadata for a database column."""

    schema_name: str
    table_name: str
    column_name: str
    data_type: str
    is_nullable: str | bool | None
    column_default: str | None
    ordinal_position: int
    max_length: int
    numeric_precision: int
    numeric_scale: int
    is_primary: bool | int
    is_unique: bool | int
    extra: str


class TableMetadata(TypedDict, total=False):
    """Metadata for a database table."""

    schema_name: str
    table_name: str
    table_type: str
    table_catalog: str
    table_schema: str
    dependency_level: int
    level: int


class IndexMetadata(TypedDict, total=False):
    """Metadata for a database index."""

    schema_name: str
    table_name: str
    index_name: str
    columns: list[str] | str | None
    is_unique: bool | int
    is_primary: bool | int


@mypyc_attr(allow_interpreted_subclasses=False)
class ForeignKeyMetadata:
    """Metadata for a foreign key constraint."""

    __slots__ = (
        "column_name",
        "constraint_name",
        "referenced_column",
        "referenced_schema",
        "referenced_table",
        "schema",
        "table_name",
    )

    def __init__(
        self,
        table_name: str,
        column_name: str,
        referenced_table: str,
        referenced_column: str,
        constraint_name: str | None = None,
        schema: str | None = None,
        referenced_schema: str | None = None,
    ) -> None:
        self.table_name = table_name
        self.column_name = column_name
        self.referenced_table = referenced_table
        self.referenced_column = referenced_column
        self.constraint_name = constraint_name
        self.schema = schema
        self.referenced_schema = referenced_schema

    def __repr__(self) -> str:
        return (
            f"ForeignKeyMetadata(table_name={self.table_name!r}, column_name={self.column_name!r}, "
            f"referenced_table={self.referenced_table!r}, referenced_column={self.referenced_column!r}, "
            f"constraint_name={self.constraint_name!r}, schema={self.schema!r}, "
            f"referenced_schema={self.referenced_schema!r})"
        )

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, ForeignKeyMetadata):
            return NotImplemented
        return (
            self.table_name == other.table_name
            and self.column_name == other.column_name
            and self.referenced_table == other.referenced_table
            and self.referenced_column == other.referenced_column
            and self.constraint_name == other.constraint_name
            and self.schema == other.schema
            and self.referenced_schema == other.referenced_schema
        )

    def __hash__(self) -> int:
        return hash((
            self.table_name,
            self.column_name,
            self.referenced_table,
            self.referenced_column,
            self.constraint_name,
            self.schema,
            self.referenced_schema,
        ))

    def __reduce__(
        self,
    ) -> "tuple[type[ForeignKeyMetadata], tuple[str, str, str, str, str | None, str | None, str | None]]":
        return (
            self.__class__,
            (
                self.table_name,
                self.column_name,
                self.referenced_table,
                self.referenced_column,
                self.constraint_name,
                self.schema,
                self.referenced_schema,
            ),
        )


@mypyc_attr(allow_interpreted_subclasses=True)
class VersionInfo:
    """Parsed database version info."""

    __slots__ = ("major", "minor", "patch")

    def __init__(self, major: int, minor: int = 0, patch: int = 0) -> None:
        """Initialize version info.

        Args:
            major: Major version number
            minor: Minor version number
            patch: Patch version number
        """
        self.major = major
        self.minor = minor
        self.patch = patch

    @property
    def version_tuple(self) -> "tuple[int, int, int]":
        """Get version as tuple for comparison."""
        return (self.major, self.minor, self.patch)

    def __str__(self) -> str:
        """String representation of version info."""
        return f"{self.major}.{self.minor}.{self.patch}"

    def __repr__(self) -> str:
        """Detailed string representation."""
        return f"VersionInfo({self.major}, {self.minor}, {self.patch})"

    def __eq__(self, other: object) -> bool:
        """Check version equality."""
        if not isinstance(other, VersionInfo):
            return NotImplemented
        return self.version_tuple == other.version_tuple

    def __lt__(self, other: "VersionInfo") -> bool:
        """Check if this version is less than another."""
        return self.version_tuple < other.version_tuple

    def __le__(self, other: "VersionInfo") -> bool:
        """Check if this version is less than or equal to another."""
        return self.version_tuple <= other.version_tuple

    def __gt__(self, other: "VersionInfo") -> bool:
        """Check if this version is greater than another."""
        return self.version_tuple > other.version_tuple

    def __ge__(self, other: "VersionInfo") -> bool:
        """Check if this version is greater than or equal to another."""
        return self.version_tuple >= other.version_tuple

    def __hash__(self) -> int:
        """Make VersionInfo hashable based on version tuple."""
        return hash(self.version_tuple)

    def __reduce__(self) -> "tuple[type[VersionInfo], tuple[int, int, int]]":
        return (self.__class__, self.version_tuple)


VersionCacheResult = tuple[bool, VersionInfo | None]
"""Return type for version cache lookup methods."""


[docs] class FeatureFlags(TypedDict, total=False): """Typed feature flags for data dictionary dialects.""" supports_arrays: bool supports_clustering: bool supports_cte: bool supports_generators: bool supports_geography: bool supports_in_memory: bool supports_index_clustering: bool supports_interleaved_tables: bool supports_json: bool supports_maps: bool supports_partitioning: bool supports_prepared_statements: bool supports_returning: bool supports_schemas: bool supports_structs: bool supports_transactions: bool supports_upsert: bool supports_uuid: bool supports_window_functions: bool
[docs] class FeatureVersions(TypedDict, total=False): """Typed feature version requirements for data dictionary dialects.""" supports_cte: "VersionInfo" supports_json: "VersionInfo" supports_jsonb: "VersionInfo" supports_partitioning: "VersionInfo" supports_returning: "VersionInfo" supports_upsert: "VersionInfo" supports_window_functions: "VersionInfo"
@mypyc_attr(allow_interpreted_subclasses=False) class DialectConfig: """Static configuration for a database dialect.""" __slots__ = ( "default_schema", "feature_flags", "feature_versions", "name", "parameter_style", "type_mappings", "version_pattern", ) def __init__( self, name: str, feature_versions: "FeatureVersions", feature_flags: "FeatureFlags", type_mappings: "dict[str, str]", version_pattern: "Pattern[str]", default_schema: "str | None" = None, parameter_style: str = "named", ) -> None: """Initialize a dialect configuration. Args: name: Dialect name used for lookups. feature_versions: Minimum versions required for features. feature_flags: Static boolean feature flags. type_mappings: Logical type to dialect type mapping. version_pattern: Regex used to parse version strings. default_schema: Default schema for dialect. parameter_style: Default parameter style for dialect SQL. """ self.name: str = name self.feature_versions: FeatureVersions = feature_versions self.feature_flags: FeatureFlags = feature_flags self.type_mappings: dict[str, str] = type_mappings self.version_pattern: Pattern[str] = version_pattern self.default_schema: str | None = default_schema self.parameter_style: str = parameter_style def get_feature_flag(self, feature: str) -> "bool | None": """Return a feature flag value if defined. Args: feature: Feature flag name. Returns: Feature flag value or None if unknown. """ return cast("bool | None", self.feature_flags.get(feature)) def get_feature_version(self, feature: str) -> "VersionInfo | None": """Return required version for a feature if defined. Args: feature: Feature version name. Returns: VersionInfo if defined, otherwise None. """ return cast("VersionInfo | None", self.feature_versions.get(feature)) def get_optimal_type(self, logical_type: str) -> str: """Return the dialect-specific type for a logical type. Args: logical_type: Logical type name. Returns: Dialect-specific type string. """ default_type = self.type_mappings.get("text", "TEXT") return self.type_mappings.get(logical_type, default_type)