我需要实现一个动态的“自带代码”功能,以注册从我自己的代码外部创建的UDF。这是容器化的,入口点是标准的python解释器(不是pypsark)。根据启动时的配置设置,spark容器将使用以下类似内容初始化自身。我们不预先知道函数的定义,但是可以在需要时在容器上预安装依赖项。
def register_udf_module(udf_name, zip_or_py_path, file_name, function_name, return_type="int"):
# Psueduocode:
global sc, spark
sc.addPyFile(zip_or_py_path)
module_ref = some_inspect_function_1(zip_or_py_path)
file_ref = module_ref[file_name]
function_ref = module_ref[function_name]
spark.udf.register(udf_name, function_ref, return_type)
我似乎找不到有关如何完成此操作的参考。具体来说,用例是,通过运行该初始化火花群集之后,用户将需要此UDF用于SQL函数(通过Thrift JDBC连接)。我不知道JDBC / SQL连接与注册UDF的能力之间存在任何接口,因此它必须已启动并可以运行以进行SQL查询,而且我以后不能指望用户会调用'spark.udf.register在他们的身边。
我发现的解决方案是在启动开始时获取一个环境变量,该环境变量指向UDF目录,然后加载并检查该路径中的每个.py文件,并在spark中加载作为UDF函数找到的所有函数。
下面的示例工作代码:
def init_spark():
global sc
# Init spark (nothing special here)
conf = SparkConf()
spark = (
SparkSession.builder.config(conf=conf)
.master("local")
.appName("Python Spark")
.enableHiveSupport()
.getOrCreate()
)
if "SPARK_UDFS_PATH" in os.environ:
add_udf_module(os.environ.get("SPARK_UDFS_PATH"))
def add_udf_module(module_dir=None):
global sc
from inspect import getmembers, isfunction
module_dir = os.path.realpath(module_dir)
if not os.path.isdir(module_dir):
raise ValueError(f"Folder '{module_dir}' does not exist.")
for file in io.list_files(module_dir):
if file.endswith(".py"):
module = path_import(file)
for member in getmembers(module):
if isfunction(member[1]):
logging.info(f"Found module function: {member}")
func_name, func = member[0], member[1]
if func_name[:1] != "_" and func_name != "udf":
logging.info(f"Registering UDF '{func_name}':\n{func.__dict__}")
spark.udf.register(func_name, func)
def path_import(absolute_file_path):
module_name = os.path.basename(absolute_file_path)
module_name = ".".join(module_name.split(".")[:-1]) # removes '.py'
spec = importlib.util.spec_from_file_location(module_name, absolute_file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
相关:
示例UDF python文件:
from pyspark.sql.functions import udf
from pyspark.sql import types
@udf(types.Long())
def times_five(value):
return value * 5
@udf("long")
def times_six(value):
return value * 6
示例SQL:
SELECT times_six(7) AS the_answer