Source code for sparklightautoml.validation.iterators

import logging
from typing import Optional, cast, Iterable

from pyspark.sql import functions as sf

from sparklightautoml.dataset.base import SparkDataset
from sparklightautoml.utils import SparkDataFrame
from sparklightautoml.validation.base import SparkBaseTrainValidIterator, TrainVal

logger = logging.getLogger(__name__)


[docs]class SparkDummyIterator(SparkBaseTrainValidIterator): """ Simple one step iterator over train part of SparkDataset """ def __init__(self, train: SparkDataset): super().__init__(train) self._curr_idx = 0 def __iter__(self) -> Iterable: self._curr_idx = 0 return self def __len__(self) -> Optional[int]: return 1 def __getitem__(self, fold_id: int) -> SparkDataset: self._validate_fold_id(fold_id) return super().__getitem__(fold_id) def __next__(self) -> TrainVal: """Define how to get next object. Returns: None, train dataset, validation dataset. """ if self._curr_idx > 0: raise StopIteration self._curr_idx += 1 sdf = cast(SparkDataFrame, self.train.data) sdf = sdf.withColumn(self.TRAIN_VAL_COLUMN, sf.lit(0)) train_ds = cast(SparkDataset, self.train.empty()) train_ds.set_data(sdf, self.train.features, self.train.roles, name=self.train.name) return train_ds def freeze(self) -> 'SparkDummyIterator': return SparkDummyIterator(self.train.freeze()) def unpersist(self, skip_val: bool = False): if not skip_val: self.train.unpersist() def get_validation_data(self) -> SparkDataset: return self.train def convert_to_holdout_iterator(self) -> "SparkHoldoutIterator": sds = cast(SparkDataset, self.train) assert sds.folds_column is not None, "Cannot convert to Holdout iterator when folds_column is not defined" return SparkHoldoutIterator(self.train)
[docs]class SparkHoldoutIterator(SparkBaseTrainValidIterator): """Simple one step iterator over one fold of SparkDataset""" def __init__(self, train: SparkDataset): assert self.TRAIN_VAL_COLUMN in train.data.columns, \ f"Cannot accept dataset without explicit '{self.TRAIN_VAL_COLUMN}' column" super().__init__(train) self._curr_idx = 0 def __iter__(self) -> Iterable: self._curr_idx = 0 return self def __len__(self) -> Optional[int]: return 1 def __getitem__(self, fold_id: int) -> SparkDataset: self._validate_fold_id(fold_id) return self.train def __next__(self) -> TrainVal: """Define how to get next object. Returns: None, train dataset, validation dataset. """ if self._curr_idx > 0: raise StopIteration # full_ds, train_part_ds, valid_part_ds = self._split_by_fold(self._curr_idx) self._curr_idx += 1 return self.train def freeze(self) -> 'SparkHoldoutIterator': return SparkHoldoutIterator(self.train.freeze()) def unpersist(self, skip_val: bool = False): if not skip_val: self.train.unpersist() def get_validation_data(self) -> SparkDataset: valid_sdf = self.train.data.where(sf.col(self.TRAIN_VAL_COLUMN) == 1).drop(self.TRAIN_VAL_COLUMN) valid = self.train.empty() valid.set_data(valid_sdf, self.train.features, self.train.roles) return valid def convert_to_holdout_iterator(self) -> "SparkHoldoutIterator": return self
[docs]class SparkFoldsIterator(SparkBaseTrainValidIterator): """Classic cv iterator. Folds should be defined in Reader, based on cross validation method. """
[docs] def __init__(self, train: SparkDataset, n_folds: Optional[int] = None): """Creates iterator. Args: train: Dataset for folding. n_folds: Number of folds. """ super().__init__(train) # TODO: PARALLEL - potential bug here num_folds = train.data.select(sf.max(train.folds_column).alias("max")).first()["max"] self.n_folds = num_folds + 1 if n_folds is not None: self.n_folds = min(self.n_folds, n_folds) self._base_train_frozen = train.frozen self._train_frozen = self._base_train_frozen self._val_frozen = self._base_train_frozen
def __len__(self) -> int: """Get len of iterator. Returns: Number of folds. """ return self.n_folds def __getitem__(self, fold_id: int) -> SparkDataset: self._validate_fold_id(fold_id) full_ds_with_is_val_col, _, _ = self._split_by_fold(fold_id) return full_ds_with_is_val_col def __iter__(self) -> "SparkFoldsIterator": """Set counter to 0 and return self. Returns: Iterator for folds. """ logger.debug("Creating folds iterator") self._curr_idx = 0 return self def __next__(self) -> TrainVal: """Define how to get next object. Returns: None, train dataset, validation dataset. """ logger.debug(f"The next valid fold num: {self._curr_idx}") if self._curr_idx == self.n_folds: logger.debug("No more folds to continue, stopping iterations") raise StopIteration full_ds_with_is_val_col, _, _ = self._split_by_fold(self._curr_idx) self._curr_idx += 1 return full_ds_with_is_val_col def freeze(self) -> 'SparkFoldsIterator': return SparkFoldsIterator(self.train.freeze(), n_folds=self.n_folds) def unpersist(self, skip_val: bool = False): if not skip_val: self.train.unpersist() def get_validation_data(self) -> SparkDataset: return self.train
[docs] def convert_to_holdout_iterator(self) -> SparkHoldoutIterator: """Convert iterator to hold-out-iterator. Fold 0 is used for validation, everything else is used for training. Returns: new hold-out-iterator. """ full_with_is_val_column, _, _ = self._split_by_fold(0) return SparkHoldoutIterator(full_with_is_val_column)