1. 问题定义
假设用户的table schema是下面这样的:
- uid 表示用户id
- date 表示用户在这天活跃
然后我们要计算,T这天活跃的用户,在T+1, T+2, T+3 后面几天的活跃数量。
create table user_activity (uid int, date string) engine = olap duplicate key (`uid`) distributed by hash(`uid`) properties ("replication_num" = "1");
2. 测试数据
def main(): data = [] today = datetime.date.today() for user_id in range(3): for d in range(10): if random.random() < 0.5: date = today - datetime.timedelta(days=d) data.append((date.strftime('%Y-%m-%d'), user_id)) with open('user_activity.csv', 'w', newline='') as csvfile: fieldnames = ['date', 'user_id'] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) # writer.writeheader() for row in data: writer.writerow({'date': row[0], 'user_id': row[1]}) main()
然后使用 stream load 将这些数据导入到 SR 里面
curl --location-trusted -u root: -H "label:124" \ -H "column_separator:," \ -H "columns: date,uid" \ -T user_activity.csv -XPUT \
MySQL [zya]> select * from user_activity; +------+------------+ | uid | date | +------+------------+ | 0 | 2023-07-10 | | 0 | 2023-07-09 | | 0 | 2023-07-08 | | 0 | 2023-07-06 | | 0 | 2023-07-05 | | 1 | 2023-07-12 | | 1 | 2023-07-10 | | 1 | 2023-07-09 | | 1 | 2023-07-07 | | 1 | 2023-07-05 | | 1 | 2023-07-04 | | 2 | 2023-07-12 | | 2 | 2023-07-10 | | 2 | 2023-07-08 | | 2 | 2023-07-06 | | 2 | 2023-07-03 | +------+------------+
3. 留存SQL
这里我们需要使用到窗口函数 Window function @ Window_function @ StarRocks Docs
但是窗口函数有个限制,就是一行数据只能输出一个值,但是我们可能想知道 T 这天之后,T+1, T+2 … 这些天的留存情况。
所以这里我们使用了 bit trick,最多计算 63 天的留存情况。我们使用一个 long类型的字段 `encoded` 来进行编码
- 假设 T 天uid有活跃,那么在 bit 0 进行设置 `encoded |= (1 << 0)`
- 假设 T+x 天 uid 依然有活跃,那么 `encoded |= (1 << x)`
然后在统计的时候,我们使用不同的 bit right shift 来统计不同天数的留存情况,其中 `RetentionEncode` 是我们要实现的UDWF(user define window function)
with Rte as (select uid, date as date , RetentionEncode(date) over (partition by uid order by date desc) as encoded from user_activity) select date, count(1) as cohort, sum(bitand(bit_shift_right(encoded, 1), 1)) as day1, sum(bitand(bit_shift_right(encoded, 2), 1)) as day2 from Rte group by date order by date;
4. 测试UDWF
使用 global function 机制将这个UDWF测试到SR里面
drop global function RetentionEncode(string); CREATE GLOBAL AGGREGATE FUNCTION RetentionEncode(string) RETURNS bigint properties ( "analytic" = "true", "symbol" = "com.starrocks.udf.retentionEncode", "type" = "StarrocksJar", "file" = "" );
然后我们看一下 1,2 天的留存情况
MySQL [zya]> select * from user_activity; +------+------------+ | uid | date | +------+------------+ | 0 | 2023-07-10 | | 0 | 2023-07-09 | | 0 | 2023-07-08 | | 0 | 2023-07-06 | | 0 | 2023-07-05 | | 1 | 2023-07-12 | | 1 | 2023-07-10 | | 1 | 2023-07-09 | | 1 | 2023-07-07 | | 1 | 2023-07-05 | | 1 | 2023-07-04 | | 2 | 2023-07-12 | | 2 | 2023-07-10 | | 2 | 2023-07-08 | | 2 | 2023-07-06 | | 2 | 2023-07-03 | +------+------------+ 16 rows in set (0.070 sec) MySQL [zya]> with Rte as (select uid, date as date , RetentionEncode(date) over (partition by uid order by date desc) as encoded from user_activity) select date, count(1) as cohort, sum(bitand(bit_shift_right(encoded, 1), 1)) as day1, sum(bitand(bit_shift_right(encoded, 2), 1)) as day2 from Rte group by date order by date; +------------+--------+------+------+ | date | cohort | day1 | day2 | +------------+--------+------+------+ | 2023-07-03 | 1 | 0 | 0 | | 2023-07-04 | 1 | 1 | 0 | | 2023-07-05 | 2 | 1 | 1 | | 2023-07-06 | 2 | 0 | 2 | | 2023-07-07 | 1 | 0 | 1 | | 2023-07-08 | 2 | 1 | 2 | | 2023-07-09 | 2 | 2 | 0 | | 2023-07-10 | 3 | 0 | 2 | | 2023-07-12 | 2 | 0 | 0 | +------------+--------+------+------+ 9 rows in set (0.385 sec)
这里验证一下效果,以 07-08 这天为例
- 在 07-08 这天,uid = 0, 2 有活跃,所以 cohort = 2
- 在 07-09 这天,uid = 0 有活跃, 所以 day1 = 1
- 在 07-10 这天,uid= 0, 2 有活跃,所以 day2 = 2
5. UDWF实现
- 我们会看到一系列dates, 这些dates是降序排列的
- 我们首先看到T+x, 然后在看到T
- 使用 `TreeSet` 维护63天以内的所有的日期
- 遍历这个 `TreeSet` 来生成对应的 encoded
package com.starrocks.udf; import java.text.SimpleDateFormat; import java.util.Date; import java.util.TreeSet; import java.util.concurrent.TimeUnit; public class retentionEncode { public static class State { TreeSet<Long> buffer = new TreeSet<>(); Long encoded = 0L; @Override public String toString() { return String.format("State(encoded = 0x%x, buffer = %d)", encoded, buffer.size()); } public int serializeLength() { return 4; } } public State create() { return new State(); } public void destroy(State state) { } public void update(State state, String val) { } public void serialize(State state, java.nio.ByteBuffer buff) { } public void merge(State state, java.nio.ByteBuffer buffer) { } public Long finalize(State state) { System.out.println("finalize: " + state.toString()); return state.encoded; } public void reset(State state) { System.out.println("reset: " + state.toString()); state.buffer.clear(); state.encoded = 0L; } private static Date stringToDate(String val) { SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd"); try { Date date = dateFormat.parse(val); return date; } catch (Exception e) { e.printStackTrace(); return null; } } private void updateEncoded(State state, Date val) { state.encoded = 0L; if (val == null) { return; } long now = val.getTime(); state.buffer.add(now); final int maxDays = 8; // pop out date >= maxDays while (state.buffer.size() > 0) { long last = state.buffer.last(); long diffInMs = last - now; long diffInDays = TimeUnit.DAYS.convert(diffInMs, TimeUnit.MILLISECONDS); if (diffInDays >= maxDays) { state.buffer.pollLast(); } else { break; } } for (long t : state.buffer) { long diffInMs = t - now; long diffInDays = TimeUnit.DAYS.convert(diffInMs, TimeUnit.MILLISECONDS); assert diffInDays < maxDays; state.encoded |= (1L << diffInDays); } } public void windowUpdate(State state, int peer_group_start, int peer_group_end, int frame_start, int frame_end, String[] dates) { Date val = stringToDate(dates[frame_start]); updateEncoded(state, val); } }