I found the same discussion in comments section of Create a custom Transformer in PySpark ML, but there is no clear answer. There is also an unresolved JIRA corresponding to that: https://issues.apache.org/jira/browse/SPARK-17025.
Given that there is no option provided by Pyspark ML pipeline for saving a custom transformer written in python, what are the other options to get it done? How can I implement the _to_java method in my python class that returns a compatible java object?
As of Spark 2.3.0 there's a much, much better way to do this.
Simply extend DefaultParamsWritable
and DefaultParamsReadable
and your class will automatically have write
and read
methods that will save your params and will be used by the PipelineModel
serialization system.
The docs were not really clear, and I had to do a bit of source reading to understand this was the way that deserialization worked.
PipelineModel.read
instantiates a PipelineModelReader
PipelineModelReader
loads metadata and checks if language is 'Python'
. If it's not, then the typical JavaMLReader
is used (what most of these answers are designed for)PipelineSharedReadWrite
is used, which calls DefaultParamsReader.loadParamsInstance
loadParamsInstance
will find class
from the saved metadata. It will instantiate that class and call .load(path)
on it. You can extend DefaultParamsReader
and get the DefaultParamsReader.load
method automatically. If you do have specialized deserialization logic you need to implement, I would look at that load
method as a starting place.
On the opposite side:
PipelineModel.write
will check if all stages are Java (implement JavaMLWritable
). If so, the typical JavaMLWriter
is used (what most of these answers are designed for)PipelineWriter
is used, which checks that all stages implement MLWritable
and calls PipelineSharedReadWrite.saveImpl
PipelineSharedReadWrite.saveImpl
will call .write().save(path)
on each stage.You can extend DefaultParamsWriter
to get the DefaultParamsWritable.write
method that saves metadata for your class and params in the right format. If you have custom serialization logic you need to implement, I would look at that and DefaultParamsWriter
as a starting point.
Ok, so finally, you have a pretty simple transformer that extends Params and all your parameters are stored in the typical Params fashion:
from pyspark import keyword_only
from pyspark.ml import Transformer
from pyspark.ml.param.shared import HasOutputCols, Param, Params
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark.sql.functions import lit # for the dummy _transform
class SetValueTransformer(
Transformer, HasOutputCols, DefaultParamsReadable, DefaultParamsWritable,
):
value = Param(
Params._dummy(),
"value",
"value to fill",
)
@keyword_only
def __init__(self, outputCols=None, value=0.0):
super(SetValueTransformer, self).__init__()
self._setDefault(value=0.0)
kwargs = self._input_kwargs
self._set(**kwargs)
@keyword_only
def setParams(self, outputCols=None, value=0.0):
"""
setParams(self, outputCols=None, value=0.0)
Sets params for this SetValueTransformer.
"""
kwargs = self._input_kwargs
return self._set(**kwargs)
def setValue(self, value):
"""
Sets the value of :py:attr:`value`.
"""
return self._set(value=value)
def getValue(self):
"""
Gets the value of :py:attr:`value` or its default value.
"""
return self.getOrDefault(self.value)
def _transform(self, dataset):
for col in self.getOutputCols():
dataset = dataset.withColumn(col, lit(self.getValue()))
return dataset
Now we can use it:
from pyspark.ml import Pipeline
svt = SetValueTransformer(outputCols=["a", "b"], value=123.0)
p = Pipeline(stages=[svt])
df = sc.parallelize([(1, None), (2, 1.0), (3, 0.5)]).toDF(["key", "value"])
pm = p.fit(df)
pm.transform(df).show()
pm.write().overwrite().save('/tmp/example_pyspark_pipeline')
pm2 = PipelineModel.load('/tmp/example_pyspark_pipeline')
print('matches?', pm2.stages[0].extractParamMap() == pm.stages[0].extractParamMap())
pm2.transform(df).show()
Result:
+---+-----+-----+-----+
|key|value| a| b|
+---+-----+-----+-----+
| 1| null|123.0|123.0|
| 2| 1.0|123.0|123.0|
| 3| 0.5|123.0|123.0|
+---+-----+-----+-----+
matches? True
+---+-----+-----+-----+
|key|value| a| b|
+---+-----+-----+-----+
| 1| null|123.0|123.0|
| 2| 1.0|123.0|123.0|
| 3| 0.5|123.0|123.0|
+---+-----+-----+-----+