在StarRocks中使用UDWF计算用户留存

Table of Contents

1. 问题定义

假设用户的table schema是下面这样的:

  • uid 表示用户id
  • date 表示用户在这天活跃

然后我们要计算,T这天活跃的用户,在T+1, T+2, T+3 后面几天的活跃数量。

create table user_activity (uid int, date string)
    engine = olap duplicate key (`uid`)
    distributed by hash(`uid`)
    properties ("replication_num" = "1");

2. 测试数据

使用python脚本来生成测试数据

def main():
    data = []
    today = datetime.date.today()
    for user_id in range(3):
        for d in range(10):
            if random.random() < 0.5:
                date = today - datetime.timedelta(days=d)
                data.append((date.strftime('%Y-%m-%d'), user_id))

    with open('user_activity.csv', 'w', newline='') as csvfile:
        fieldnames = ['date', 'user_id']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        # writer.writeheader()
        for row in data:
            writer.writerow({'date': row[0], 'user_id': row[1]})

main()

然后使用 stream load 将这些数据导入到 SR 里面

curl --location-trusted -u root: -H "label:124" \
    -H "column_separator:," \
    -H "columns: date,uid" \
    -T user_activity.csv -XPUT \
    http://127.0.0.1:41001/api/zya/user_activity/_stream_load

大概数就就是这样的

MySQL [zya]> select * from user_activity;
+------+------------+
| uid  | date       |
+------+------------+
|    0 | 2023-07-10 |
|    0 | 2023-07-09 |
|    0 | 2023-07-08 |
|    0 | 2023-07-06 |
|    0 | 2023-07-05 |
|    1 | 2023-07-12 |
|    1 | 2023-07-10 |
|    1 | 2023-07-09 |
|    1 | 2023-07-07 |
|    1 | 2023-07-05 |
|    1 | 2023-07-04 |
|    2 | 2023-07-12 |
|    2 | 2023-07-10 |
|    2 | 2023-07-08 |
|    2 | 2023-07-06 |
|    2 | 2023-07-03 |
+------+------------+

3. 留存SQL

这里我们需要使用到窗口函数 Window function @ Window_function @ StarRocks Docs

但是窗口函数有个限制,就是一行数据只能输出一个值,但是我们可能想知道 T 这天之后,T+1, T+2 … 这些天的留存情况。

所以这里我们使用了 bit trick,最多计算 63 天的留存情况。我们使用一个 long类型的字段 `encoded` 来进行编码

  • 假设 T 天uid有活跃,那么在 bit 0 进行设置 `encoded |= (1 << 0)`
  • 假设 T+x 天 uid 依然有活跃,那么 `encoded |= (1 << x)`

然后在统计的时候,我们使用不同的 bit right shift 来统计不同天数的留存情况,其中 `RetentionEncode` 是我们要实现的UDWF(user define window function)

with Rte as (select uid, date as date , RetentionEncode(date)  over (partition by uid order by date desc) as encoded from user_activity)
     select date, count(1) as cohort,
          sum(bitand(bit_shift_right(encoded, 1), 1)) as day1,
          sum(bitand(bit_shift_right(encoded, 2), 1)) as day2 from Rte
          group by date order by date;

4. 测试UDWF

使用 global function 机制将这个UDWF测试到SR里面

https://docs.starrocks.io/en-us/latest/sql-reference/sql-functions/JAVA_UDF#java-udfs

drop global function RetentionEncode(string);

CREATE GLOBAL AGGREGATE FUNCTION RetentionEncode(string)
RETURNS bigint
properties
(
    "analytic" = "true",
    "symbol" = "com.starrocks.udf.retentionEncode",
    "type" = "StarrocksJar",
    "file" = "http://127.0.0.1:41006/data/sr-udf-1.0-SNAPSHOT-jar-with-dependencies.jar"
);

然后我们看一下 1,2 天的留存情况

MySQL [zya]> select * from user_activity;
+------+------------+
| uid  | date       |
+------+------------+
|    0 | 2023-07-10 |
|    0 | 2023-07-09 |
|    0 | 2023-07-08 |
|    0 | 2023-07-06 |
|    0 | 2023-07-05 |
|    1 | 2023-07-12 |
|    1 | 2023-07-10 |
|    1 | 2023-07-09 |
|    1 | 2023-07-07 |
|    1 | 2023-07-05 |
|    1 | 2023-07-04 |
|    2 | 2023-07-12 |
|    2 | 2023-07-10 |
|    2 | 2023-07-08 |
|    2 | 2023-07-06 |
|    2 | 2023-07-03 |
+------+------------+
16 rows in set (0.070 sec)

MySQL [zya]> with Rte as (select uid, date as date , RetentionEncode(date)  over (partition by uid order by date desc) as encoded from user_activity)    select date, count(1) as cohort, sum(bitand(bit_shift_right(encoded, 1), 1)) as day1, sum(bitand(bit_shift_right(encoded, 2), 1)) as day2 from Rte group by date order by date;
+------------+--------+------+------+
| date       | cohort | day1 | day2 |
+------------+--------+------+------+
| 2023-07-03 |      1 |    0 |    0 |
| 2023-07-04 |      1 |    1 |    0 |
| 2023-07-05 |      2 |    1 |    1 |
| 2023-07-06 |      2 |    0 |    2 |
| 2023-07-07 |      1 |    0 |    1 |
| 2023-07-08 |      2 |    1 |    2 |
| 2023-07-09 |      2 |    2 |    0 |
| 2023-07-10 |      3 |    0 |    2 |
| 2023-07-12 |      2 |    0 |    0 |
+------------+--------+------+------+
9 rows in set (0.385 sec)

这里验证一下效果,以 07-08 这天为例

  • 在 07-08 这天,uid = 0, 2 有活跃,所以 cohort = 2
  • 在 07-09 这天,uid = 0 有活跃, 所以 day1 = 1
  • 在 07-10 这天,uid= 0, 2 有活跃,所以 day2 = 2

5. UDWF实现

https://docs.starrocks.io/en-us/latest/sql-reference/sql-functions/JAVA_UDF#use-a-udwf

代码如下,大致思路是:

  • 我们会看到一系列dates, 这些dates是降序排列的
  • 我们首先看到T+x, 然后在看到T
  • 使用 `TreeSet` 维护63天以内的所有的日期
  • 遍历这个 `TreeSet` 来生成对应的 encoded
package com.starrocks.udf;

import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.TreeSet;
import java.util.concurrent.TimeUnit;

public class retentionEncode {
    public static class State {
        TreeSet<Long> buffer = new TreeSet<>();

        Long encoded = 0L;

        @Override
        public String toString() {
            return String.format("State(encoded = 0x%x, buffer = %d)", encoded, buffer.size());
        }

        public int serializeLength() {
            return 4;
        }
    }

    public State create() {
        return new State();
    }

    public void destroy(State state) {

    }

    public void update(State state, String val) {
    }

    public void serialize(State state, java.nio.ByteBuffer buff) {
    }

    public void merge(State state, java.nio.ByteBuffer buffer) {
    }

    public Long finalize(State state) {
        System.out.println("finalize: " + state.toString());
        return state.encoded;
    }

    public void reset(State state) {
        System.out.println("reset: " + state.toString());
        state.buffer.clear();
        state.encoded = 0L;
    }

    private static Date stringToDate(String val) {
        SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd");
        try {
            Date date = dateFormat.parse(val);
            return date;
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    private void updateEncoded(State state, Date val) {
        state.encoded = 0L;
        if (val == null) {
            return;
        }

        long now = val.getTime();
        state.buffer.add(now);

        final int maxDays = 8;

        // pop out date >= maxDays
        while (state.buffer.size() > 0) {
            long last = state.buffer.last();
            long diffInMs = last - now;
            long diffInDays = TimeUnit.DAYS.convert(diffInMs, TimeUnit.MILLISECONDS);
            if (diffInDays >= maxDays) {
                state.buffer.pollLast();
            } else {
                break;
            }
        }

        for (long t : state.buffer) {
            long diffInMs = t - now;
            long diffInDays = TimeUnit.DAYS.convert(diffInMs, TimeUnit.MILLISECONDS);
            assert diffInDays < maxDays;
            state.encoded |= (1L << diffInDays);
        }
    }

    public void windowUpdate(State state, int peer_group_start, int peer_group_end, int frame_start, int frame_end,
                             String[] dates) {
        Date val = stringToDate(dates[frame_start]);
        updateEncoded(state, val);
    }
}