字典树学习笔记

字典树学习笔记

Fri Nov 22 2024
7 分钟

字典树,顾名思义,就是一棵能当字典的树。先来看板子:

P8306 【模板】字典树#

nn 个模式串和 qq 次询问,对于每次询问,求有多少个模式串满足询问串是模式串的前缀。

字典树的概念不必说(因为这是写给我自己看的笔记),对每一个点记录一个值 cnt[i] ,在插入时,路径上的每一个点 cnt[i] ++ ,表示这个节点表示的字符串又多当了一个字符串的前缀。在查询时,如果这个查询串走到一半往下走不动了,意味着它并不是某一个字符串的前缀,就可以 return 0 了。

AC code:

CPP
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include<bits/extc++.h>
using namespace std;
const int maxn = 3e6 + 5;
int n,q;
int tr[maxn][70],cnt[maxn],idx;
char s[maxn];
inline int toint(char x)
{
    if (x >= 'A' && x <= 'Z')
        return x - 'A' + 1;
    else if (x >= 'a' && x <= 'z')
        return x - 'a' + 26 + 1;
    else
        return x - '0' + 52 + 1;
}
inline void add(char str[])
{
    int pos = 0,len = strlen(str);
    for (int i = 0; i < len; i++)
    {
        int ch = toint(str[i]);
        if (!tr[pos][ch])//理解:每一条边代表一个字符
            tr[pos][ch] = ++idx;
        pos = tr[pos][ch];
        cnt[pos] ++;
    }
}
inline int find(char str[])
{
    int pos = 0,len = strlen(str);
    for (int i = 0; i < len; i++)
    {
        int ch = toint(str[i]);
        if (!tr[pos][ch])
            return 0;
        pos = tr[pos][ch];
    }
    return cnt[pos];
}
void solve()
{
    for (int i = 0; i <= idx; i++)
        for (int j = 1; j <= 65; j++)   
            tr[i][j] = 0;
    for (int i = 0; i <= idx; i++)
        cnt[i] = 0;
    idx = 0;
    scanf("%d%d",&n,&q);
    for (int i = 1; i <= n; i++)
    {
        scanf("%s",s);
        add(s);
    }
    for (int i = 1; i <= q; i++)
    {
        scanf("%s",s);
        printf("%d\n",find(s));
    }
}
signed main()
{
    int t;
    scanf("%d",&t);
    while (t--)
        solve();
    return 0;
}

P4551 最长异或路径#

求最长的异或路径,就是路径上边权的异或和。

dis(u,v)dis(u,v) 表示 uuvv 的异或路径。所以有:

dis(u,v)dis(u,v)

== dis(u,LCA(u,v))dis(u,LCA(u,v)) \oplus dis(v,LCA(u,v))dis(v,LCA(u,v))

== dis(u,LCA(u,v))dis(u,LCA(u,v)) \oplus dis(v,LCA(u,v))dis(v,LCA(u,v)) \oplus dis(LCA(u,v),Root)dis(LCA(u,v),Root) \oplus dis(LCA(u,v),Root)dis(LCA(u,v),Root)

== dis(u,LCA(u,v))dis(u,LCA(u,v)) \oplus dis(LCA(u,v),Root)dis(LCA(u,v),Root) \oplus dis(v,LCA(u,v))dis(v,LCA(u,v)) \oplus dis(LCA,Root)dis(LCA,Root)

== dis(u,Root)dis(u,Root) \oplus dis(v,Root)dis(v,Root)

所以我们可以预处理出所有的 dis(u,Root)dis(u,Root),然后扔到01字典树里。枚举每一个 uu,在字典树上,如果 disudis_u 的第 xx 位取反有对应的子树,就往哪里走并记录答案,因为这样能让那一位异或后变成 11。否则就往另一个子树走,不更新答案

AC code:

CPP
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include<bits/extc++.h>
using namespace std;
const int maxn = 1e5 + 10;
int n;
int head[maxn],idx = 1;
int dis[maxn],tr[maxn << 4][2],id = 1;
struct edge
{
    int v,w,nxt;
    edge(int v = 0,int w = 0,int nxt = 0):v(v),w(w),nxt(nxt){};
}e[maxn << 1];
void adde(int u,int v,int w)
{
    e[++idx] = edge(v,w,head[u]);
    head[u] = idx;
}
void dfs(int u,int fa)//预处理dis[u]
{
    for (int i = head[u]; i; i = e[i].nxt)
    {
        int v = e[i].v,w = e[i].w;
        if (v == fa)
            continue;
        dis[v] = (dis[u] ^ w);
        dfs(v,u);
    }
}
void insert(int x)
{
    int p = 1;
    for (int i = 31; i >= 0; i--)
    {
        int ch = ((x >> i) & 1);
        if (!tr[p][ch])
            tr[p][ch] = ++id;
        p = tr[p][ch];
    }
}
int find(int x)
{
    int p = 1,ret = 0;
    for (int i = 31; i >= 0; i--)
    {
        int ch = ((x >> i) & 1);
        if (tr[p][ch ^ 1])//这一位取反有就往下走并更新答案
        {
            ret += (1 << i);
            p = tr[p][ch ^ 1];
        }
        else//没有就往旁边走。
            p = tr[p][ch];
    }
    return ret;
}
signed main()
{
    scanf("%d",&n);
    for (int i = 1; i < n; i++)
    {
        int u,v,w;
        scanf("%d%d%d",&u,&v,&w);
        adde(u,v,w);
        adde(v,u,w);
    }
    dis[1] = 0;
    dfs(1,0);
    for (int i = 1; i <= n; i++)
        insert(dis[i]);
    int ans = 0;
    for (int i = 1; i <= n; i++)    
        ans = max(ans,find(dis[i]));
    cout << ans;
    return 0;
}