主页
搜索
最近更新
数据统计
申请密钥
批量保存
开发版网站(新前端)
系统公告
1
/
1
请查看完所有公告
吃最正宗的shi--------shi山平衡树
最后更新于 2025-08-27 17:39:43
作者
RecursionMH
分类
算法·理论
复制 Markdown
查看原文
转到新前端
删除文章
更新内容
```cpp #include<bits/stdc++.h> using namespace std; #define int long long const int N = 1300005, INF = 1e18; int n, tot = 0, root = 0, m, last = 0, ans_xor = 0; struct tree { int ch[2], a, fa; int cnt = 0; int size = 0; } x[N]; /*int find (int a) {//在平衡树中尝试寻找a节点 int now = root; while(now) { if (a < x[now].a) { if (x[now].ch[0])now = x[now].ch[0]; else { splay(now, 0); return; } }else { if (a > x[now].a) { if (x[now].ch[1])now = x[now].ch[1]; else { splay(now, 0); return 0; } }else { splay(now, 0); return a; } } } return 0; }*/ //没用find bool get(int a) {//查询该节点是其父亲的左节点还是右节点,0:左,1:右 return x[x[a].fa].ch[1] == a; } inline void pushup(int u) {//维护size x[u].size = (x[u].ch[0] ? x[x[u].ch[0]].size : 0) + (x[u].ch[1] ? x[x[u].ch[1]].size : 0) + x[u].cnt; } inline void rotate(int p) {//上旋x //将节点x向上旋到其父节点位置 //y=x[p].fa,z=x[p].fa; //get(x) get(y); { int y = x[p].fa, z = x[y].fa; int gc = get(p); x[y].ch[gc] = x[p].ch[gc ^ 1]; if (x[p].ch[gc ^ 1]) x[x[p].ch[gc ^ 1]].fa = y; x[p].ch[gc ^ 1] = y; x[y].fa = p; x[p].fa = z; if (z) { if (x[z].ch[0] == y) x[z].ch[0] = p; else x[z].ch[1] = p; } pushup(y); pushup(p); } inline void splay(int a, int goal) {//伸展节点a至其成为goal的子节点 //如果只需要上旋一次,肯定rotate就可以了 //如果距离目标位置不止一层,则每次都是双旋 //zig-zag和zag-zig形连续旋两次x //zig-zig形和zag-zag形,先旋x父节点y,再旋x while (x[a].fa != goal) { int fa = x[a].fa, gfa = x[fa].fa; if (gfa != goal) { if (get(a) == get(fa)) rotate(fa); else rotate(a); } rotate(a); } if (goal == 0) root = a; } void insert(int a) {//插入 if (!root) { root = ++tot; x[root].a = a; x[root].cnt = 1; x[root].size = 1; return; } int now = root, fa = 0; while (true) { if (a == x[now].a) { x[now].cnt++; x[now].size++; splay(now, 0); return; } fa = now; now = x[now].ch[a > x[now].a]; if (!now) { x[++tot].a = a; x[tot].cnt = 1; x[tot].size = 1; x[tot].fa = fa; x[fa].ch[a > x[fa].a] = tot; splay(tot, 0); return; } } } void erase(int a) {//删除 int now = root; while (now) { if (a < x[now].a) now = x[now].ch[0]; else if (a > x[now].a) now = x[now].ch[1]; else break; } if (!now) return; splay(now, 0); if (x[now].cnt > 1) { x[now].cnt--; pushup(now); return; } int l = x[now].ch[0], r = x[now].ch[1]; if (!l && !r) { root = 0; } else if (!l) { root = r; x[r].fa = 0; } else if (!r) { root = l; x[l].fa = 0; } else { int p = l; while (x[p].ch[1]) p = x[p].ch[1]; splay(p, now); x[p].ch[1] = r; x[r].fa = p; root = p; x[p].fa = 0; pushup(p); } } int ran(int a) {//查找a从小到大的排名 int now = root, res = 0; while (now) { if (a < x[now].a) { now = x[now].ch[0]; } else { res += (x[now].ch[0] ? x[x[now].ch[0]].size : 0); if (a == x[now].a) { splay(now, 0); return res + 1; } res += x[now].cnt; now = x[now].ch[1]; } } return res + 1; } int getk(int k) {//查找排名为k的num int now = root; while (now) { int left_size = (x[now].ch[0] ? x[x[now].ch[0]].size : 0); if (k <= left_size) { now = x[now].ch[0]; } else if (k <= left_size + x[now].cnt) { splay(now, 0); return x[now].a; } else { k -= left_size + x[now].cnt; now = x[now].ch[1]; } } return -1; } int before(int a) {//前驱 int now = root, res = -INF, abaaba = 0; while (now) { if (x[now].a < a) { if (x[now].a > res) { res = x[now].a; abaaba = now; } now = x[now].ch[1]; } else { now = x[now].ch[0]; } } if (abaaba) splay(abaaba, 0); return res == -INF ? 0 : res; } int after(int a) {//后继 int now = root, res = INF, abaaba = 0; while (now) { if (x[now].a > a) { if (x[now].a < res) { res = x[now].a; abaaba = now; } now = x[now].ch[0]; } else { now = x[now].ch[1]; } } if (abaaba) splay(abaaba, 0); return res == INF ? 0 : res; } inline int read() { int x = 0, f = 1; char ch = getchar(); while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); } while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); } return x * f; } signed main() { n = read(), m = read(); for (int i = 1; i <= n; ++i) { int a = read(); insert(a); } for (int i = 1; i <= m; ++i) { int opt = read(); int val = read(); if (opt >= 3 && opt <= 6) { val ^= last; } else { val ^= last; } if (opt == 1) insert(val); else if (opt == 2) erase(val); else if (opt == 3) last = ran(val); else if (opt == 4) last = getk(val); else if (opt == 5) last = before(val); else if (opt == 6) last = after(val); if (opt >= 3) ans_xor ^= last; } printf("%lld\n", ans_xor); return 0; } ``` 版本2 ```cpp #include <bits/stdc++.h> using namespace std; #define int long long #define ls(x) t[x].ch[0] #define rs(x) t[x].ch[1] #define fa(x) t[x].fa const int N = 1000001; int n, tot = 0, root = 0; struct tree { int ch[2], fa; int size, cnt, a; } t[N]; inline void pushup(int x) { if (x) t[x].size = t[ls(x)].size + t[rs(x)].size + t[x].cnt; } int newnode(int num, int fa) { t[++tot].a = num; t[tot].cnt = t[tot].size = 1; t[tot].fa = fa; return tot; } bool get(int x) { return rs(fa(x)) == x; } inline void rotate(int x) { int y = fa(x), z = fa(y); int tx = get(x), ty = get(y); if (z) t[z].ch[ty] = x; fa(x) = z; t[y].ch[tx] = t[x].ch[tx ^ 1]; if (t[x].ch[tx ^ 1]) fa(t[x].ch[tx ^ 1]) = y; t[x].ch[tx ^ 1] = y; fa(y) = x; pushup(y); pushup(x); } inline void splay(int x) { while (fa(x)) { int y = fa(x); if (fa(y)) { if (get(y) == get(x)) rotate(y); else rotate(x); } rotate(x); } root = x; } void find(int num) { int p = root, last = 0; if (p == 0) return; while (p) { last = p; if (t[p].a > num) p = ls(p); else if (t[p].a < num) p = rs(p); else break; } if (p) splay(p); else if (last) splay(last); } void insert(int num) { int p = root, last = 0; while (p) { last = p; if (t[p].a == num) { t[p].cnt++; splay(p); return; } else if (t[p].a > num) p = ls(p); else p = rs(p); } if (!last) { root = newnode(num, 0); return; } if (t[last].a > num) { ls(last) = newnode(num, last); splay(ls(last)); } else { rs(last) = newnode(num, last); splay(rs(last)); } } void del(int num) { find(num); if (root == 0 || t[root].a != num) return; if (t[root].cnt > 1) { t[root].cnt--; pushup(root); return; } if (!ls(root) && !rs(root)) { root = 0; return; } if (!ls(root)) { int tmp = root; root = rs(root); fa(root) = 0; return; } if (!rs(root)) { int tmp = root; root = ls(root); fa(root) = 0; return; } int tmp = root; int p = rs(root); while (ls(p)) p = ls(p); splay(p); ls(p) = ls(tmp); fa(ls(tmp)) = p; pushup(p); } int queryrank(int num) { find(num); if (t[root].a == num) return t[ls(root)].size + 1; else if (t[root].a < num) return t[ls(root)].size + t[root].cnt + 1; else return t[ls(root)].size + 1; } int querynum(int k) { int p = root; while (p) { if (t[ls(p)].size >= k) p = ls(p); else if (t[ls(p)].size + t[p].cnt >= k) break; else { k -= t[ls(p)].size + t[p].cnt; p = rs(p); } } if (p) splay(p); return t[p].a; } int pre(int num) { find(num); if (t[root].a < num) return t[root].a; int p = ls(root); if (!p) return -1e18; while (rs(p)) p = rs(p); splay(p); return t[p].a; } int suf(int num) { find(num); if (t[root].a > num) return t[root].a; int p = rs(root); if (!p) return 1e18; while (ls(p)) p = ls(p); splay(p); return t[p].a; } signed main() { cin >> n; for (int i = 1; i <= n; ++i) { int opt, val; cin >> opt >> val; if (opt == 1) insert(val); else if (opt == 2) del(val); else if (opt == 3) cout << queryrank(val) << '\n'; else if (opt == 4) cout << querynum(val) << '\n'; else if (opt == 5) cout << pre(val) << '\n'; else if (opt == 6) cout << suf(val) << '\n'; } return 0; } ```
正在渲染内容...
点赞
1
收藏
0