皆さんこんにちは。機械学習エンジニアのwasatingです。
.pyファイルでモデルの学習をしたい時にコマンドラインでハイパラを決めたい!何てことがたまにあると思います。 そんな時にいちいち全パラメータに対して
add_argument('--param1') add_argument('--param2') add_argument('--param3') ...
としていくのはさすがに手間な上にモデルを変えたりした際にまた書き直す必要もあり、さすがに現実的ではないですよね。
というわけで今回は
1. 未知のコマンドライン引数に柔軟に対応し
2. 既知のコマンドライン引数と同様の扱いができるようにしたいと思います
未知のコマンドラインへの対応
こちらはご存じの方も多いかもしれませんが、端的に言うとargparse.ArgumentParser().parse_known_args()
を使用します。
例えば以下のようなコードがあったとして
# parser_test.py def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--model_dir', type=str) parser.add_argument('--output_dir', type=str) return parser.parse_args() args = parse_args()
この時
python parser_test.py --model_dir /path/to/model/dir --output_dir /path/to/output/dir
であれば問題ありませんが、
python parser_test.py --model_dir /path/to/model/dir --output_dir /path/to/output/dir --checkpoint_dir /path/to/check/point/dir
のようにadd_argument
していない引数があると
parser_test.py: error: unrecognized arguments: --checkpoint_dir /path/to/check/point/dir
といったエラーが出ます。
これに対し、
# parser_test_unk.py def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--model_dir', type=str) parser.add_argument('--output_dir', type=str) return parser.parse_known_args() args, unknown = parse_args()
とすることで、
print(f'args: {args}') # Namespace(model_dir=/path/to/model/dir, output_dir=/path/to/output/dir) print(f'unknown: {unknown}') # ['--checkpoint_dir', '/path/to/check/point/dir']
という様に、add_argument
したものはこれまで通りNameSpaceに、していないものは[--arg val ...]のようにlistとして受け取ることができます。
既知のコマンドライン引数と同様の扱いをする
では先ほど得られた未知の引数ですが、このままlistとして使ってもいいですが、せっかくなので既知のargsと同等の扱いができるようにしたいと思います。
def parse_unknown_args(unknowns: list): parser = argparse.ArgumentParser() [parser.add_argument(v) for v in unknowns if v.startswith('--')] return args.parse_args(unknowns) parse_unknown_args(unknown)
上記のようにすることで、未知のコマンドライン引数も既知のもの同様に扱うことができます。
が、実際の使用方法としてはこのまま使うのではなく、
def cast_int_or_float(value): if re.fullmatch(r'-?\d+', value): value = int(value) elif re.match(r'-?\d+\.\d+', value): value = float(value) return value def parse_unknown_args(unknowns: list): parser = argparse.ArgumentParser() [parser.add_argument(v) for v in unknowns if v.startswith('--')] args = parser.parse_args(unknowns) args = vars(args) args = {k: cast_int_or_float(v) for k, v in args.items()} return args
という様にdictとして扱うことが多いかと思います。 ポイントとしては
args.parse_args(unknowns)
とすることで、最初のparse_known_args
でunknownとしたもののみを見る- unknownはすべてstringとして扱われるので内容でintやfloatにキャスト
の二点です
一つ目に関してはこれを忘れると既知の引数も含まれ、本来既知であった引数がparse_unknown_args
内では未知のものとして扱われます。(2時間ぐらいここで溶かした)
以上argparseで未知のコマンドライン引数に対応する方法でした。
この投稿が誰かの助けになれば幸いです。