from __future__ import annotations
from collections.abc import Sequence
from typing import (
TYPE_CHECKING,
Any,
Generic,
Literal,
TypeAlias,
cast,
get_args,
get_origin,
overload,
)
from typing_extensions import Self
from ..entity_types import EntityType
from ..errors import EntityTypeError
from .core import E, SnowflakeSkuid
if TYPE_CHECKING:
from pydantic import GetCoreSchemaHandler
from pydantic.annotated_handlers import GetJsonSchemaHandler
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import CoreSchema
try:
from pydantic_core import core_schema
except ImportError:
core_schema = None # ty: ignore[invalid-assignment]
class SkuidStrBase(str, Generic[E]):
expected_etype: tuple[EntityType, ...] = tuple(EntityType)
_subclass_cache: CacheType[E] = {}
# https://github.com/pydantic/pydantic/discussions/4262
def __set__(self, _, value: str | int): ...
@classmethod
def _validate_etype(cls, value: str | int) -> SnowflakeSkuid[E]:
if isinstance(value, int):
if len(cls.expected_etype) != 1:
raise ValueError(
f"{cls.__name__} cannot be constructed from an integer without a single "
"entity type. Use a parameterized form, e.g. "
f"{cls.__name__}[Literal[EntityType.TXN]](value)."
)
etype = cast(E, cls.expected_etype[0])
return SnowflakeSkuid(etype, value)
parsed = SnowflakeSkuid[E].parse(cast(AnySkuidStr[E], value))
if parsed.entity_type not in cls.expected_etype:
raise EntityTypeError(f"Invalid entity prefix: {cls.__name__} does not accept {value}")
return parsed
@classmethod
def _get_entity_specific_subclass(cls, args: Sequence[EntityType]) -> type[SkuidStrBase[E]]:
"""Generate a subclass that accepts specific entity type(s).
We create a custom subclass here so that we can capture the typing
information as accepted_etypes and enforce that validation at runtime,
not leaving it up to static type checkers.
"""
if len(args) != 1:
# Relax this in the future if multi-type polymorphic fields are needed,
# e.g. ShortSkuid[Literal[et.SBD, et.SBI, et.EB]]
raise ValueError(
f"{cls.__name__} can only be parameterized for a single EntityType, not {args}"
)
# Coerce string prefixes to EntityType members so that both
# Literal[EntityType.SBD] and Literal["sbd"] are accepted at runtime.
etypes = tuple(EntityType.validate(arg) for arg in args)
key = (cls, etypes)
if cached := cls._subclass_cache.get(key):
return cached
etypes_str = ", ".join(f"EntityType.{e.name}" for e in etypes)
name = f"{cls.__name__}[Literal[{etypes_str}]]"
subclass = type(name, (cls,), {"expected_etype": etypes})
cls._subclass_cache[key] = subclass
return subclass
@classmethod
@overload
def __class_getitem__(cls, type_arg: type[EntityType]) -> type[SkuidStrBase[EntityType]]: ...
@classmethod
@overload
def __class_getitem__(cls, type_arg: E) -> type[SkuidStrBase[E]]: ...
@classmethod
def __class_getitem__(cls, type_arg: Any) -> Any:
"""
Override __class_getitem__ to capture runtime type parameters.
Custom method that captures type parameters for use in runtime
validation & raises runtime errors for invalid types.
"""
if type_arg is EntityType:
return cls
elif get_origin(type_arg) is Literal:
args = get_args(type_arg)
return cls._get_entity_specific_subclass(args)
else:
# Delegate to str/__class_getitem__ for TypeVar / non-Literal parameters.
# `object.__class_getitem__` is not in Pylance's typeshed stubs even though
# it exists at runtime; suppress the false-positive here.
return super().__class_getitem__(type_arg) # pyright: ignore
@classmethod
def _get_format_label(cls) -> str:
label = "skuid.{form}"
if len(cls.expected_etype) == 1:
return label + f".{cls.expected_etype[0]}"
return label
@classmethod
def __get_pydantic_core_schema__(
cls,
_source: type[Any],
_handler: GetCoreSchemaHandler,
) -> CoreSchema:
if core_schema is None:
raise RuntimeError("Pydantic is required to build Pydantic core schema.")
return core_schema.no_info_plain_validator_function(
cls,
serialization=core_schema.plain_serializer_function_ser_schema(str),
)
@classmethod
def __get_pydantic_json_schema__(
cls,
core_schema_: CoreSchema,
handler: GetJsonSchemaHandler,
) -> JsonSchemaValue:
if core_schema is None:
raise RuntimeError("Pydantic is required to build Pydantic JSON schema.")
# The plain-validator core schema has no JSON Schema equivalent, so we
# derive the JSON schema from a plain str schema and annotate it instead.
schema = handler(core_schema.str_schema())
schema["format"] = cls._get_format_label()
return schema
CacheType: TypeAlias = dict[
tuple[type[SkuidStrBase[E]], tuple[EntityType, ...]], type[SkuidStrBase[E]]
]
[docs]
class ShortSkuid(SkuidStrBase[E]):
"""A SKUID string normalised to short form (``type-<base32>``).
Accepts a short or long SKUID string and coerces it to short form. When
parameterised with a specific :class:`~ska_ser_skuid.EntityType` via
``ShortSkuid[Literal[EntityType.SBD]]``, it additionally:
* rejects strings whose prefix does not match, and
* accepts bare snowflake integers and applies the known prefix.
``ShortSkuid`` is a plain :class:`str` subclass, so it can be used anywhere
a string is expected. It also integrates with Pydantic v2 as a field type.
Examples::
from ska_ser_skuid import ShortSkuid, EntityType
from typing import Literal
# Unparameterised — accepts any entity type, string input only
val = ShortSkuid("sbd-0-20260107-3v") # -> "sbd-3v"
# Parameterised — enforces prefix, accepts integer snowflake input
TxnShort = ShortSkuid[Literal[EntityType.TXN]]
val = TxnShort(123) # -> e.g. "txn-3v"
"""
def __new__(cls, value: str | int) -> Self:
parsed = cls._validate_etype(value)
return super().__new__(cls, parsed.short())
@classmethod
def _get_format_label(cls) -> str:
return super()._get_format_label().format(form="short")
[docs]
class LongSkuid(SkuidStrBase[E]):
"""A SKUID string normalised to long form (``type-<generator_id>-<YYYYMMDD>-<base32>``).
Accepts a short or long SKUID string and coerces it to long form. When
parameterised with a specific :class:`~ska_ser_skuid.EntityType` via
``LongSkuid[Literal[EntityType.SBD]]``, it additionally:
* rejects strings whose prefix does not match, and
* accepts bare snowflake integers and applies the known prefix.
``LongSkuid`` is a plain :class:`str` subclass, so it can be used anywhere
a string is expected. It also integrates with Pydantic v2 as a field type.
Examples::
from ska_ser_skuid import LongSkuid, EntityType
from typing import Literal
# Unparameterised — accepts any entity type, string input only
val = LongSkuid("sbd-3v") # -> "sbd-0-20260107-3v"
# Parameterised — enforces prefix, accepts integer snowflake input
TxnLong = LongSkuid[Literal[EntityType.TXN]]
val = TxnLong(123) # -> e.g. "txn-0-20260309-3v"
"""
def __new__(cls, value: str | int) -> Self:
parsed = cls._validate_etype(value)
return super().__new__(cls, parsed.long())
@classmethod
def _get_format_label(cls) -> str:
return super()._get_format_label().format(form="long")
AnySkuidStr: TypeAlias = ShortSkuid[E] | LongSkuid[E]