forked from EndlessCheng/codeforces-go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
odt.go
99 lines (86 loc) · 1.92 KB
/
odt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
package copypasta
import "sort"
// 说明见 odt_bst.go
// 这里用 slice 实现 O(nlogn)
// 对应题目 https://codeforces.com/problemset/problem/896/C
// 我的题解 https://www.luogu.com.cn/blog/endlesscheng/solution-cf896c
type odtBlock struct {
l, r int
val int64
}
type odt []odtBlock
func newODT(arr []int64) odt {
n := len(arr)
t := make(odt, n)
for i := range t {
t[i] = odtBlock{i, i, arr[i]}
}
return t
}
// [l, r] => [l, mid-1] [mid, r]
// return index of [mid, r]
// return len(t) if not found
func (t *odt) split(mid int) int {
ot := *t
for i, b := range ot {
if b.l == mid {
return i
}
if b.l < mid && mid <= b.r { // b.l <= mid-1
*t = append(ot[:i+1], append(odt{{mid, b.r, b.val}}, ot[i+1:]...)...)
ot[i].r = mid - 1
return i + 1
}
}
return len(ot)
}
func (t *odt) prepare(l, r int) (begin, end int) {
begin = t.split(l)
end = t.split(r + 1)
return
}
// 以下方法传入的 begin, end 来自事先计算的 t.prepare
func (t *odt) merge(begin, end, r int, val int64) {
ot := *t
ot[begin].r = r
ot[begin].val = val
if begin+1 < end {
*t = append(ot[:begin+1], ot[end:]...)
}
}
func (t odt) add(begin, end int, val int64) {
for i := begin; i < end; i++ {
t[i].val += val
}
}
func (t odt) kth(begin, end, k int) int64 {
blocks := append(odt(nil), t[begin:end]...)
sort.Slice(blocks, func(i, j int) bool { return blocks[i].val < blocks[j].val })
k--
for _, b := range blocks {
if cnt := b.r - b.l + 1; k >= cnt {
k -= cnt
} else {
return b.val
}
}
panic(k)
}
func (odt) pow(x int64, n int, mod int64) int64 {
x %= mod
res := int64(1) % mod
for ; n > 0; n >>= 1 {
if n&1 == 1 {
res = res * x % mod
}
x = x * x % mod
}
return res
}
func (t odt) powSum(begin, end int, n int, mod int64) (res int64) {
for _, b := range t[begin:end] {
// 总和能溢出的话这里要额外取模
res += int64(b.r-b.l+1) * t.pow(b.val, n, mod)
}
return res % mod
}