Pythonで抽象化セグメントツリーを生やした
経緯
- 先日、有志コンを開催した eeicpc #1 参加者のみなさんありがとうございました!
- testerをやる際、C++だけでなくPypy3でも通したくなった(なんとなく)
- その中に非可換モノイドを載せる問題があった(これ)
- Python用のセグメント木(以下「セグ木」)を探したが、非可換モノイドの載るセグ木が見当たらない
- (max,min,addなど可換でよければじゅっぴーさんのセグ木がよさそう)
- 仕方がないので自分で生やそう!
ソースコードと使い方
- 使用は自己責任でお願いします
- (2020/03/27 18:22 指摘を受けて少し改良)
(2020/04/02 immutable版を作成、改良。2ベキセグ木の方が時間の定数倍が軽そうなので2ベキセグ木に。)
セグ木に載せるオブジェクトがmutableな場合
import copy class SegmentTree: def __init__(self, N, func, I): self.sz = 2**(N-1).bit_length() self.func = copy.deepcopy(func) self.I = copy.deepcopy(I) self.seg = [copy.deepcopy(I) for i in range(self.sz * 2)] def assign(self, k, x): self.seg[k + self.sz] = copy.deepcopy(x) def build(self): for i in range(self.sz - 1, 0, -1): self.seg[i] = self.func(self.seg[2 * i], self.seg[2 * i + 1]) def update(self, k, x): k += self.sz self.seg[k] = copy.deepcopy(x) while k > 1: k >>= 1 self.seg[k] = self.func(self.seg[2 * k], self.seg[2 * k + 1]) def query(self, a, b): L = copy.deepcopy(self.I) R = copy.deepcopy(self.I) a += self.sz b += self.sz while a < b: if a & 1: L = self.func(L, self.seg[a]) a += 1 if b & 1: b -= 1 R = self.func(self.seg[b], R) a >>= 1 b >>= 1 return self.func(L, R)
- immutableな場合
class SegmentTree: def __init__(self, N, func, I): self.sz = 2**(N-1).bit_length() self.func = func self.I = I self.seg = [I] * (self.sz * 2) def assign(self, k, x): self.seg[k + self.sz] = x def build(self): for i in range(self.sz - 1, 0, -1): self.seg[i] = self.func(self.seg[2 * i], self.seg[2 * i + 1]) def update(self, k, x): k += self.sz self.seg[k] = x while k > 1: k >>= 1 self.seg[k] = self.func(self.seg[2 * k], self.seg[2 * k + 1]) def query(self, a, b): L = self.I R = self.I a += self.sz b += self.sz while a < b: if a & 1: L = self.func(L, self.seg[a]) a += 1 if b & 1: b -= 1 R = self.func(self.seg[b], R) a >>= 1 b >>= 1 return self.func(L, R)
- 使用例 : 区間最大クエリ
- (抽象化してあるのでライブラリをいじらずに書ける)
# 要素数10のセグ木を宣言 seg = SegmentTree(10, max, -(2 ** 60)) # リストaで初期化 for i in range(10): seg.assign(i, a[i]) # 構築 seg.build() # 0番目の要素を10に更新 seg.update(0, 10) # 半開区間[3 , 6)に含まれる要素の最大値を取得 ma = seg.query(3, 6)
使ってみる
注意:行列+セグ木で有名な企業コンの問題のネタバレがあります
~~~~ 以下、ネタバレ防止のため改行 ~~~~
- 例1 : DISCO!
- 「TL : 13s」という仰々しいTLのわりに解法はシンプルな問題
- 行列をセグ木に乗せれば解ける(逆行列を用いた線形解法もあるようだ)
- セグ木の要素数がと最大級、これが通れば大体通りそう
- AtCoderなのでPython3+numpyで挑んでみる
# importとセグメント木は省略、適宜補ってください read = sys.stdin.buffer.read readline = sys.stdin.buffer.readline S = readline().rstrip().decode() Q = int(readline()) m = map(int, read().split()) LR = zip(m, m) seg = SegmentTree(len(S), np.dot, np.identity(6, dtype=np.uint32)) for i in range(len(S)): if S[i] == 'D': seg.seg[i + seg.sz][0][1] = 1 elif S[i] == 'I': seg.seg[i + seg.sz][1][2] = 1 elif S[i] == 'S': seg.seg[i + seg.sz][2][3] = 1 elif S[i] == 'C': seg.seg[i + seg.sz][3][4] = 1 elif S[i] == 'O': seg.seg[i + seg.sz][4][5] = 1 seg.build() for L, R in LR: mat = seg.query(L - 1, R) print(mat[0][5])
- ライブラリを貼ってちょろっと書くだけ、簡単だな!
→TLE どうして…
仕方がないので重そうな部分を適当に書き替える 具体的には
# __init__関数内 self.seg = [copy.deepcopy(I) for i in range(self.sz * 2)]
- が明らかにヤバそうなので、これを
self.seg = np.zeros(self.sz * 2 * 6 * 6, dtype=np.uint32).reshape(self.sz * 2, 6, 6) for i in range(N): self.seg[i + self.sz] = I
- に直す
- →ギリギリ通りました( 12155ms ) 俺の勝ち! 提出
- 例2 : さっき貼った問題
- Pypy3でやる(ソースコード略)と、3.5sで通る(TL : 6s)
- 思ったより早い(C++だと0.3s程度なので10倍程度で済んでいる)
- セグ木の要素数が少ない(N <= )のが大きい
- 例3:GCD on Blackboard
-実装例
read = sys.stdin.buffer.read readline = sys.stdin.buffer.readline def main(): N = int(readline()) m = map(int, read().split()) seg = SegmentTree(N, gcd, 0) for i in range(N): seg.assign(i, next(m)) seg.build() ans = 1 for i in range(N): ans = max(ans, gcd(seg.query(0, i), seg.query(i+1, N))) print(ans) main()
- 今回はセグ木に載るオブジェクトがimmutableなのでimmutable版を使う
- mutable提出 952ms
- immutable提出 587ms
- copy.deepcopy(I)が無くなった分だけはっきり軽くなっている