Source code for sparklightautoml.pipelines.ml.nested_ml_pipe

from typing import Union, Tuple, Optional, Sequence

from lightautoml.ml_algo.tuning.base import ParamsTuner, DefaultTuner
from lightautoml.pipelines.ml.nested_ml_pipe import (
    NestedTabularMLPipeline as LAMANestedTabularMLPipeline,
    NestedTabularMLAlgo,
)

from sparklightautoml.computations.base import ComputationsSettings
from sparklightautoml.ml_algo.base import SparkTabularMLAlgo
from sparklightautoml.pipelines.features.base import SparkFeaturesPipeline
from sparklightautoml.pipelines.ml.base import SparkMLPipeline
from sparklightautoml.pipelines.selection.base import SparkSelectionPipelineWrapper


[docs]class SparkNestedTabularMLPipeline(SparkMLPipeline, LAMANestedTabularMLPipeline): """ Same as NestedTabularMLPipeline of LAMA, but redefines a couple of methods via SparkMLPipelineMixin """ def __init__( self, ml_algos: Sequence[Union[SparkTabularMLAlgo, Tuple[SparkTabularMLAlgo, ParamsTuner]]], force_calc: Union[bool, Sequence[bool]] = True, pre_selection: Optional[SparkSelectionPipelineWrapper] = None, features_pipeline: Optional[SparkFeaturesPipeline] = None, post_selection: Optional[SparkSelectionPipelineWrapper] = None, cv: int = 1, n_folds: Optional[int] = None, inner_tune: bool = False, refit_tuner: bool = False, computations_settings: Optional[ComputationsSettings] = None ): if cv > 1: new_ml_algos = [] for n, mt_pair in enumerate(ml_algos): try: mod, tuner = mt_pair except (TypeError, ValueError): mod, tuner = mt_pair, DefaultTuner() if inner_tune: new_ml_algos.append(NestedTabularMLAlgo(mod, tuner, refit_tuner, cv, n_folds)) else: new_ml_algos.append((NestedTabularMLAlgo(mod, None, True, cv, n_folds), tuner)) ml_algos = new_ml_algos super().__init__(ml_algos, force_calc, pre_selection, features_pipeline, post_selection, computations_settings)