以下基于hive 3.1.2版本
Hive中自定义UDF函数,有两种实现方式,一是通过继承org.apache.hadoop.hive.ql.exec.UDF
类实现,二是通过继承org.apache.hadoop.hive.ql.udf.generic.GenericUDF
类实现。
无论是哪种方式,实现步骤都是:
- 继承特定类,实现接口或方法
- 打jar包
- 将生成的jar包加入到hive环境中
- 在hive中创建jar包中实现类的对应函数
首先引入pom依赖:
<dependency>
<groupId>org.apache.hive</groupId>
<artifactId>hive-exec</artifactId>
<version>3.1.2</version>
</dependency>
1. UDF实现
继承UDF类实现时只需要实现evaluate方法就可以了,写之前,找了replace函数的源码用来参考,源码贴在下面:
package org.apache.hadoop.hive.ql.udf;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.io.Text;
/**
* UDFReplace replaces all substrings that are matched with a replacement substring.
*
*/
@Description(name = "replace",
value = "_FUNC_(str, search, rep) - replace all substrings of 'str' that "
+ "match 'search' with 'rep'", extended = "Example:\n"
+ " > SELECT _FUNC_('Hack and Hue', 'H', 'BL') FROM src LIMIT 1;\n"
+ " 'BLack and BLue'")
public class UDFReplace extends UDF {
private Text result = new Text();
public UDFReplace() {
}
public Text evaluate(Text s, Text search, Text replacement) {
if (s == null || search == null || replacement == null) {
return null;
}
String r = s.toString().replace(search.toString(), replacement.toString());
result.set(r);
return result;
}
}
模仿上面,自己定义了个函数,功能和hive中的repeat函数一样:
package com.demo.hive;
import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
@Description(name = "my_repeat", // 用于描述该类在hive中对应的函数名,一般与hive中的映射函数名保持一致
value = "_FUNC_(str, n): repeat str n times", // "desc function xxx"时显示的内容
extended = "Example SQL: select _FUNC_('a',3);\nResult: 'aaa'") // "desc function extended xxx"时显示的内容
public class MyUDFRepeat extends UDF {
// 涉及到hive中的字符或字符串类型,建议使用Text类处理
private Text res = new Text();
public Text evaluate(Text str, IntWritable n) {
if (str == null || n == null) {
return null;
}
if (n.get() > 0) {
byte[] arr = str.getBytes();
byte[] newArr = new byte[str.getLength() * n.get()];
for (int i = 0; i < n.get(); i++) {
System.arraycopy(arr, 0, newArr, i * str.getLength(), str.getLength());
}
res.set(newArr);
}
return res;
}
}
在写上面这个函数时,最开始出现了一些问题,逻辑上怎么检查都没看出来,捯饬了将近一天才发现原来是Text类中的getByte()和String中的getByte()略有区别(返回的字节数组长度并不相等),后来将所有的
str.getbytes().length
换成str.getLength()
就好了,这里以后再深入研究一下。关于Text类的API:https://hadoop.apache.org/docs/r3.1.2/api/index.html
将上面源码打成jar包之后上传到hive服务所在主机或者hadoop上,然后在本地idea中执行:
add jar /root/HiveLib/hive_udf-1.0-SNAPSHOT.jar; // jar包加入到hive环境
create temporary function my_repeat as 'com.demo.hive.MyUDFRepeat'; // 创建临时函数,只对当前session生效
创建完函数可以查看一下函数详细信息:
desc function extended my_repeat;
跑下测试数据验证效果:
select *,my_repeat(name,2),repeat(name,2) from db_prac.employee;
2. GenericUDF实现
同样先贴一下length函数的源码,通过GenericUDF类实现需要实现父类中的三个抽象方法:initialize()、evaluate()、getDisplayString()
package org.apache.hadoop.hive.ql.udf.generic;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.vector.VectorizedExpressions;
import org.apache.hadoop.hive.ql.exec.vector.expressions.StringLength;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.lazy.LazyBinary;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorConverter;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.IntWritable;
/**
* GenericUDFLength.
*
*/
@Description(name = "length",
value = "_FUNC_(str | binary) - Returns the length of str or number of bytes in binary data",
extended = "Example:\n"
+ " > SELECT _FUNC_('Facebook') FROM src LIMIT 1;\n" + " 8")
@VectorizedExpressions({StringLength.class})
public class GenericUDFLength extends GenericUDF {
private final IntWritable result = new IntWritable();
private transient PrimitiveObjectInspector argumentOI;
private transient PrimitiveObjectInspectorConverter.StringConverter stringConverter;
private transient PrimitiveObjectInspectorConverter.BinaryConverter binaryConverter;
private transient boolean isInputString;
@Override
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
if (arguments.length != 1) {
throw new UDFArgumentLengthException(
"LENGTH requires 1 argument, got " + arguments.length);
}
if (arguments[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentException(
"LENGTH only takes primitive types, got " + argumentOI.getTypeName());
}
argumentOI = (PrimitiveObjectInspector) arguments[0];
PrimitiveObjectInspector.PrimitiveCategory inputType = argumentOI.getPrimitiveCategory();
ObjectInspector outputOI = null;
switch (inputType) {
case CHAR:
case VARCHAR:
case STRING:
isInputString = true;
stringConverter = new PrimitiveObjectInspectorConverter.StringConverter(argumentOI);
break;
case BINARY:
isInputString = false;
binaryConverter = new PrimitiveObjectInspectorConverter.BinaryConverter(argumentOI,
PrimitiveObjectInspectorFactory.writableBinaryObjectInspector);
break;
default:
throw new UDFArgumentException(
" LENGTH() only takes STRING/CHAR/VARCHAR/BINARY types as first argument, got "
+ inputType);
}
outputOI = PrimitiveObjectInspectorFactory.writableIntObjectInspector;
return outputOI;
}
@Override
public Object evaluate(DeferredObject[] arguments) throws HiveException {
byte[] data = null;
if (isInputString) {
String val = null;
if (arguments[0] != null) {
val = (String) stringConverter.convert(arguments[0].get());
}
if (val == null) {
return null;
}
data = val.getBytes();
int len = 0;
for (int i = 0; i < data.length; i++) {
if (GenericUDFUtils.isUtfStartByte(data[i])) {
len++;
}
}
result.set(len);
return result;
} else {
BytesWritable val = null;
if (arguments[0] != null) {
val = (BytesWritable) binaryConverter.convert(arguments[0].get());
}
if (val == null) {
return null;
}
result.set(val.getLength());
return result;
}
}
@Override
public String getDisplayString(String[] children) {
return getStandardDisplayString("length", children);
}
}
模仿上面,下面写了个判断是否是子字符串的函数:
package com.demo.hive;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
@Description(name = "str_contains", value = "_FUNC_(str1, str2): return true if str1 contains str2, else return false")
public class MyGenericUDFContains extends GenericUDF {
private StringObjectInspector pos1;
private StringObjectInspector pos2;
@Override
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
// 检查参数个数
if (arguments.length != 2) {
throw new UDFArgumentLengthException("参数个数必须为2");
}
// 检查参数类型
if (!(arguments[0] instanceof StringObjectInspector) || !(arguments[1] instanceof StringObjectInspector)) {
throw new UDFArgumentException("参数必须都为String类型");
}
this.pos1 = (StringObjectInspector) arguments[0];
this.pos2 = (StringObjectInspector) arguments[1];
// 函数结果返回类型为布尔类型
return PrimitiveObjectInspectorFactory.javaBooleanObjectInspector;
}
@Override
public Object evaluate(DeferredObject[] arguments) throws HiveException {
String str1 = this.pos1.getPrimitiveJavaObject(arguments[0].get());
String str2 = this.pos2.getPrimitiveJavaObject(arguments[1].get());
return str1.contains(str2) ? Boolean.TRUE : Boolean.FALSE;
}
@Override
public String getDisplayString(String[] children) {
return getStandardDisplayString("str_contains", children);
}
}
打jar包上传之后,创建映射函数:
create temporary function str_contains as 'com.demo.hive.MyGenericUDFContains';
查看一下函数信息:
desc function extended str_contains;
跑一下测试数据:
select name, str_contains(name,"i") from db_prac.employee;
end
总结
- UDF类实现简单,只需要实现evaluate()方法就可以了,并且该方法支持重载;GenericUDF类相对于UDF类复杂了一些,但提供了更加灵活的参数检查和更丰富的参数类型,开发中根据实际情况选择。
- 上面的注册方式为临时注册,注册的函数只在当前session有效,一般只是测试使用。如果需要永久注册,可以先将jar包上传hdfs,然后通过命令
create function my_repeat as 'com.demo.hive.MyUDFRepeat' using jar "hdfs:/user/hive/lib/hive_udf-1.0-SNAPSHOT.jar";
永久注册。
删除注册过的函数:drop [temporary] function xxx;