Source code for ska_ser_skuid.skuid.strings

from __future__ import annotations

from collections.abc import Sequence
from typing import (
    TYPE_CHECKING,
    Any,
    Generic,
    Literal,
    TypeAlias,
    cast,
    get_args,
    get_origin,
    overload,
)

from typing_extensions import Self

from ..entity_types import EntityType
from ..errors import EntityTypeError
from .core import E, SnowflakeSkuid

if TYPE_CHECKING:
    from pydantic import GetCoreSchemaHandler
    from pydantic.annotated_handlers import GetJsonSchemaHandler
    from pydantic.json_schema import JsonSchemaValue
    from pydantic_core import CoreSchema

try:
    from pydantic_core import core_schema
except ImportError:
    core_schema = None  # ty: ignore[invalid-assignment]


class SkuidStrBase(str, Generic[E]):
    expected_etype: tuple[EntityType, ...] = tuple(EntityType)
    _subclass_cache: CacheType[E] = {}

    # https://github.com/pydantic/pydantic/discussions/4262
    def __set__(self, _, value: str | int): ...

    @classmethod
    def _validate_etype(cls, value: str | int) -> SnowflakeSkuid[E]:
        if isinstance(value, int):
            if len(cls.expected_etype) != 1:
                raise ValueError(
                    f"{cls.__name__} cannot be constructed from an integer without a single "
                    "entity type. Use a parameterized form, e.g. "
                    f"{cls.__name__}[Literal[EntityType.TXN]](value)."
                )
            etype = cast(E, cls.expected_etype[0])
            return SnowflakeSkuid(etype, value)
        parsed = SnowflakeSkuid[E].parse(cast(AnySkuidStr[E], value))
        if parsed.entity_type not in cls.expected_etype:
            raise EntityTypeError(f"Invalid entity prefix: {cls.__name__} does not accept {value}")
        return parsed

    @classmethod
    def _get_entity_specific_subclass(cls, args: Sequence[EntityType]) -> type[SkuidStrBase[E]]:
        """Generate a subclass that accepts specific entity type(s).

        We create a custom subclass here so that we can capture the typing
        information as accepted_etypes and enforce that validation at runtime,
        not leaving it up to static type checkers.
        """
        if len(args) != 1:
            # Relax this in the future if multi-type polymorphic fields are needed,
            # e.g. ShortSkuid[Literal[et.SBD, et.SBI, et.EB]]
            raise ValueError(
                f"{cls.__name__} can only be parameterized for a single EntityType, not {args}"
            )
        # Coerce string prefixes to EntityType members so that both
        # Literal[EntityType.SBD] and Literal["sbd"] are accepted at runtime.
        etypes = tuple(EntityType.validate(arg) for arg in args)
        key = (cls, etypes)
        if cached := cls._subclass_cache.get(key):
            return cached

        etypes_str = ", ".join(f"EntityType.{e.name}" for e in etypes)
        name = f"{cls.__name__}[Literal[{etypes_str}]]"
        subclass = type(name, (cls,), {"expected_etype": etypes})
        cls._subclass_cache[key] = subclass
        return subclass

    @classmethod
    @overload
    def __class_getitem__(cls, type_arg: type[EntityType]) -> type[SkuidStrBase[EntityType]]: ...

    @classmethod
    @overload
    def __class_getitem__(cls, type_arg: E) -> type[SkuidStrBase[E]]: ...

    @classmethod
    def __class_getitem__(cls, type_arg: Any) -> Any:
        """
        Override __class_getitem__ to capture runtime type parameters.

        Custom method that captures type parameters for use in runtime
        validation & raises runtime errors for invalid types.
        """
        if type_arg is EntityType:
            return cls
        elif get_origin(type_arg) is Literal:
            args = get_args(type_arg)
            return cls._get_entity_specific_subclass(args)
        else:
            # Delegate to str/__class_getitem__ for TypeVar / non-Literal parameters.
            # `object.__class_getitem__` is not in Pylance's typeshed stubs even though
            # it exists at runtime; suppress the false-positive here.
            return super().__class_getitem__(type_arg)  # pyright: ignore

    @classmethod
    def _get_format_label(cls) -> str:
        label = "skuid.{form}"
        if len(cls.expected_etype) == 1:
            return label + f".{cls.expected_etype[0]}"
        return label

    @classmethod
    def __get_pydantic_core_schema__(
        cls,
        _source: type[Any],
        _handler: GetCoreSchemaHandler,
    ) -> CoreSchema:
        if core_schema is None:
            raise RuntimeError("Pydantic is required to build Pydantic core schema.")
        return core_schema.no_info_plain_validator_function(
            cls,
            serialization=core_schema.plain_serializer_function_ser_schema(str),
        )

    @classmethod
    def __get_pydantic_json_schema__(
        cls,
        core_schema_: CoreSchema,
        handler: GetJsonSchemaHandler,
    ) -> JsonSchemaValue:
        if core_schema is None:
            raise RuntimeError("Pydantic is required to build Pydantic JSON schema.")
        # The plain-validator core schema has no JSON Schema equivalent, so we
        # derive the JSON schema from a plain str schema and annotate it instead.
        schema = handler(core_schema.str_schema())
        schema["format"] = cls._get_format_label()
        return schema


CacheType: TypeAlias = dict[
    tuple[type[SkuidStrBase[E]], tuple[EntityType, ...]], type[SkuidStrBase[E]]
]


[docs] class ShortSkuid(SkuidStrBase[E]): """A SKUID string normalised to short form (``type-<base32>``). Accepts a short or long SKUID string and coerces it to short form. When parameterised with a specific :class:`~ska_ser_skuid.EntityType` via ``ShortSkuid[Literal[EntityType.SBD]]``, it additionally: * rejects strings whose prefix does not match, and * accepts bare snowflake integers and applies the known prefix. ``ShortSkuid`` is a plain :class:`str` subclass, so it can be used anywhere a string is expected. It also integrates with Pydantic v2 as a field type. Examples:: from ska_ser_skuid import ShortSkuid, EntityType from typing import Literal # Unparameterised — accepts any entity type, string input only val = ShortSkuid("sbd-0-20260107-3v") # -> "sbd-3v" # Parameterised — enforces prefix, accepts integer snowflake input TxnShort = ShortSkuid[Literal[EntityType.TXN]] val = TxnShort(123) # -> e.g. "txn-3v" """ def __new__(cls, value: str | int) -> Self: parsed = cls._validate_etype(value) return super().__new__(cls, parsed.short()) @classmethod def _get_format_label(cls) -> str: return super()._get_format_label().format(form="short")
[docs] class LongSkuid(SkuidStrBase[E]): """A SKUID string normalised to long form (``type-<generator_id>-<YYYYMMDD>-<base32>``). Accepts a short or long SKUID string and coerces it to long form. When parameterised with a specific :class:`~ska_ser_skuid.EntityType` via ``LongSkuid[Literal[EntityType.SBD]]``, it additionally: * rejects strings whose prefix does not match, and * accepts bare snowflake integers and applies the known prefix. ``LongSkuid`` is a plain :class:`str` subclass, so it can be used anywhere a string is expected. It also integrates with Pydantic v2 as a field type. Examples:: from ska_ser_skuid import LongSkuid, EntityType from typing import Literal # Unparameterised — accepts any entity type, string input only val = LongSkuid("sbd-3v") # -> "sbd-0-20260107-3v" # Parameterised — enforces prefix, accepts integer snowflake input TxnLong = LongSkuid[Literal[EntityType.TXN]] val = TxnLong(123) # -> e.g. "txn-0-20260309-3v" """ def __new__(cls, value: str | int) -> Self: parsed = cls._validate_etype(value) return super().__new__(cls, parsed.long()) @classmethod def _get_format_label(cls) -> str: return super()._get_format_label().format(form="long")
AnySkuidStr: TypeAlias = ShortSkuid[E] | LongSkuid[E]