83 lines
2.9 KiB
Python
83 lines
2.9 KiB
Python
|
from typing import Tuple
|
||
|
|
||
|
import pandas as pd
|
||
|
from pandas import DataFrame
|
||
|
from sklearn.model_selection import train_test_split
|
||
|
|
||
|
|
||
|
def split_stratified_into_train_val_test(
|
||
|
df_input,
|
||
|
target_colname="z",
|
||
|
stratify_colname="y",
|
||
|
frac_train=0.6,
|
||
|
frac_val=0.15,
|
||
|
frac_test=0.25,
|
||
|
random_state=None,
|
||
|
) -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:
|
||
|
"""
|
||
|
Splits a Pandas dataframe into three subsets (train, val, and test)
|
||
|
following fractional ratios provided by the user, where each subset is
|
||
|
stratified by the values in a specific column (that is, each subset has
|
||
|
the same relative frequency of the values in the column). It performs this
|
||
|
splitting by running train_test_split() twice.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
df_input : Pandas dataframe
|
||
|
Input dataframe to be split.
|
||
|
stratify_colname : str
|
||
|
The name of the column that will be used for stratification. Usually
|
||
|
this column would be for the label.
|
||
|
frac_train : float
|
||
|
frac_val : float
|
||
|
frac_test : float
|
||
|
The ratios with which the dataframe will be split into train, val, and
|
||
|
test data. The values should be expressed as float fractions and should
|
||
|
sum to 1.0.
|
||
|
random_state : int, None, or RandomStateInstance
|
||
|
Value to be passed to train_test_split().
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
df_train, df_val, df_test :
|
||
|
Dataframes containing the three splits.
|
||
|
"""
|
||
|
|
||
|
if frac_train + frac_val + frac_test != 1.0:
|
||
|
raise ValueError(
|
||
|
"fractions %f, %f, %f do not add up to 1.0"
|
||
|
% (frac_train, frac_val, frac_test)
|
||
|
)
|
||
|
|
||
|
if stratify_colname not in df_input.columns:
|
||
|
raise ValueError("%s is not a column in the dataframe" % (stratify_colname))
|
||
|
|
||
|
if target_colname not in df_input.columns:
|
||
|
raise ValueError("%s is not a column in the dataframe" % (target_colname))
|
||
|
|
||
|
X = df_input # Contains all columns.
|
||
|
y = df_input[[target_colname]] # Dataframe of just the column on which to stratify.
|
||
|
z = df_input[[stratify_colname]]
|
||
|
|
||
|
# Split original dataframe into train and temp dataframes.
|
||
|
df_train, df_temp, y_train, y_temp = train_test_split(
|
||
|
X, y, stratify=z, test_size=(1.0 - frac_train), random_state=random_state
|
||
|
)
|
||
|
|
||
|
if frac_val <= 0:
|
||
|
assert len(df_input) == len(df_train) + len(df_temp)
|
||
|
return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp
|
||
|
|
||
|
# Split the temp dataframe into val and test dataframes.
|
||
|
relative_frac_test = frac_test / (frac_val + frac_test)
|
||
|
df_val, df_test, y_val, y_test = train_test_split(
|
||
|
df_temp,
|
||
|
y_temp,
|
||
|
stratify=df_temp[[stratify_colname]],
|
||
|
test_size=relative_frac_test,
|
||
|
random_state=random_state,
|
||
|
)
|
||
|
|
||
|
assert len(df_input) == len(df_train) + len(df_val) + len(df_test)
|
||
|
return df_train, df_val, df_test, y_train, y_val, y_test
|