いろいろ備忘録日記

主に .NET とか Go とか Flutter とか Python絡みのメモを公開しています。

Pythonメモ-08 (Python loop optimization, ループ最適化, 時間計算量, Time-Complexity)

今回、あんまり python 関係ない話題ですが、面白かったのでメモメモ。

どの言語にも言える話です。

概要

stackoverflow の Python カテゴリ見てたら、以下の内容を発見。

stackoverflow.com

トピックの投稿者さんは、処理をPythonで書いてるけどめっちゃ時間かかるからなんとかしたいとのこと。内容は

  • 200万件以上のリストがある
    • 1データは [a, b, c, d] のようにリストとなっている。これが200万件以上
  • データ処理用の dict がある。
    • キーの数が2000オーバー。
    • 1キーに少なくとも50データ入ってるリストが設定されている。
  • 200万オーバーのリストをループして、1データの[0]の要素の値と dict の value のリストを比べる。
    • 値があったら (in) 、[0] の要素値を dict の key に置き換える。

みたいなことをやっているみたい。これが超時間かかるとのこと。

なーんか、実務でもよく見るパターンですねぇw。こんなのよくあります。

回答者の内容

で、回答者の人の内容が以下。

計算量がめっちゃデカイから、少なくせぃ!!

改善ポイントは、以下の部分。

  • Listの in は、線形探索で時間計算量はO(n)なので、これやめる。
    • ソース見ると、listのループ毎に 50 * 2000 とかなってるのでこれは時間かかる。
    • 代わりに set で in する。これは、時間計算量がO(1)。
  • もっと速くしたいんだったら、現在の dict のキーと値を逆にする。
    • これで、[0]の要素値を一発で dict に当てることができる。

みたいな感じ。

基本的に、リストでサーチする処理は、線形探索になるので量が増えると遅いです。 そういう場合は、集合とか辞書とかつかってヒットさせるようにすると一気に速くなります。

大きなデータを処理する場合に、必ずリストなりなんなりのコンテナを利用しますが その際重要なのが

  • 順序が必要かどうか?
  • 重複を許容するかどうか?

ですね。単に値の集まりとして処理するのだったら、集合や辞書使ったほうが効率がいいときが多いです。

なんでもかんでも、リストでやろうとするとデータ量が増えたときにアワワワワってなるときがあります。

データの性質を捉えて、適切なコレクションを選ぶのは、python でも C# でも Java でも同じです。

今回のケースだと、最終的に

It is just roughly 100.000 times faster in this case :-)

くらい速くなったみたいですね。

サンプル

見てて、面白かったので、なんかサンプルでも作ろうかなっておもって python で書いてみました。

やってることに全然意味がないクソサンプルですが。大量のデータが欲しかったので郵便局の郵便番号データを

利用しました。通称 ken_all さん。200万ではないですが、これでも12万4000行以上あります。

# coding: utf-8
"""
Pythonでのループ最適化のサンプルです。
以下のURLの情報にインスパイアされてサンプルつくりました。

http://stackoverflow.com/questions/43827281/python-loop-optimization
"""
import collections
import csv
import pathlib
import zipfile as zip
from timeit import timeit
from typing import List, Dict

import requests


class PrepareProc:
    def __init__(self) -> None:
        super().__init__()

        self.zip_file_name = r'ken_all.zip'
        self.csv_file_name = r'ken_all.csv'

        self.work_dir = pathlib.Path(r'/tmp')
        self.zip_file_path = self.work_dir / self.zip_file_name
        self.csv_file_path = self.work_dir / self.csv_file_name

        # 郵便番号データダウンロードURL
        self.data_url = r'http://www.post.japanpost.jp/zipcode/dl/kogaki/zip/ken_all.zip'
        # 郵便番号データファイルのエンコーディング
        self.csv_encoding = 'sjis'

    def download(self) -> None:
        if self.zip_file_path.exists():
            return

        with open(self.zip_file_path, mode='wb') as writer:
            writer.write(requests.get(self.data_url).content)

    def extract(self) -> None:
        if self.csv_file_path.exists():
            return

        with zip.ZipFile(str(self.zip_file_path.absolute()), mode='r') as z:
            z.extractall(self.work_dir)

    def read(self) -> List[List[str]]:
        with open(self.csv_file_path, mode='rt', encoding=self.csv_encoding, newline='') as f:
            reader = csv.reader(f)
            return [line for line in reader]


# noinspection PyUnresolvedReferences
class _ProcValidateMixin:
    def _pre_validate(self) -> None:
        assert self._lines[0][0] != '北海道'

    def _post_validate(self) -> None:
        assert self._lines[0][0] == '北海道'


class SlowProc(_ProcValidateMixin):
    def __init__(self, lines: List[List[str]]) -> None:
        super().__init__()
        self._lines = lines
        self._mapping = self._make_mapping()

    def __call__(self, *args, **kwargs) -> None:
        for line in self._lines:
            for key in self._mapping:
                if line[0] in self._mapping[key]:
                    line[0] = key

    def _make_mapping(self) -> Dict[str, list]:
        mapping = collections.defaultdict(list)
        for line in self._lines:
            mapping[line[6]].append(line[0])

        return mapping


class NormalProc(_ProcValidateMixin):
    def __init__(self, lines: List[List[str]]) -> None:
        super().__init__()
        self._lines = lines
        self._mapping = self._make_mapping()

    def __call__(self, *args, **kwargs) -> None:
        for line in self._lines:
            for key in self._mapping:
                if line[0] in self._mapping[key]:
                    line[0] = key

    def _make_mapping(self) -> Dict[str, set]:
        mapping = collections.defaultdict(set)
        for line in self._lines:
            mapping[line[6]].add(line[0])

        return mapping


class FastProc(_ProcValidateMixin):
    def __init__(self, lines: List[List[str]]) -> None:
        super().__init__()
        self._lines = lines
        self._mapping = self._make_mapping()

    def __call__(self, *args, **kwargs) -> None:
        for line in self._lines:
            item = self._mapping[line[0]]
            if item:
                line[0] = item

    def _make_mapping(self) -> Dict[str, str]:
        mapping = collections.defaultdict(str)
        for line in self._lines:
            mapping[line[0]] = line[6]

        return mapping


if __name__ == '__main__':
    prepare = PrepareProc()
    prepare.download()
    prepare.extract()

    # 全件処理させると、とても時間がかかるので20000行に絞って実施。
    # 実際の行数は、2017/05/10時点で124115行ある。
    slow = SlowProc(prepare.read()[:20000])
    slow._pre_validate()
    print(f'slow={round(timeit(slow, number=1), 3)}')
    slow._post_validate()

    normal = NormalProc(prepare.read())
    normal._pre_validate()
    print(f'normal={round(timeit(normal, number=1), 3)}')
    normal._post_validate()

    fast = FastProc(prepare.read())
    fast._pre_validate()
    print(f'fast={round(timeit(fast, number=1), 3)}')
    fast._post_validate()

実行すると、以下のようになります。

slow=9.404
normal=0.836
fast=0.025

slowなやつは、全件でやると、いつまでも終わらないので2万件だけでこの時間です。

ソースは、以下でも見れます。

try-python/loop_optimization01.py at master · devlights/try-python · GitHub

参考になる情報


過去の記事については、以下のページからご参照下さい。

サンプルコードは、以下の場所で公開しています。