## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#importjsonimportosimporttimeimportuuidfromtypingimportAny,Dict,Generic,List,Optional,Sequence,Type,TypeVar,cast,TYPE_CHECKINGfrompysparkimportSparkContext,sincefrompyspark.ml.commonimportinherit_docfrompyspark.sqlimportSparkSessionfrompyspark.utilimportVersionUtilsifTYPE_CHECKING:frompy4j.java_gatewayimportJavaGateway,JavaObjectfrompyspark.ml._typingimportPipelineStagefrompyspark.ml.baseimportParamsfrompyspark.ml.wrapperimportJavaWrapperT=TypeVar("T")RW=TypeVar("RW",bound="BaseReadWrite")W=TypeVar("W",bound="MLWriter")JW=TypeVar("JW",bound="JavaMLWriter")RL=TypeVar("RL",bound="MLReadable")JR=TypeVar("JR",bound="JavaMLReader")def_jvm()->"JavaGateway":""" Returns the JVM view associated with SparkContext. Must be called after SparkContext is initialized. """jvm=SparkContext._jvmifjvm:returnjvmelse:raiseAttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")
[docs]classIdentifiable:""" Object with a unique ID. """def__init__(self)->None:#: A unique id for the object.self.uid=self._randomUID()def__repr__(self)->str:returnself.uid@classmethoddef_randomUID(cls)->str:""" Generate a unique string id for the object. The default implementation concatenates the class name, "_", and 12 random hex chars. """returnstr(cls.__name__+"_"+uuid.uuid4().hex[-12:])
[docs]@inherit_docclassBaseReadWrite:""" Base class for MLWriter and MLReader. Stores information about the SparkContext and SparkSession. .. versionadded:: 2.3.0 """def__init__(self)->None:self._sparkSession:Optional[SparkSession]=None
[docs]defsession(self:RW,sparkSession:SparkSession)->RW:""" Sets the Spark Session to use for saving/loading. """self._sparkSession=sparkSessionreturnself
@propertydefsparkSession(self)->SparkSession:""" Returns the user-specified Spark Session or the default. """ifself._sparkSessionisNone:self._sparkSession=SparkSession._getActiveSessionOrCreate()assertself._sparkSessionisnotNonereturnself._sparkSession@propertydefsc(self)->SparkContext:""" Returns the underlying `SparkContext`. """assertself.sparkSessionisnotNonereturnself.sparkSession.sparkContext
[docs]@inherit_docclassMLWriter(BaseReadWrite):""" Utility class that can save ML instances. .. versionadded:: 2.0.0 """def__init__(self)->None:super(MLWriter,self).__init__()self.shouldOverwrite:bool=Falseself.optionMap:Dict[str,Any]={}def_handleOverwrite(self,path:str)->None:frompyspark.ml.wrapperimportJavaWrapper_java_obj=JavaWrapper._new_java_obj("org.apache.spark.ml.util.FileSystemOverwrite")wrapper=JavaWrapper(_java_obj)wrapper._call_java("handleOverwrite",path,True,self.sparkSession._jsparkSession)
[docs]defsave(self,path:str)->None:"""Save the ML instance to the input path."""ifself.shouldOverwrite:self._handleOverwrite(path)self.saveImpl(path)
[docs]defsaveImpl(self,path:str)->None:""" save() handles overwriting and then calls this method. Subclasses should override this method to implement the actual saving of the instance. """raiseNotImplementedError("MLWriter is not yet implemented for type: %s"%type(self))
[docs]defoverwrite(self)->"MLWriter":"""Overwrites if the output path already exists."""self.shouldOverwrite=Truereturnself
[docs]defoption(self,key:str,value:Any)->"MLWriter":""" Adds an option to the underlying MLWriter. See the documentation for the specific model's writer for possible options. The option name (key) is case-insensitive. """self.optionMap[key.lower()]=str(value)returnself
[docs]@inherit_docclassGeneralMLWriter(MLWriter):""" Utility class that can save ML instances in different formats. .. versionadded:: 2.4.0 """
[docs]defformat(self,source:str)->"GeneralMLWriter":""" Specifies the format of ML export ("pmml", "internal", or the fully qualified class name for export). """self.source=sourcereturnself
@inherit_docclassJavaMLWriter(MLWriter):""" (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaParams` types """_jwrite:"JavaObject"def__init__(self,instance:"JavaMLWritable"):super(JavaMLWriter,self).__init__()_java_obj=instance._to_java()# type: ignore[attr-defined]self._jwrite=_java_obj.write()defsave(self,path:str)->None:"""Save the ML instance to the input path."""ifnotisinstance(path,str):raiseTypeError("path should be a string, got type %s"%type(path))self._jwrite.save(path)defoverwrite(self)->"JavaMLWriter":"""Overwrites if the output path already exists."""self._jwrite.overwrite()returnselfdefoption(self,key:str,value:str)->"JavaMLWriter":self._jwrite.option(key,value)returnselfdefsession(self,sparkSession:SparkSession)->"JavaMLWriter":"""Sets the Spark Session to use for saving."""self._jwrite.session(sparkSession._jsparkSession)returnself@inherit_docclassGeneralJavaMLWriter(JavaMLWriter):""" (Private) Specialization of :py:class:`GeneralMLWriter` for :py:class:`JavaParams` types """def__init__(self,instance:"JavaMLWritable"):super(GeneralJavaMLWriter,self).__init__(instance)defformat(self,source:str)->"GeneralJavaMLWriter":""" Specifies the format of ML export ("pmml", "internal", or the fully qualified class name for export). """self._jwrite.format(source)returnself
[docs]@inherit_docclassMLWritable:""" Mixin for ML instances that provide :py:class:`MLWriter`. .. versionadded:: 2.0.0 """
[docs]defwrite(self)->MLWriter:"""Returns an MLWriter instance for this ML instance."""raiseNotImplementedError("MLWritable is not yet implemented for type: %r"%type(self))
[docs]defsave(self,path:str)->None:"""Save this ML instance to the given path, a shortcut of 'write().save(path)'."""self.write().save(path)
@inherit_docclassJavaMLWritable(MLWritable):""" (Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`. """defwrite(self)->JavaMLWriter:"""Returns an MLWriter instance for this ML instance."""returnJavaMLWriter(self)@inherit_docclassGeneralJavaMLWritable(JavaMLWritable):""" (Private) Mixin for ML instances that provide :py:class:`GeneralJavaMLWriter`. """defwrite(self)->GeneralJavaMLWriter:"""Returns an GeneralMLWriter instance for this ML instance."""returnGeneralJavaMLWriter(self)
[docs]@inherit_docclassMLReader(BaseReadWrite,Generic[RL]):""" Utility class that can load ML instances. .. versionadded:: 2.0.0 """def__init__(self)->None:super(MLReader,self).__init__()
[docs]defload(self,path:str)->RL:"""Load the ML instance from the input path."""raiseNotImplementedError("MLReader is not yet implemented for type: %s"%type(self))
@inherit_docclassJavaMLReader(MLReader[RL]):""" (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types """def__init__(self,clazz:Type["JavaMLReadable[RL]"])->None:super(JavaMLReader,self).__init__()self._clazz=clazzself._jread=self._load_java_obj(clazz).read()defload(self,path:str)->RL:"""Load the ML instance from the input path."""ifnotisinstance(path,str):raiseTypeError("path should be a string, got type %s"%type(path))java_obj=self._jread.load(path)ifnothasattr(self._clazz,"_from_java"):raiseNotImplementedError("This Java ML type cannot be loaded into Python currently: %r"%self._clazz)returnself._clazz._from_java(java_obj)# type: ignore[attr-defined]defsession(self:JR,sparkSession:SparkSession)->JR:"""Sets the Spark Session to use for loading."""self._jread.session(sparkSession._jsparkSession)returnself@classmethoddef_java_loader_class(cls,clazz:Type["JavaMLReadable[RL]"])->str:""" Returns the full class name of the Java ML instance. The default implementation replaces "pyspark" by "org.apache.spark" in the Python full class name. """java_package=clazz.__module__.replace("pyspark","org.apache.spark")ifclazz.__name__in("Pipeline","PipelineModel"):# Remove the last package name "pipeline" for Pipeline and PipelineModel.java_package=".".join(java_package.split(".")[0:-1])returnjava_package+"."+clazz.__name__@classmethoddef_load_java_obj(cls,clazz:Type["JavaMLReadable[RL]"])->"JavaObject":"""Load the peer Java object of the ML instance."""java_class=cls._java_loader_class(clazz)java_obj=_jvm()fornameinjava_class.split("."):java_obj=getattr(java_obj,name)returnjava_obj
[docs]@inherit_docclassMLReadable(Generic[RL]):""" Mixin for instances that provide :py:class:`MLReader`. .. versionadded:: 2.0.0 """
[docs]@classmethoddefread(cls)->MLReader[RL]:"""Returns an MLReader instance for this class."""raiseNotImplementedError("MLReadable.read() not implemented for type: %r"%cls)
[docs]@classmethoddefload(cls,path:str)->RL:"""Reads an ML instance from the input path, a shortcut of `read().load(path)`."""returncls.read().load(path)
@inherit_docclassJavaMLReadable(MLReadable[RL]):""" (Private) Mixin for instances that provide JavaMLReader. """@classmethoddefread(cls)->JavaMLReader[RL]:"""Returns an MLReader instance for this class."""returnJavaMLReader(cls)
[docs]@inherit_docclassDefaultParamsWritable(MLWritable):""" Helper trait for making simple :py:class:`Params` types writable. If a :py:class:`Params` class stores all data as :py:class:`Param` values, then extending this trait will provide a default implementation of writing saved instances of the class. This only handles simple :py:class:`Param` types; e.g., it will not handle :py:class:`pyspark.sql.DataFrame`. See :py:class:`DefaultParamsReadable`, the counterpart to this class. .. versionadded:: 2.3.0 """
[docs]defwrite(self)->MLWriter:"""Returns a DefaultParamsWriter instance for this class."""frompyspark.ml.paramimportParamsifisinstance(self,Params):returnDefaultParamsWriter(self)else:raiseTypeError("Cannot use DefaultParamsWritable with type %s because it does not "+" extend Params.",type(self),)
[docs]@inherit_docclassDefaultParamsWriter(MLWriter):""" Specialization of :py:class:`MLWriter` for :py:class:`Params` types Class for writing Estimators and Transformers whose parameters are JSON-serializable. .. versionadded:: 2.3.0 """def__init__(self,instance:"Params"):super(DefaultParamsWriter,self).__init__()self.instance=instance
[docs]@staticmethoddefsaveMetadata(instance:"Params",path:str,sc:SparkContext,extraMetadata:Optional[Dict[str,Any]]=None,paramMap:Optional[Dict[str,Any]]=None,)->None:""" Saves metadata + Params to: path + "/metadata" - class - timestamp - sparkVersion - uid - paramMap - defaultParamMap (since 2.4.0) - (optionally, extra metadata) Parameters ---------- extraMetadata : dict, optional Extra metadata to be saved at same level as uid, paramMap, etc. paramMap : dict, optional If given, this is saved in the "paramMap" field. """metadataPath=os.path.join(path,"metadata")metadataJson=DefaultParamsWriter._get_metadata_to_save(instance,sc,extraMetadata,paramMap)sc.parallelize([metadataJson],1).saveAsTextFile(metadataPath)
@staticmethoddef_get_metadata_to_save(instance:"Params",sc:SparkContext,extraMetadata:Optional[Dict[str,Any]]=None,paramMap:Optional[Dict[str,Any]]=None,)->str:""" Helper for :py:meth:`DefaultParamsWriter.saveMetadata` which extracts the JSON to save. This is useful for ensemble models which need to save metadata for many sub-models. Notes ----- See :py:meth:`DefaultParamsWriter.saveMetadata` for details on what this includes. """uid=instance.uidcls=instance.__module__+"."+instance.__class__.__name__# User-supplied param valuesparams=instance._paramMapjsonParams={}ifparamMapisnotNone:jsonParams=paramMapelse:forpinparams:jsonParams[p.name]=params[p]# Default param valuesjsonDefaultParams={}forpininstance._defaultParamMap:jsonDefaultParams[p.name]=instance._defaultParamMap[p]basicMetadata={"class":cls,"timestamp":int(round(time.time()*1000)),"sparkVersion":sc.version,"uid":uid,"paramMap":jsonParams,"defaultParamMap":jsonDefaultParams,}ifextraMetadataisnotNone:basicMetadata.update(extraMetadata)returnjson.dumps(basicMetadata,separators=(",",":"))
[docs]@inherit_docclassDefaultParamsReadable(MLReadable[RL]):""" Helper trait for making simple :py:class:`Params` types readable. If a :py:class:`Params` class stores all data as :py:class:`Param` values, then extending this trait will provide a default implementation of reading saved instances of the class. This only handles simple :py:class:`Param` types; e.g., it will not handle :py:class:`pyspark.sql.DataFrame`. See :py:class:`DefaultParamsWritable`, the counterpart to this class. .. versionadded:: 2.3.0 """
[docs]@classmethoddefread(cls)->"DefaultParamsReader[RL]":"""Returns a DefaultParamsReader instance for this class."""returnDefaultParamsReader(cls)
[docs]@inherit_docclassDefaultParamsReader(MLReader[RL]):""" Specialization of :py:class:`MLReader` for :py:class:`Params` types Default :py:class:`MLReader` implementation for transformers and estimators that contain basic (json-serializable) params and no data. This will not handle more complex params or types with data (e.g., models with coefficients). .. versionadded:: 2.3.0 """def__init__(self,cls:Type[DefaultParamsReadable[RL]]):super(DefaultParamsReader,self).__init__()self.cls=cls@staticmethoddef__get_class(clazz:str)->Type[RL]:""" Loads Python class from its name. """parts=clazz.split(".")module=".".join(parts[:-1])m=__import__(module)forcompinparts[1:]:m=getattr(m,comp)returnm
[docs]@staticmethoddefloadMetadata(path:str,sc:SparkContext,expectedClassName:str="")->Dict[str,Any]:""" Load metadata saved using :py:meth:`DefaultParamsWriter.saveMetadata` Parameters ---------- path : str sc : :py:class:`pyspark.SparkContext` expectedClassName : str, optional If non empty, this is checked against the loaded metadata. """metadataPath=os.path.join(path,"metadata")metadataStr=sc.textFile(metadataPath,1).first()loadedVals=DefaultParamsReader._parseMetaData(metadataStr,expectedClassName)returnloadedVals
@staticmethoddef_parseMetaData(metadataStr:str,expectedClassName:str="")->Dict[str,Any]:""" Parse metadata JSON string produced by :py:meth`DefaultParamsWriter._get_metadata_to_save`. This is a helper function for :py:meth:`DefaultParamsReader.loadMetadata`. Parameters ---------- metadataStr : str JSON string of metadata expectedClassName : str, optional If non empty, this is checked against the loaded metadata. """metadata=json.loads(metadataStr)className=metadata["class"]iflen(expectedClassName)>0:assertclassName==expectedClassName,("Error loading metadata: Expected "+"class name {} but found class name {}".format(expectedClassName,className))returnmetadata
[docs]@staticmethoddefgetAndSetParams(instance:RL,metadata:Dict[str,Any],skipParams:Optional[List[str]]=None)->None:""" Extract Params from metadata, and set them in the instance. """# Set user-supplied param valuesforparamNameinmetadata["paramMap"]:param=cast("Params",instance).getParam(paramName)ifskipParamsisNoneorparamNamenotinskipParams:paramValue=metadata["paramMap"][paramName]cast("Params",instance).set(param,paramValue)# Set default param valuesmajorAndMinorVersions=VersionUtils.majorMinorVersion(metadata["sparkVersion"])major=majorAndMinorVersions[0]minor=majorAndMinorVersions[1]# For metadata file prior to Spark 2.4, there is no default section.ifmajor>2or(major==2andminor>=4):assert"defaultParamMap"inmetadata,("Error loading metadata: Expected "+"`defaultParamMap` section not found")forparamNameinmetadata["defaultParamMap"]:paramValue=metadata["defaultParamMap"][paramName]cast("Params",instance)._setDefault(**{paramName:paramValue})
[docs]@staticmethoddefloadParamsInstance(path:str,sc:SparkContext)->RL:""" Load a :py:class:`Params` instance from the given path, and return it. This assumes the instance inherits from :py:class:`MLReadable`. """metadata=DefaultParamsReader.loadMetadata(path,sc)ifDefaultParamsReader.isPythonParamsInstance(metadata):pythonClassName=metadata["class"]else:pythonClassName=metadata["class"].replace("org.apache.spark","pyspark")py_type:Type[RL]=DefaultParamsReader.__get_class(pythonClassName)instance=py_type.load(path)returninstance
[docs]@inherit_docclassHasTrainingSummary(Generic[T]):""" Base class for models that provides Training summary. .. versionadded:: 3.0.0 """@property# type: ignore[misc]@since("2.1.0")defhasSummary(self)->bool:""" Indicates whether a training summary exists for this model instance. """returncast("JavaWrapper",self)._call_java("hasSummary")@property# type: ignore[misc]@since("2.1.0")defsummary(self)->T:""" Gets summary of the model trained on the training set. An exception is thrown if no summary exists. """returncast("JavaWrapper",self)._call_java("summary")
classMetaAlgorithmReadWrite:@staticmethoddefisMetaEstimator(pyInstance:Any)->bool:frompyspark.mlimportEstimator,Pipelinefrompyspark.ml.tuningimport_ValidatorParamsfrompyspark.ml.classificationimportOneVsRestreturn(isinstance(pyInstance,Pipeline)orisinstance(pyInstance,OneVsRest)or(isinstance(pyInstance,Estimator)andisinstance(pyInstance,_ValidatorParams)))@staticmethoddefgetAllNestedStages(pyInstance:Any)->List["Params"]:frompyspark.mlimportPipeline,PipelineModelfrompyspark.ml.tuningimport_ValidatorParamsfrompyspark.ml.classificationimportOneVsRest,OneVsRestModel# TODO: We need to handle `RFormulaModel.pipelineModel` here after Pyspark RFormulaModel# support pipelineModel property.pySubStages:Sequence["Params"]ifisinstance(pyInstance,Pipeline):pySubStages=pyInstance.getStages()elifisinstance(pyInstance,PipelineModel):pySubStages=cast(List["PipelineStage"],pyInstance.stages)elifisinstance(pyInstance,_ValidatorParams):raiseValueError("PySpark does not support nested validator.")elifisinstance(pyInstance,OneVsRest):pySubStages=[pyInstance.getClassifier()]elifisinstance(pyInstance,OneVsRestModel):pySubStages=[pyInstance.getClassifier()]+pyInstance.models# type: ignore[assignment, operator]else:pySubStages=[]nestedStages=[]forpySubStageinpySubStages:nestedStages.extend(MetaAlgorithmReadWrite.getAllNestedStages(pySubStage))return[pyInstance]+nestedStages@staticmethoddefgetUidMap(instance:Any)->Dict[str,"Params"]:nestedStages=MetaAlgorithmReadWrite.getAllNestedStages(instance)uidMap={stage.uid:stageforstageinnestedStages}iflen(nestedStages)!=len(uidMap):raiseRuntimeError(f"{instance.__class__.__module__}.{instance.__class__.__name__}"f".load found a compound estimator with stages with duplicate "f"UIDs. List of UIDs: {list(uidMap.keys())}.")returnuidMap