Bases: Generic[DateTimeLike]
A SplitState
represents the state of a split in terms of its four cut/split points.
Namely these are start and end of training set, start and end of forecasting/test set.
The class ensures that the split is valid by checking that the attributes are of the correct type and are ordered
chronologically.
The class provides properties to calculate the length of the training set, forecast set, gap between them, and the
total length of the split.
Parameters:
Name |
Type |
Description |
Default |
train_start
|
DateTimeLike
|
The start of the training set.
|
required
|
train_end
|
DateTimeLike
|
The end of the training set.
|
required
|
forecast_start
|
DateTimeLike
|
The start of the forecast set.
|
required
|
forecast_end
|
DateTimeLike
|
The end of the forecast set.
|
required
|
Raises:
Type |
Description |
TypeError
|
If any of the attributes is not of type datetime , date or pd.Timestamp .
|
ValueError
|
If the attributes are not ordered chronologically.
|
Source code in timebasedcv/splitstate.py
| @dataclass(frozen=True)
class SplitState(Generic[DateTimeLike]):
"""A `SplitState` represents the state of a split in terms of its four cut/split points.
Namely these are start and end of training set, start and end of forecasting/test set.
The class ensures that the split is valid by checking that the attributes are of the correct type and are ordered
chronologically.
The class provides properties to calculate the length of the training set, forecast set, gap between them, and the
total length of the split.
Arguments:
train_start: The start of the training set.
train_end: The end of the training set.
forecast_start: The start of the forecast set.
forecast_end: The end of the forecast set.
Raises:
TypeError: If any of the attributes is not of type `datetime`, `date` or `pd.Timestamp`.
ValueError: If the attributes are not ordered chronologically.
"""
__slots__ = ( # noqa: RUF023
"train_start",
"train_end",
"forecast_start",
"forecast_end",
)
train_start: DateTimeLike
train_end: DateTimeLike
forecast_start: DateTimeLike
forecast_end: DateTimeLike
def __post_init__(self: Self) -> None:
"""Post init used to validate the `SplitState` instance attributes."""
# Validate types
_slots = self.__slots__
_values = tuple(getattr(self, _attr) for _attr in _slots)
_types = tuple(type(_value) for _value in _values)
pd = get_pandas()
if not (
all(_type is datetime for _type in _types)
or all(_type is date for _type in _types)
or (pd is not None and all(_type is pd.Timestamp for _type in _types))
):
# cfr: https://stackoverflow.com/questions/16991948/detect-if-a-variable-is-a-datetime-object
msg = "All attributes must be of type `datetime`, `date` or `pd.Timestamp`."
raise TypeError(msg)
# Validate order
_ordered = tuple(pairwise_comparison(_values, less_or_equal))
if not all(_ordered):
_error_msg = "\n".join(
f"{s1}({v1}) is greater or equal to {s2}({v2})"
for (s1, s2), (v1, v2), is_ordered in zip(pairwise(_slots), pairwise(_values), _ordered)
if not is_ordered
)
msg = f"`{'`, `'.join(_slots)}` must be ordered. Found:\n{_error_msg}"
raise ValueError(msg)
@property
def train_length(self: Self) -> relativedelta:
"""Returns the time between `train_start` and `train_end`.
Returns:
A `relativedelta` object representing the time between `train_start` and `train_end`.
"""
return relativedelta(self.train_end, self.train_start)
@property
def forecast_length(self: Self) -> relativedelta:
"""Returns the time between `forecast_start` and `forecast_end`.
Returns:
A `relativedelta` object representing the time between `forecast_start` and `forecast_end`.
"""
return relativedelta(self.forecast_end, self.forecast_start)
@property
def gap_length(self: Self) -> relativedelta:
"""Returns the time between `train_end` and `forecast_start`.
Returns:
A `relativedelta` object representing the time between `train_end` and `forecast_start`.
"""
return relativedelta(self.forecast_start, self.train_end)
@property
def total_length(self: Self) -> relativedelta:
"""Returns the time between `train_start` and `forecast_end`.
Returns:
A `relativedelta` object representing the time between `train_start` and `forecast_end`.
"""
return relativedelta(self.forecast_end, self.train_start)
def __add__(self: Self, other: timedelta | relativedelta | pd.Timedelta) -> SplitState:
"""Adds `other` to each value of the state."""
return SplitState(
train_start=self.train_start + other,
train_end=self.train_end + other,
forecast_start=self.forecast_start + other,
forecast_end=self.forecast_end + other,
)
def __sub__(self: Self, other: timedelta | relativedelta | pd.Timedelta) -> SplitState:
"""Subtracts other to each value of the state."""
return SplitState(
train_start=self.train_start - other,
train_end=self.train_end - other,
forecast_start=self.forecast_start - other,
forecast_end=self.forecast_end - other,
)
|
forecast_length
property
forecast_length: relativedelta
Returns the time between forecast_start
and forecast_end
.
Returns:
Type |
Description |
relativedelta
|
A relativedelta object representing the time between forecast_start and forecast_end .
|
gap_length
property
gap_length: relativedelta
Returns the time between train_end
and forecast_start
.
Returns:
Type |
Description |
relativedelta
|
A relativedelta object representing the time between train_end and forecast_start .
|
total_length
property
total_length: relativedelta
Returns the time between train_start
and forecast_end
.
Returns:
Type |
Description |
relativedelta
|
A relativedelta object representing the time between train_start and forecast_end .
|
train_length
property
train_length: relativedelta
Returns the time between train_start
and train_end
.
Returns:
Type |
Description |
relativedelta
|
A relativedelta object representing the time between train_start and train_end .
|
__add__
__add__(other: timedelta | relativedelta | Timedelta) -> SplitState
Adds other
to each value of the state.
Source code in timebasedcv/splitstate.py
| def __add__(self: Self, other: timedelta | relativedelta | pd.Timedelta) -> SplitState:
"""Adds `other` to each value of the state."""
return SplitState(
train_start=self.train_start + other,
train_end=self.train_end + other,
forecast_start=self.forecast_start + other,
forecast_end=self.forecast_end + other,
)
|
__post_init__
Post init used to validate the SplitState
instance attributes.
Source code in timebasedcv/splitstate.py
| def __post_init__(self: Self) -> None:
"""Post init used to validate the `SplitState` instance attributes."""
# Validate types
_slots = self.__slots__
_values = tuple(getattr(self, _attr) for _attr in _slots)
_types = tuple(type(_value) for _value in _values)
pd = get_pandas()
if not (
all(_type is datetime for _type in _types)
or all(_type is date for _type in _types)
or (pd is not None and all(_type is pd.Timestamp for _type in _types))
):
# cfr: https://stackoverflow.com/questions/16991948/detect-if-a-variable-is-a-datetime-object
msg = "All attributes must be of type `datetime`, `date` or `pd.Timestamp`."
raise TypeError(msg)
# Validate order
_ordered = tuple(pairwise_comparison(_values, less_or_equal))
if not all(_ordered):
_error_msg = "\n".join(
f"{s1}({v1}) is greater or equal to {s2}({v2})"
for (s1, s2), (v1, v2), is_ordered in zip(pairwise(_slots), pairwise(_values), _ordered)
if not is_ordered
)
msg = f"`{'`, `'.join(_slots)}` must be ordered. Found:\n{_error_msg}"
raise ValueError(msg)
|
__sub__
__sub__(other: timedelta | relativedelta | Timedelta) -> SplitState
Subtracts other to each value of the state.
Source code in timebasedcv/splitstate.py
| def __sub__(self: Self, other: timedelta | relativedelta | pd.Timedelta) -> SplitState:
"""Subtracts other to each value of the state."""
return SplitState(
train_start=self.train_start - other,
train_end=self.train_end - other,
forecast_start=self.forecast_start - other,
forecast_end=self.forecast_end - other,
)
|