1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
| #include <bits/stdc++.h> using namespace std;
#define gc getchar
typedef long long ll;
const int N = 5e5 + 10,inf = 1e9;
ll read() { ll a = 0,b = 0;char c = gc(); while(c < '0'||c > '9') b = (c == '-'),c = gc(); while(c >= '0'&&c <= '9') a = a * 10 + c - 48,c = gc(); return b ? -a : a; }
struct node { int val,pri,siz; int ls,rs; };
struct pir{int l,r;};
struct treap { node tr[N]; int idx = 0; void upd(int u) { if(!u) return ; node &p = tr[u]; p.siz = tr[p.ls].siz + tr[p.rs].siz + 1; } pir split(int rt,int val) { if(!rt) return pir{0,0}; node &p = tr[rt]; pir o; if(p.val < val) { o = split(p.rs,val); p.rs = o.l; o.l = rt; } else { o = split(p.ls,val); p.ls = o.r; o.r = rt; } upd(rt); return o; } int merge(int l,int r) { if(!l||!r) return l + r; node &lp = tr[l],&rp = tr[r]; if(lp.pri < rp.pri) { rp.ls = merge(l,rp.ls); upd(r); return r; } else { lp.rs = merge(lp.rs,r); upd(l); return l; } } int ins(int rt,int val) { pir o = split(rt,val); tr[++idx] = node{val,rand(),1,0,0}; o.l = merge(o.l,idx); return merge(o.l,o.r); } int del(int rt,int val) { node &p = tr[rt]; if(p.val < val) p.rs = del(p.rs,val); else if(p.val == val) { rt = merge(p.ls,p.rs); upd(rt); return rt; } else p.ls = del(p.ls,val); upd(rt); return rt; } int find_kth(int rt,int val) { if(!rt) return 1; node p = tr[rt]; if(p.val < val) return find_kth(p.rs,val) + tr[p.ls].siz + 1; else return find_kth(p.ls,val); } int find_val(int rt,int k) { if(!rt) return -1; node p = tr[rt]; if(tr[p.ls].siz + 1 > k) return find_val(p.ls,k); else if(tr[p.ls].siz + 1 == k) return p.val; else return find_val(p.rs,k - tr[p.ls].siz - 1); } int find_pre(int rt,int val) { int pre = -inf; while(rt) { if(tr[rt].val < val) pre = tr[rt].val,rt = tr[rt].rs; else rt = tr[rt].ls; } return pre; } int find_suc(int rt,int val) { int suc = inf; while(rt) { if(tr[rt].val > val) suc = tr[rt].val,rt = tr[rt].ls; else rt = tr[rt].rs; } return suc; } }t; int main() { int T = read(),rt = 0; while(T--) { int op = read(),p = read(); if(op == 1) rt = t.ins(rt,p); if(op == 2) rt = t.del(rt,p); if(op == 3) printf("%d\n",t.find_kth(rt,p)); if(op == 4) printf("%d\n",t.find_val(rt,p)); if(op == 5) printf("%d\n",t.find_pre(rt,p)); if(op == 6) printf("%d\n",t.find_suc(rt,p)); } return 0; }
|