Nyaanの日記

競プロ用(精進記録など) 乱文失礼します

Pythonで抽象化セグメントツリーを生やした

経緯

  • 先日、有志コンを開催した eeicpc #1 参加者のみなさんありがとうございました!
  • testerをやる際、C++だけでなくPypy3でも通したくなった(なんとなく)
  • その中に非可換モノイドを載せる問題があった(これ)
  • Python用のセグメント木(以下「セグ木」)を探したが、非可換モノイドの載るセグ木が見当たらない
  • (max,min,addなど可換でよければじゅっぴーさんのセグ木がよさそう)
  • 仕方がないので自分で生やそう!

ソースコードと使い方

  • 使用は自己責任でお願いします
  • (2020/03/27 18:22 指摘を受けて少し改良)
  • (2020/04/02 immutable版を作成、改良。2ベキセグ木の方が時間の定数倍が軽そうなので2ベキセグ木に。)

  • セグ木に載せるオブジェクトがmutableな場合

import copy
class SegmentTree:
    def __init__(self, N, func, I):
        self.sz = 2**(N-1).bit_length()
        self.func = copy.deepcopy(func)
        self.I = copy.deepcopy(I)
        self.seg = [copy.deepcopy(I) for i in range(self.sz * 2)]

    def assign(self, k, x):
        self.seg[k + self.sz] = copy.deepcopy(x)

    def build(self):
        for i in range(self.sz - 1, 0, -1):
            self.seg[i] = self.func(self.seg[2 * i], self.seg[2 * i + 1])

    def update(self, k, x):
        k += self.sz
        self.seg[k] = copy.deepcopy(x)
        while k > 1:
            k >>= 1
            self.seg[k] = self.func(self.seg[2 * k], self.seg[2 * k + 1])

    def query(self, a, b):
        L = copy.deepcopy(self.I)
        R = copy.deepcopy(self.I)
        a += self.sz
        b += self.sz
        while a < b:
            if a & 1:
                L = self.func(L, self.seg[a])
                a += 1
            if b & 1:
                b -= 1
                R = self.func(self.seg[b], R)
            a >>= 1
            b >>= 1
        return self.func(L, R)
  • immutableな場合
class SegmentTree:
    def __init__(self, N, func, I):
        self.sz = 2**(N-1).bit_length()
        self.func = func
        self.I = I
        self.seg = [I] * (self.sz * 2)
 
    def assign(self, k, x):
        self.seg[k + self.sz] = x
 
    def build(self):
        for i in range(self.sz - 1, 0, -1):
            self.seg[i] = self.func(self.seg[2 * i], self.seg[2 * i + 1])
 
    def update(self, k, x):
        k += self.sz
        self.seg[k] = x
        while k > 1:
            k >>= 1
            self.seg[k] = self.func(self.seg[2 * k], self.seg[2 * k + 1])
 
    def query(self, a, b):
        L = self.I
        R = self.I
        a += self.sz
        b += self.sz
        while a < b:
            if a & 1:
                L = self.func(L, self.seg[a])
                a += 1
            if b & 1:
                b -= 1
                R = self.func(self.seg[b], R)
            a >>= 1
            b >>= 1
        return self.func(L, R)
  • 使用例 : 区間最大クエリ
  • (抽象化してあるのでライブラリをいじらずに書ける)
# 要素数10のセグ木を宣言
seg = SegmentTree(10, max, -(2 ** 60))
# リストaで初期化
for i in range(10):
    seg.assign(i, a[i])
# 構築
seg.build()
# 0番目の要素を10に更新
seg.update(0, 10)
# 半開区間[3 , 6)に含まれる要素の最大値を取得
ma = seg.query(3, 6)

使ってみる

注意:行列+セグ木で有名な企業コンの問題のネタバレがあります
~~~~ 以下、ネタバレ防止のため改行 ~~~~



















  • 例1 : DISCO!
  • 「TL : 13s」という仰々しいTLのわりに解法はシンプルな問題
  • 行列をセグ木に乗せれば解ける(逆行列を用いた線形解法もあるようだ)
  • セグ木の要素数 10 ^ 6 と最大級、これが通れば大体通りそう
  • AtCoderなのでPython3+numpyで挑んでみる
# importとセグメント木は省略、適宜補ってください
read = sys.stdin.buffer.read
readline = sys.stdin.buffer.readline
S = readline().rstrip().decode()
Q = int(readline())
m = map(int, read().split())
LR = zip(m, m)
seg = SegmentTree(len(S), np.dot, np.identity(6, dtype=np.uint32))

for i in range(len(S)):
    if S[i] == 'D':
        seg.seg[i + seg.sz][0][1] = 1
    elif S[i] == 'I':
        seg.seg[i + seg.sz][1][2] = 1
    elif S[i] == 'S':
        seg.seg[i + seg.sz][2][3] = 1
    elif S[i] == 'C':
        seg.seg[i + seg.sz][3][4] = 1
    elif S[i] == 'O':
        seg.seg[i + seg.sz][4][5] = 1
seg.build()

for L, R in LR:
    mat = seg.query(L - 1, R)
    print(mat[0][5])
  • ライブラリを貼ってちょろっと書くだけ、簡単だな!
  • →TLE どうして…

  • 仕方がないので重そうな部分を適当に書き替える 具体的には

# __init__関数内
self.seg = [copy.deepcopy(I) for i in range(self.sz * 2)]
  • が明らかにヤバそうなので、これを
self.seg = np.zeros(self.sz * 2 * 6 * 6, dtype=np.uint32).reshape(self.sz * 2, 6, 6)
for i in range(N):
  self.seg[i + self.sz] = I
  • に直す
  • →ギリギリ通りました( 12155ms ) 俺の勝ち! 提出



read = sys.stdin.buffer.read
readline = sys.stdin.buffer.readline
def main():
 N = int(readline())
    m = map(int, read().split())
    seg = SegmentTree(N, gcd, 0)
    for i in range(N):
        seg.assign(i, next(m))
    seg.build()
    ans = 1
    for i in range(N):
        ans = max(ans, gcd(seg.query(0, i), seg.query(i+1, N)))
    print(ans)

main()
  • 今回はセグ木に載るオブジェクトがimmutableなのでimmutable版を使う
  • mutable提出 952ms
  • immutable提出 587ms
  • copy.deepcopy(I)が無くなった分だけはっきり軽くなっている

おわりに

  • C++最高!一番好きな競プロ用言語です!(は?)
  • 冗談はともかく、セグ木を使わざるを得ない問題は高速な言語を使った方がよさそう
  • (追記:なぜかというと、僕はpythonでDISCO!を通すのに2時間弱かかったので)
  • 高速に動かすのがかなり難しく、適切に書かないと \mathrm {O} ( N \log N ) , N \leq 2 \cdot 10 ^ 5がTLEする
  • Python使いならセグ木を使わない方向で強くなった方がよさそう
  • (例えばDISCO!は更新クエリがないので逆行列を使って線形で解くべき)
  • ただ、セグ木+Pythonがメチャクチャ弱いというわけでもなさそうなのが書いてみてわかったので記事にした
  • ライブラリの使用は自己責任でお願いします、問題点があったら教えてください