Skip to content

timebasedcv.splitstate

timebasedcv.splitstate.SplitState dataclass

Bases: Generic[DateTimeLike]

A SplitState represents the state of a split in terms of its 4 cut/split points.

Namely these are start and end of training set, start and end of forecasting/test set.

The class ensures that the split is valid by checking that the attributes are of the correct type and are ordered chronologically.

The class provides properties to calculate the length of the training set, forecast set, gap between them, and the total length of the split.

Parameters:

Name Type Description Default
train_start DateTimeLike

The start of the training set.

required
train_end DateTimeLike

The end of the training set.

required
forecast_start DateTimeLike

The start of the forecast set.

required
forecast_end DateTimeLike

The end of the forecast set.

required

Raises:

Type Description
TypeError

If any of the attributes is not of type datetime, date or pd.Timestamp.

ValueError

If the attributes are not ordered chronologically.

Source code in timebasedcv/splitstate.py
@dataclass(frozen=True)
class SplitState(Generic[DateTimeLike]):
    """A `SplitState` represents the state of a split in terms of its 4 cut/split points.

    Namely these are start and end of training set, start and end of forecasting/test set.

    The class ensures that the split is valid by checking that the attributes are of the correct type and are ordered
    chronologically.

    The class provides properties to calculate the length of the training set, forecast set, gap between them, and the
    total length of the split.

    Arguments:
        train_start: The start of the training set.
        train_end: The end of the training set.
        forecast_start: The start of the forecast set.
        forecast_end: The end of the forecast set.

    Raises:
        TypeError: If any of the attributes is not of type `datetime`, `date` or `pd.Timestamp`.
        ValueError: If the attributes are not ordered chronologically.
    """

    __slots__ = (
        "train_start",
        "train_end",
        "forecast_start",
        "forecast_end",
    )

    train_start: DateTimeLike
    train_end: DateTimeLike
    forecast_start: DateTimeLike
    forecast_end: DateTimeLike

    def __post_init__(self: Self) -> None:
        """Post init used to validate the `SplitState` instance attributes."""
        # Validate types
        _slots = self.__slots__
        _values = tuple(getattr(self, _attr) for _attr in _slots)
        _types = tuple(type(_value) for _value in _values)

        pd = get_pandas()

        if not (
            all(_type is datetime for _type in _types)
            or all(_type is date for _type in _types)
            or (pd is not None and all(_type is pd.Timestamp for _type in _types))
        ):
            # cfr: https://stackoverflow.com/questions/16991948/detect-if-a-variable-is-a-datetime-object
            msg = "All attributes must be of type `datetime`, `date` or `pd.Timestamp`."
            raise TypeError(msg)

        # Validate order
        _ordered = tuple(pairwise_comparison(_values, less_or_equal))

        if not all(_ordered):
            _error_msg = "\n".join(
                f"{s1}({v1}) is greater or equal to {s2}({v2})"
                for (s1, s2), (v1, v2), is_ordered in zip(pairwise(_slots), pairwise(_values), _ordered)
                if not is_ordered
            )
            msg = f"`{'`, `'.join(_slots)}` must be ordered. Found:\n{_error_msg}"
            raise ValueError(msg)

    @property
    def train_length(self: Self) -> timedelta:
        """Returns the time between `train_start` and `train_end`.

        Returns:
            A `timedelta` object representing the time between `train_start` and `train_end`.
        """
        return self.train_end - self.train_start

    @property
    def forecast_length(self: Self) -> timedelta:
        """Returns the time between `forecast_start` and `forecast_end`.

        Returns:
            A `timedelta` object representing the time between `forecast_start` and `forecast_end`.
        """
        return self.forecast_end - self.forecast_start

    @property
    def gap_length(self: Self) -> timedelta:
        """Returns the time between `train_end` and `forecast_start`.

        Returns:
            A `timedelta` object representing the time between `train_end` and `forecast_start`.
        """
        return self.forecast_start - self.train_end

    @property
    def total_length(self: Self) -> timedelta:
        """Returns the time between `train_start` and `forecast_end`.

        Returns:
            A `timedelta` object representing the time between `train_start` and `forecast_end`.
        """
        return self.forecast_end - self.train_start

    def __add__(self: Self, other: Union[timedelta, pd.Timedelta]) -> SplitState:
        """Adds `other` to each value of the state."""
        return SplitState(
            train_start=self.train_start + other,
            train_end=self.train_end + other,
            forecast_start=self.forecast_start + other,
            forecast_end=self.forecast_end + other,
        )

    def __sub__(self: Self, other: Union[timedelta, pd.Timedelta]) -> SplitState:
        """Subtracts other to each value of the state."""
        return SplitState(
            train_start=self.train_start - other,
            train_end=self.train_end - other,
            forecast_start=self.forecast_start - other,
            forecast_end=self.forecast_end - other,
        )

forecast_length: timedelta property

Returns the time between forecast_start and forecast_end.

Returns:

Type Description
timedelta

A timedelta object representing the time between forecast_start and forecast_end.

gap_length: timedelta property

Returns the time between train_end and forecast_start.

Returns:

Type Description
timedelta

A timedelta object representing the time between train_end and forecast_start.

total_length: timedelta property

Returns the time between train_start and forecast_end.

Returns:

Type Description
timedelta

A timedelta object representing the time between train_start and forecast_end.

train_length: timedelta property

Returns the time between train_start and train_end.

Returns:

Type Description
timedelta

A timedelta object representing the time between train_start and train_end.

__add__(other)

Adds other to each value of the state.

Source code in timebasedcv/splitstate.py
def __add__(self: Self, other: Union[timedelta, pd.Timedelta]) -> SplitState:
    """Adds `other` to each value of the state."""
    return SplitState(
        train_start=self.train_start + other,
        train_end=self.train_end + other,
        forecast_start=self.forecast_start + other,
        forecast_end=self.forecast_end + other,
    )

__post_init__()

Post init used to validate the SplitState instance attributes.

Source code in timebasedcv/splitstate.py
def __post_init__(self: Self) -> None:
    """Post init used to validate the `SplitState` instance attributes."""
    # Validate types
    _slots = self.__slots__
    _values = tuple(getattr(self, _attr) for _attr in _slots)
    _types = tuple(type(_value) for _value in _values)

    pd = get_pandas()

    if not (
        all(_type is datetime for _type in _types)
        or all(_type is date for _type in _types)
        or (pd is not None and all(_type is pd.Timestamp for _type in _types))
    ):
        # cfr: https://stackoverflow.com/questions/16991948/detect-if-a-variable-is-a-datetime-object
        msg = "All attributes must be of type `datetime`, `date` or `pd.Timestamp`."
        raise TypeError(msg)

    # Validate order
    _ordered = tuple(pairwise_comparison(_values, less_or_equal))

    if not all(_ordered):
        _error_msg = "\n".join(
            f"{s1}({v1}) is greater or equal to {s2}({v2})"
            for (s1, s2), (v1, v2), is_ordered in zip(pairwise(_slots), pairwise(_values), _ordered)
            if not is_ordered
        )
        msg = f"`{'`, `'.join(_slots)}` must be ordered. Found:\n{_error_msg}"
        raise ValueError(msg)

__sub__(other)

Subtracts other to each value of the state.

Source code in timebasedcv/splitstate.py
def __sub__(self: Self, other: Union[timedelta, pd.Timedelta]) -> SplitState:
    """Subtracts other to each value of the state."""
    return SplitState(
        train_start=self.train_start - other,
        train_end=self.train_end - other,
        forecast_start=self.forecast_start - other,
        forecast_end=self.forecast_end - other,
    )