Source code for sparklightautoml.validation.base

from abc import ABC, abstractmethod
from contextlib import contextmanager
from copy import copy
from typing import Tuple, cast, Sequence, Optional, Union, Any

from lightautoml.dataset.base import LAMLDataset
from lightautoml.ml_algo.base import MLAlgo
from lightautoml.ml_algo.tuning.base import ParamsTuner
from lightautoml.pipelines.features.base import FeaturesPipeline
from lightautoml.pipelines.selection.base import SelectionPipeline, ImportanceEstimator
from lightautoml.validation.base import TrainValidIterator
from pyspark.sql import functions as sf

from sparklightautoml import VALIDATION_COLUMN
from sparklightautoml.dataset.base import SparkDataset, Unpersistable
from sparklightautoml.pipelines.features.base import SparkFeaturesPipeline
from sparklightautoml.utils import SparkDataFrame

TrainVal = SparkDataset


def mark_as_train(sdf: SparkDataFrame, is_val_col: str):
    return sdf.withColumn(is_val_col, sf.lit(0))


def mark_as_val(sdf: SparkDataFrame, is_val_col: str):
    return sdf.withColumn(is_val_col, sf.lit(1))


def split_out_train(sdf: SparkDataFrame, is_val_col: str):
    return sdf.where(sf.col(is_val_col) == 0).drop(is_val_col)


def split_out_val(sdf: SparkDataFrame, is_val_col: str):
    return sdf.where(sf.col(is_val_col) == 1).drop(is_val_col)


class SparkSelectionPipeline(SelectionPipeline, ABC):
    def __init__(self,
                 features_pipeline: Optional[FeaturesPipeline] = None,
                 ml_algo: Optional[Union[MLAlgo, Tuple[MLAlgo, ParamsTuner]]] = None,
                 imp_estimator: Optional[ImportanceEstimator] = None,
                 fit_on_holdout: bool = False,
                 **kwargs: Any):
        super().__init__(features_pipeline, ml_algo, imp_estimator, fit_on_holdout, **kwargs)


[docs]class SparkBaseTrainValidIterator(TrainValidIterator, Unpersistable, ABC): """ Implements applying selection pipeline and feature pipeline to SparkDataset. """ TRAIN_VAL_COLUMN = VALIDATION_COLUMN def __init__(self, train: SparkDataset): assert train.folds_column in train.data.columns super().__init__(train) self.train = cast(SparkDataset, train) def __next__(self) -> TrainVal: """Define how to get next object. Returns: a tuple with: - train part of the dataset - validation part of the dataset. """ ... def __getitem__(self, fold_id: int) -> SparkDataset: ... @contextmanager def frozen(self) -> 'SparkBaseTrainValidIterator': yield self.freeze() @abstractmethod def freeze(self) -> 'SparkBaseTrainValidIterator': ... @abstractmethod def unpersist(self, skip_val: bool = False): ... @abstractmethod def get_validation_data(self) -> SparkDataset: ...
[docs] def apply_selector(self, selector: SparkSelectionPipeline) -> "SparkBaseTrainValidIterator": """Select features on train data. Check if selector is fitted. If not - fit and then perform selection. If fitted, check if it's ok to apply. Args: selector: Uses for feature selection. Returns: Dataset with selected features. """ if not selector.is_fitted: with self._child_persistence_context() as sel_train_valid: selector.fit(sel_train_valid) train_valid = copy(self) train_valid.train = selector.select(cast(SparkDataset, self.train)) return train_valid
def apply_feature_pipeline( self, features_pipeline: SparkFeaturesPipeline) -> "SparkBaseTrainValidIterator": train_valid = copy(self) train_valid.train = features_pipeline.fit_transform(train_valid.train) return train_valid def _validate_fold_id(self, fold_id: int): assert 0 <= fold_id < len(self) def _split_by_fold(self, fold: int) -> Tuple[SparkDataset, SparkDataset, SparkDataset]: train = cast(SparkDataset, self.train) is_val_col = ( sf.when(sf.col(self.train.folds_column) != fold, sf.lit(0)).otherwise(sf.lit(1)) .alias(self.TRAIN_VAL_COLUMN) ) sdf = train.data.select("*", is_val_col) train_part_sdf = split_out_train(sdf, self.TRAIN_VAL_COLUMN) valid_part_sdf = split_out_val(sdf, self.TRAIN_VAL_COLUMN) train_ds = cast(SparkDataset, self.train.empty()) train_ds.set_data(sdf, self.train.features, self.train.roles, name=self.train.name) train_part_ds = cast(SparkDataset, self.train.empty()) train_part_ds.set_data( train_part_sdf, self.train.features, self.train.roles, name=f"{self.train.name}_train_{fold}" ) valid_part_ds = cast(SparkDataset, self.train.empty()) valid_part_ds.set_data( valid_part_sdf, self.train.features, self.train.roles, name=f"{self.train.name}_val_{fold}" ) return train_ds, train_part_ds, valid_part_ds @contextmanager def _child_persistence_context(self) -> 'SparkBaseTrainValidIterator': train_valid = copy(self) train = train_valid.train.empty() pm = train_valid.train.persistence_manager child_manager = pm.child() train.set_data( train_valid.train.data, train_valid.train.features, train_valid.train.roles, persistence_manager=child_manager, dependencies=[] ) train_valid.train = train yield train_valid child_manager.unpersist_all() pm.remove_child(child_manager)