在StarRocks中使用UDWF计算用户留存
1. 问题定义
假设用户的table schema是下面这样的:
- uid 表示用户id
- date 表示用户在这天活跃
然后我们要计算,T这天活跃的用户,在T+1, T+2, T+3 后面几天的活跃数量。
create table user_activity (uid int, date string)
engine = olap duplicate key (`uid`)
distributed by hash(`uid`)
properties ("replication_num" = "1");
2. 测试数据
使用python脚本来生成测试数据
def main():
data = []
today = datetime.date.today()
for user_id in range(3):
for d in range(10):
if random.random() < 0.5:
date = today - datetime.timedelta(days=d)
data.append((date.strftime('%Y-%m-%d'), user_id))
with open('user_activity.csv', 'w', newline='') as csvfile:
fieldnames = ['date', 'user_id']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
# writer.writeheader()
for row in data:
writer.writerow({'date': row[0], 'user_id': row[1]})
main()
然后使用 stream load 将这些数据导入到 SR 里面
curl --location-trusted -u root: -H "label:124" \
-H "column_separator:," \
-H "columns: date,uid" \
-T user_activity.csv -XPUT \
http://127.0.0.1:41001/api/zya/user_activity/_stream_load
大概数就就是这样的
MySQL [zya]> select * from user_activity; +------+------------+ | uid | date | +------+------------+ | 0 | 2023-07-10 | | 0 | 2023-07-09 | | 0 | 2023-07-08 | | 0 | 2023-07-06 | | 0 | 2023-07-05 | | 1 | 2023-07-12 | | 1 | 2023-07-10 | | 1 | 2023-07-09 | | 1 | 2023-07-07 | | 1 | 2023-07-05 | | 1 | 2023-07-04 | | 2 | 2023-07-12 | | 2 | 2023-07-10 | | 2 | 2023-07-08 | | 2 | 2023-07-06 | | 2 | 2023-07-03 | +------+------------+
3. 留存SQL
这里我们需要使用到窗口函数 Window function @ Window_function @ StarRocks Docs
但是窗口函数有个限制,就是一行数据只能输出一个值,但是我们可能想知道 T 这天之后,T+1, T+2 … 这些天的留存情况。
所以这里我们使用了 bit trick,最多计算 63 天的留存情况。我们使用一个 long类型的字段 `encoded` 来进行编码
- 假设 T 天uid有活跃,那么在 bit 0 进行设置 `encoded |= (1 << 0)`
- 假设 T+x 天 uid 依然有活跃,那么 `encoded |= (1 << x)`
然后在统计的时候,我们使用不同的 bit right shift 来统计不同天数的留存情况,其中 `RetentionEncode` 是我们要实现的UDWF(user define window function)
with Rte as (select uid, date as date , RetentionEncode(date) over (partition by uid order by date desc) as encoded from user_activity)
select date, count(1) as cohort,
sum(bitand(bit_shift_right(encoded, 1), 1)) as day1,
sum(bitand(bit_shift_right(encoded, 2), 1)) as day2 from Rte
group by date order by date;
4. 测试UDWF
使用 global function 机制将这个UDWF测试到SR里面
https://docs.starrocks.io/en-us/latest/sql-reference/sql-functions/JAVA_UDF#java-udfs
drop global function RetentionEncode(string);
CREATE GLOBAL AGGREGATE FUNCTION RetentionEncode(string)
RETURNS bigint
properties
(
"analytic" = "true",
"symbol" = "com.starrocks.udf.retentionEncode",
"type" = "StarrocksJar",
"file" = "http://127.0.0.1:41006/data/sr-udf-1.0-SNAPSHOT-jar-with-dependencies.jar"
);
然后我们看一下 1,2 天的留存情况
MySQL [zya]> select * from user_activity; +------+------------+ | uid | date | +------+------------+ | 0 | 2023-07-10 | | 0 | 2023-07-09 | | 0 | 2023-07-08 | | 0 | 2023-07-06 | | 0 | 2023-07-05 | | 1 | 2023-07-12 | | 1 | 2023-07-10 | | 1 | 2023-07-09 | | 1 | 2023-07-07 | | 1 | 2023-07-05 | | 1 | 2023-07-04 | | 2 | 2023-07-12 | | 2 | 2023-07-10 | | 2 | 2023-07-08 | | 2 | 2023-07-06 | | 2 | 2023-07-03 | +------+------------+ 16 rows in set (0.070 sec) MySQL [zya]> with Rte as (select uid, date as date , RetentionEncode(date) over (partition by uid order by date desc) as encoded from user_activity) select date, count(1) as cohort, sum(bitand(bit_shift_right(encoded, 1), 1)) as day1, sum(bitand(bit_shift_right(encoded, 2), 1)) as day2 from Rte group by date order by date; +------------+--------+------+------+ | date | cohort | day1 | day2 | +------------+--------+------+------+ | 2023-07-03 | 1 | 0 | 0 | | 2023-07-04 | 1 | 1 | 0 | | 2023-07-05 | 2 | 1 | 1 | | 2023-07-06 | 2 | 0 | 2 | | 2023-07-07 | 1 | 0 | 1 | | 2023-07-08 | 2 | 1 | 2 | | 2023-07-09 | 2 | 2 | 0 | | 2023-07-10 | 3 | 0 | 2 | | 2023-07-12 | 2 | 0 | 0 | +------------+--------+------+------+ 9 rows in set (0.385 sec)
这里验证一下效果,以 07-08 这天为例
- 在 07-08 这天,uid = 0, 2 有活跃,所以 cohort = 2
- 在 07-09 这天,uid = 0 有活跃, 所以 day1 = 1
- 在 07-10 这天,uid= 0, 2 有活跃,所以 day2 = 2
5. UDWF实现
https://docs.starrocks.io/en-us/latest/sql-reference/sql-functions/JAVA_UDF#use-a-udwf
代码如下,大致思路是:
- 我们会看到一系列dates, 这些dates是降序排列的
- 我们首先看到T+x, 然后在看到T
- 使用 `TreeSet` 维护63天以内的所有的日期
- 遍历这个 `TreeSet` 来生成对应的 encoded
package com.starrocks.udf;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.TreeSet;
import java.util.concurrent.TimeUnit;
public class retentionEncode {
public static class State {
TreeSet<Long> buffer = new TreeSet<>();
Long encoded = 0L;
@Override
public String toString() {
return String.format("State(encoded = 0x%x, buffer = %d)", encoded, buffer.size());
}
public int serializeLength() {
return 4;
}
}
public State create() {
return new State();
}
public void destroy(State state) {
}
public void update(State state, String val) {
}
public void serialize(State state, java.nio.ByteBuffer buff) {
}
public void merge(State state, java.nio.ByteBuffer buffer) {
}
public Long finalize(State state) {
System.out.println("finalize: " + state.toString());
return state.encoded;
}
public void reset(State state) {
System.out.println("reset: " + state.toString());
state.buffer.clear();
state.encoded = 0L;
}
private static Date stringToDate(String val) {
SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd");
try {
Date date = dateFormat.parse(val);
return date;
} catch (Exception e) {
e.printStackTrace();
return null;
}
}
private void updateEncoded(State state, Date val) {
state.encoded = 0L;
if (val == null) {
return;
}
long now = val.getTime();
state.buffer.add(now);
final int maxDays = 8;
// pop out date >= maxDays
while (state.buffer.size() > 0) {
long last = state.buffer.last();
long diffInMs = last - now;
long diffInDays = TimeUnit.DAYS.convert(diffInMs, TimeUnit.MILLISECONDS);
if (diffInDays >= maxDays) {
state.buffer.pollLast();
} else {
break;
}
}
for (long t : state.buffer) {
long diffInMs = t - now;
long diffInDays = TimeUnit.DAYS.convert(diffInMs, TimeUnit.MILLISECONDS);
assert diffInDays < maxDays;
state.encoded |= (1L << diffInDays);
}
}
public void windowUpdate(State state, int peer_group_start, int peer_group_end, int frame_start, int frame_end,
String[] dates) {
Date val = stringToDate(dates[frame_start]);
updateEncoded(state, val);
}
}