Skip to content

timebasedcv.splitstate

timebasedcv.splitstate.SplitState dataclass

Bases: Generic[DateTimeLike]

A SplitState represents the state of a split in terms of its four cut/split points.

Namely these are start and end of training set, start and end of forecasting/test set.

The class ensures that the split is valid by checking that the attributes are of the correct type and are ordered chronologically.

The class provides properties to calculate the length of the training set, forecast set, gap between them, and the total length of the split.

Parameters:

Name Type Description Default
train_start DateTimeLike

The start of the training set.

required
train_end DateTimeLike

The end of the training set.

required
forecast_start DateTimeLike

The start of the forecast set.

required
forecast_end DateTimeLike

The end of the forecast set.

required

Raises:

Type Description
TypeError

If any of the attributes is not of type datetime, date or pd.Timestamp.

ValueError

If the attributes are not ordered chronologically.

Source code in timebasedcv/splitstate.py
@dataclass(frozen=True)
class SplitState(Generic[DateTimeLike]):
    """A `SplitState` represents the state of a split in terms of its four cut/split points.

    Namely these are start and end of training set, start and end of forecasting/test set.

    The class ensures that the split is valid by checking that the attributes are of the correct type and are ordered
    chronologically.

    The class provides properties to calculate the length of the training set, forecast set, gap between them, and the
    total length of the split.

    Arguments:
        train_start: The start of the training set.
        train_end: The end of the training set.
        forecast_start: The start of the forecast set.
        forecast_end: The end of the forecast set.

    Raises:
        TypeError: If any of the attributes is not of type `datetime`, `date` or `pd.Timestamp`.
        ValueError: If the attributes are not ordered chronologically.
    """

    __slots__ = (  # noqa: RUF023
        "train_start",
        "train_end",
        "forecast_start",
        "forecast_end",
    )

    train_start: DateTimeLike
    train_end: DateTimeLike
    forecast_start: DateTimeLike
    forecast_end: DateTimeLike

    def __post_init__(self: Self) -> None:
        """Post init used to validate the `SplitState` instance attributes."""
        # Validate types
        _slots = self.__slots__
        _values = tuple(getattr(self, _attr) for _attr in _slots)
        _types = tuple(type(_value) for _value in _values)

        pd = get_pandas()

        if not (
            all(_type is datetime for _type in _types)
            or all(_type is date for _type in _types)
            or (pd is not None and all(_type is pd.Timestamp for _type in _types))
        ):
            # cfr: https://stackoverflow.com/questions/16991948/detect-if-a-variable-is-a-datetime-object
            msg = "All attributes must be of type `datetime`, `date` or `pd.Timestamp`."
            raise TypeError(msg)

        # Validate order
        _ordered = tuple(pairwise_comparison(_values, less_or_equal))

        if not all(_ordered):
            _error_msg = "\n".join(
                f"{s1}({v1}) is greater or equal to {s2}({v2})"
                for (s1, s2), (v1, v2), is_ordered in zip(pairwise(_slots), pairwise(_values), _ordered)
                if not is_ordered
            )
            msg = f"`{'`, `'.join(_slots)}` must be ordered. Found:\n{_error_msg}"
            raise ValueError(msg)

    @property
    def train_length(self: Self) -> relativedelta:
        """Returns the time between `train_start` and `train_end`.

        Returns:
            A `relativedelta` object representing the time between `train_start` and `train_end`.
        """
        return relativedelta(self.train_end, self.train_start)

    @property
    def forecast_length(self: Self) -> relativedelta:
        """Returns the time between `forecast_start` and `forecast_end`.

        Returns:
            A `relativedelta` object representing the time between `forecast_start` and `forecast_end`.
        """
        return relativedelta(self.forecast_end, self.forecast_start)

    @property
    def gap_length(self: Self) -> relativedelta:
        """Returns the time between `train_end` and `forecast_start`.

        Returns:
            A `relativedelta` object representing the time between `train_end` and `forecast_start`.
        """
        return relativedelta(self.forecast_start, self.train_end)

    @property
    def total_length(self: Self) -> relativedelta:
        """Returns the time between `train_start` and `forecast_end`.

        Returns:
            A `relativedelta` object representing the time between `train_start` and `forecast_end`.
        """
        return relativedelta(self.forecast_end, self.train_start)

    def __add__(self: Self, other: timedelta | relativedelta | pd.Timedelta) -> SplitState:
        """Adds `other` to each value of the state."""
        return SplitState(
            train_start=self.train_start + other,
            train_end=self.train_end + other,
            forecast_start=self.forecast_start + other,
            forecast_end=self.forecast_end + other,
        )

    def __sub__(self: Self, other: timedelta | relativedelta | pd.Timedelta) -> SplitState:
        """Subtracts other to each value of the state."""
        return SplitState(
            train_start=self.train_start - other,
            train_end=self.train_end - other,
            forecast_start=self.forecast_start - other,
            forecast_end=self.forecast_end - other,
        )

forecast_length property

forecast_length: relativedelta

Returns the time between forecast_start and forecast_end.

Returns:

Type Description
relativedelta

A relativedelta object representing the time between forecast_start and forecast_end.

gap_length property

gap_length: relativedelta

Returns the time between train_end and forecast_start.

Returns:

Type Description
relativedelta

A relativedelta object representing the time between train_end and forecast_start.

total_length property

total_length: relativedelta

Returns the time between train_start and forecast_end.

Returns:

Type Description
relativedelta

A relativedelta object representing the time between train_start and forecast_end.

train_length property

train_length: relativedelta

Returns the time between train_start and train_end.

Returns:

Type Description
relativedelta

A relativedelta object representing the time between train_start and train_end.

__add__

__add__(other: timedelta | relativedelta | Timedelta) -> SplitState

Adds other to each value of the state.

Source code in timebasedcv/splitstate.py
def __add__(self: Self, other: timedelta | relativedelta | pd.Timedelta) -> SplitState:
    """Adds `other` to each value of the state."""
    return SplitState(
        train_start=self.train_start + other,
        train_end=self.train_end + other,
        forecast_start=self.forecast_start + other,
        forecast_end=self.forecast_end + other,
    )

__post_init__

__post_init__() -> None

Post init used to validate the SplitState instance attributes.

Source code in timebasedcv/splitstate.py
def __post_init__(self: Self) -> None:
    """Post init used to validate the `SplitState` instance attributes."""
    # Validate types
    _slots = self.__slots__
    _values = tuple(getattr(self, _attr) for _attr in _slots)
    _types = tuple(type(_value) for _value in _values)

    pd = get_pandas()

    if not (
        all(_type is datetime for _type in _types)
        or all(_type is date for _type in _types)
        or (pd is not None and all(_type is pd.Timestamp for _type in _types))
    ):
        # cfr: https://stackoverflow.com/questions/16991948/detect-if-a-variable-is-a-datetime-object
        msg = "All attributes must be of type `datetime`, `date` or `pd.Timestamp`."
        raise TypeError(msg)

    # Validate order
    _ordered = tuple(pairwise_comparison(_values, less_or_equal))

    if not all(_ordered):
        _error_msg = "\n".join(
            f"{s1}({v1}) is greater or equal to {s2}({v2})"
            for (s1, s2), (v1, v2), is_ordered in zip(pairwise(_slots), pairwise(_values), _ordered)
            if not is_ordered
        )
        msg = f"`{'`, `'.join(_slots)}` must be ordered. Found:\n{_error_msg}"
        raise ValueError(msg)

__sub__

__sub__(other: timedelta | relativedelta | Timedelta) -> SplitState

Subtracts other to each value of the state.

Source code in timebasedcv/splitstate.py
def __sub__(self: Self, other: timedelta | relativedelta | pd.Timedelta) -> SplitState:
    """Subtracts other to each value of the state."""
    return SplitState(
        train_start=self.train_start - other,
        train_end=self.train_end - other,
        forecast_start=self.forecast_start - other,
        forecast_end=self.forecast_end - other,
    )