argparseで未知のコマンドライン引数に対応する

皆さんこんにちは。機械学習エンジニアのwasatingです。

.pyファイルでモデルの学習をしたい時にコマンドラインでハイパラを決めたい!何てことがたまにあると思います。 そんな時にいちいち全パラメータに対して

add_argument('--param1')
add_argument('--param2')
add_argument('--param3')
...

としていくのはさすがに手間な上にモデルを変えたりした際にまた書き直す必要もあり、さすがに現実的ではないですよね。

というわけで今回は
1. 未知のコマンドライン引数に柔軟に対応
2. 既知のコマンドライン引数と同様の扱いができるようにしたいと思います

未知のコマンドラインへの対応

こちらはご存じの方も多いかもしれませんが、端的に言うとargparse.ArgumentParser().parse_known_args()を使用します。
例えば以下のようなコードがあったとして

# parser_test.py
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', type=str)
    parser.add_argument('--output_dir', type=str)

    return parser.parse_args()

args = parse_args()

この時

python parser_test.py --model_dir /path/to/model/dir --output_dir /path/to/output/dir

であれば問題ありませんが、

python parser_test.py --model_dir /path/to/model/dir --output_dir /path/to/output/dir --checkpoint_dir /path/to/check/point/dir

のようにadd_argumentしていない引数があると

parser_test.py: error: unrecognized arguments: --checkpoint_dir /path/to/check/point/dir

といったエラーが出ます。

これに対し、

# parser_test_unk.py
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', type=str)
    parser.add_argument('--output_dir', type=str)

    return parser.parse_known_args()

args, unknown = parse_args()

とすることで、

print(f'args: {args}')
#  Namespace(model_dir=/path/to/model/dir, output_dir=/path/to/output/dir)
print(f'unknown: {unknown}')
# ['--checkpoint_dir', '/path/to/check/point/dir']

という様に、add_argumentしたものはこれまで通りNameSpaceに、していないものは[--arg val ...]のようにlistとして受け取ることができます。

既知のコマンドライン引数と同様の扱いをする

では先ほど得られた未知の引数ですが、このままlistとして使ってもいいですが、せっかくなので既知のargsと同等の扱いができるようにしたいと思います。

def parse_unknown_args(unknowns: list):
    parser = argparse.ArgumentParser()
    [parser.add_argument(v) for v in unknowns if v.startswith('--')]

    return args.parse_args(unknowns)

parse_unknown_args(unknown)

上記のようにすることで、未知のコマンドライン引数も既知のもの同様に扱うことができます。

が、実際の使用方法としてはこのまま使うのではなく、

def cast_int_or_float(value):
    if re.fullmatch(r'-?\d+', value):
        value = int(value)
    elif re.match(r'-?\d+\.\d+', value):
        value = float(value)

    return value


def parse_unknown_args(unknowns: list):
    parser = argparse.ArgumentParser()
    [parser.add_argument(v) for v in unknowns if v.startswith('--')]

    args = parser.parse_args(unknowns)
    args = vars(args)
    args = {k: cast_int_or_float(v) for k, v in args.items()}

    return args

という様にdictとして扱うことが多いかと思います。 ポイントとしては

  • args.parse_args(unknowns)とすることで、最初のparse_known_argsでunknownとしたもののみを見る
  • unknownはすべてstringとして扱われるので内容でintやfloatにキャスト

の二点です

一つ目に関してはこれを忘れると既知の引数も含まれ、本来既知であった引数がparse_unknown_args内では未知のものとして扱われます。(2時間ぐらいここで溶かした)


以上argparseで未知のコマンドライン引数に対応する方法でした。
この投稿が誰かの助けになれば幸いです。