Source code for sparklightautoml.transformers.scala_wrappers.laml_string_indexer

from pyspark import since, keyword_only, SparkContext
from pyspark.ml.param.shared import (
    HasInputCol,
    HasOutputCol,
    HasInputCols,
    HasOutputCols,
    HasHandleInvalid,
    TypeConverters,
    Param,
    Params,
)
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
from pyspark.ml.common import inherit_doc

from sparklightautoml.mlwriters import CommonJavaToPythonMLReadable


class _StringIndexerParams(JavaParams, HasHandleInvalid, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols):
    """
    Params for :py:class:`StringIndexer` and :py:class:`StringIndexerModel`.
    """

    stringOrderType = Param(
        Params._dummy(),
        "stringOrderType",
        "How to order labels of string column. The first label after "
        + "ordering is assigned an index of 0. Supported options: "
        + "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc. "
        + "Default is frequencyDesc. In case of equal frequency when "
        + "under frequencyDesc/Asc, the strings are further sorted "
        + "alphabetically",
        typeConverter=TypeConverters.toString,
    )

    handleInvalid = Param(
        Params._dummy(),
        "handleInvalid",
        "how to handle invalid data (unseen "
        + "or NULL values) in features and label column of string type. "
        + "Options are 'skip' (filter out rows with invalid data), "
        + "error (throw an error), or 'keep' (put invalid data "
        + "in a special additional bucket, at index numLabels).",
        typeConverter=TypeConverters.toString,
    )

    minFreqs = Param(
        Params._dummy(),
        "minFreqs",
        "The minimum number of the element occurrences not to skip it.",
        typeConverter=TypeConverters.toListInt,
    )

    defaultValue = Param(
        Params._dummy(),
        "defaultValue",
        "The index for unknown labels and labels that are too rare.",
        typeConverter=TypeConverters.toFloat,
    )

    freqLabel = Param(
        Params._dummy(),
        "freqLabel",
        "If true, label will be transformed to its occurrences",
        typeConverter=TypeConverters.toBoolean,
    )

    nanLast = Param(
        Params._dummy(),
        "nanLast",
        "If true, appends 'NaN' label to the end of each mapping",
        typeConverter=TypeConverters.toBoolean,
    )

    def __init__(self, *args):
        super(_StringIndexerParams, self).__init__(*args)
        self._setDefault(
            handleInvalid="error",
            stringOrderType="frequencyDesc",
            minFreqs=tuple(),
            defaultValue=0.0,
            freqLabel=False,
            nanLast=False,
        )

    @since("2.3.0")
    def getStringOrderType(self):
        """
        Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'.
        """
        return self.getOrDefault(self.stringOrderType)


class _StringIndexerModelParams(JavaParams, HasHandleInvalid, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols):
    """
    Params for :py:class:`StringIndexer` and :py:class:`StringIndexerModel`.
    """

    # stringOrderType = Param(Params._dummy(), "stringOrderType",
    #                         "How to order labels of string column. The first label after " +
    #                         "ordering is assigned an index of 0. Supported options: " +
    #                         "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc. " +
    #                         "Default is frequencyDesc. In case of equal frequency when " +
    #                         "under frequencyDesc/Asc, the strings are further sorted " +
    #                         "alphabetically",
    #                         typeConverter=TypeConverters.toString)

    handleInvalid = Param(
        Params._dummy(),
        "handleInvalid",
        "how to handle invalid data (unseen "
        + "or NULL values) in features and label column of string type. "
        + "Options are 'skip' (filter out rows with invalid data), "
        + "error (throw an error), or 'keep' (put invalid data "
        + "in a special additional bucket, at index numLabels).",
        typeConverter=TypeConverters.toString,
    )

    # minFreqs = Param(Params._dummy(),
    #                  "minFreqs",
    #                  "The minimum number of the element occurrences not to skip it.",
    #                  typeConverter=TypeConverters.toListInt)

    defaultValue = Param(
        Params._dummy(),
        "defaultValue",
        "The index for unknown labels and labels that are too rare.",
        typeConverter=TypeConverters.toFloat,
    )

    freqLabel = Param(
        Params._dummy(),
        "freqLabel",
        "If true, label will be transformed to its occurrences",
        typeConverter=TypeConverters.toBoolean,
    )

    nanLast = Param(
        Params._dummy(),
        "nanLast",
        "If true, appends 'NaN' label to the end of each mapping",
        typeConverter=TypeConverters.toBoolean,
    )

    def __init__(self, *args):
        super(_StringIndexerModelParams, self).__init__(*args)
        self._setDefault(handleInvalid="error", defaultValue=0.0, freqLabel=False, nanLast=False)


[docs]@inherit_doc class LAMLStringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLWritable): """ Custom implementation of PySpark StringIndexer wrapper """
[docs] @keyword_only def __init__( self, *, inputCol=None, outputCol=None, inputCols=None, outputCols=None, handleInvalid="error", stringOrderType="frequencyDesc", minFreqs=None, defaultValue=0.0, freqLabel=False, nanLast=False ): """ __init__(self, \\*, inputCol=None, outputCol=None, inputCols=None, outputCols=None, \ handleInvalid="error", stringOrderType="frequencyDesc") """ super(LAMLStringIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.lightautoml.LAMLStringIndexer", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs)
[docs] @keyword_only @since("1.4.0") def setParams( self, *, inputCol=None, outputCol=None, inputCols=None, outputCols=None, handleInvalid="error", stringOrderType="frequencyDesc", minFreqs=None, defaultValue=0.0, freqLabel=False, nanLast=False ): """ setParams(self, \\*, inputCol=None, outputCol=None, inputCols=None, outputCols=None, \ handleInvalid="error", stringOrderType="frequencyDesc") Sets params for this StringIndexer. """ kwargs = self._input_kwargs return self._set(**kwargs)
def _create_model(self, java_model): return LAMLStringIndexerModel(java_model)
[docs] @since("2.3.0") def setStringOrderType(self, value): """ Sets the value of :py:attr:`stringOrderType`. """ return self._set(stringOrderType=value)
[docs] def setInputCol(self, value): """ Sets the value of :py:attr:`inputCol`. """ return self._set(inputCol=value)
[docs] @since("3.0.0") def setInputCols(self, value): """ Sets the value of :py:attr:`inputCols`. """ return self._set(inputCols=value)
[docs] def setOutputCol(self, value): """ Sets the value of :py:attr:`outputCol`. """ return self._set(outputCol=value)
[docs] @since("3.0.0") def setOutputCols(self, value): """ Sets the value of :py:attr:`outputCols`. """ return self._set(outputCols=value)
[docs] def setHandleInvalid(self, value): """ Sets the value of :py:attr:`handleInvalid`. """ return self._set(handleInvalid=value)
[docs] @since("3.2.0") def setMinFreqs(self, value): """ Sets the value of :py:attr:`minFreqs`. """ return self._set(minFreqs=value)
[docs] @since("3.2.0") def setDefaultValue(self, value): """ Sets the value of :py:attr:`defaultValue`. """ return self._set(defaultValue=value)
[docs] @since("3.2.0") def setFreqLabel(self, value): """ Sets the value of :py:attr:`freqLabel`. """ return self._set(freqLabel=value)
[docs] @since("3.2.0") def setNanLast(self, value): """ Sets the value of :py:attr:`nanLast`. """ return self._set(nanLast=value)
[docs]class LAMLStringIndexerModel( JavaModel, _StringIndexerModelParams, CommonJavaToPythonMLReadable, JavaMLWritable ): """ Model fitted by :py:class:`StringIndexer`. .. versionadded:: 1.4.0 """
[docs] def setInputCol(self, value): """ Sets the value of :py:attr:`inputCol`. """ return self._set(inputCol=value)
[docs] @since("3.0.0") def setInputCols(self, value): """ Sets the value of :py:attr:`inputCols`. """ return self._set(inputCols=value)
[docs] def setOutputCol(self, value): """ Sets the value of :py:attr:`outputCol`. """ return self._set(outputCol=value)
[docs] @since("3.0.0") def setOutputCols(self, value): """ Sets the value of :py:attr:`outputCols`. """ return self._set(outputCols=value)
[docs] @since("2.4.0") def setHandleInvalid(self, value): """ Sets the value of :py:attr:`handleInvalid`. """ return self._set(handleInvalid=value)
[docs] @since("3.2.0") def setDefaultValue(self, value): """ Sets the value of :py:attr:`defaultValue`. """ return self._set(defaultValue=value)
[docs] @since("3.2.0") def setFreqLabel(self, value): """ Sets the value of :py:attr:`freqLabel`. """ return self._set(freqLabel=value)
[docs] @since("3.2.0") def setNanLast(self, value): """ Sets the value of :py:attr:`nanLast`. """ return self._set(nanLast=value)
[docs] @classmethod @since("2.4.0") def from_labels( cls, labels, inputCol, outputCol=None, handleInvalid=None, defaultValue=0.0, freqLabel=False, nanLast=False ): """ Construct the model directly from an array of label strings, requires an active SparkContext. """ sc = SparkContext._active_spark_context java_class = sc._gateway.jvm.java.lang.String jlabels = LAMLStringIndexerModel._new_java_array(labels, java_class) model = LAMLStringIndexerModel._create_from_java_class( "org.apache.spark.ml.feature.lightautoml.LAMLStringIndexerModel", jlabels ) model.setInputCol(inputCol) if outputCol is not None: model.setOutputCol(outputCol) if handleInvalid is not None: model.setHandleInvalid(handleInvalid) model.setDefaultValue(defaultValue) model.setFreqLabel(freqLabel) model.setNanLast(nanLast) return model
[docs] @classmethod @since("3.0.0") def from_arrays_of_labels( cls, arrayOfLabels, inputCols, outputCols=None, handleInvalid=None, defaultValue=0.0, freqLabel=False ): """ Construct the model directly from an array of array of label strings, requires an active SparkContext. """ sc = SparkContext._active_spark_context java_class = sc._gateway.jvm.java.lang.String jlabels = LAMLStringIndexerModel._new_java_array(arrayOfLabels, java_class) model = LAMLStringIndexerModel._create_from_java_class( "org.apache.spark.ml.feature.lightautoml.LAMLStringIndexerModel", jlabels ) model.setInputCols(inputCols) if outputCols is not None: model.setOutputCols(outputCols) if handleInvalid is not None: model.setHandleInvalid(handleInvalid) model.setDefaultValue(defaultValue) model.setFreqLabel(freqLabel) return model
@property @since("1.5.0") def labels(self): """ Ordered list of labels, corresponding to indices to be assigned. .. deprecated:: 3.1.0 It will be removed in future versions. Use `labelsArray` method instead. """ return self._call_java("labels") @property @since("3.0.2") def labelsArray(self): """ Array of ordered list of labels, corresponding to indices to be assigned for each input column. """ return self._call_java("getStringLabels") @staticmethod def _from_java(java_stage): """ Given a Java object, create and return a Python wrapper of it. Used for ML persistence. Meta-algorithms such as Pipeline should override this method as a classmethod. """ def __get_class(clazz): """ Loads Python class from its name. """ parts = clazz.split(".") module = ".".join(parts[:-1]) m = __import__(module) for comp in parts[1:]: m = getattr(m, comp) return m stage_name = "sparklightautoml.transformers.scala_wrappers.laml_string_indexer.LAMLStringIndexerModel" # Generate a default new instance from the stage_name class. py_type = __get_class(stage_name) if issubclass(py_type, JavaParams): # Load information from java_stage to the instance. py_stage = py_type() py_stage._java_obj = java_stage py_stage._resetUid(java_stage.uid()) py_stage._transfer_params_from_java() elif hasattr(py_type, "_from_java"): py_stage = py_type._from_java(java_stage) else: raise NotImplementedError("This Java stage cannot be loaded into Python currently: %r" % stage_name) return py_stage