皆さんこんにちは。
機械学習エンジニアのwsatingです。
最近openAIが公開したAssistant APIでこれまで以上にRAGが簡単にできるようになってきましたね。
そんな最近の潮流から逆行(?)して今回はRetrieveでよく使われるsentence BERTのfine tuningに使われるTripletをpolarsで作る話をしたいと思います。
Triplet Loss、Tripletとは
Sentence BERTの学習に用いるLossの一つであるTriplet Lossというものがあります。詳細は省きますが、
- anchor
- positive
- negative
の3要素(Triplet)を用いて、anchorとpositiveの距離を近く、anchorとnegativeの距離を遠ざけるような学習をさせることができます。
学習時にnegativeをpositiveと近いもの、例えば同じカテゴリに属するものとする、hard negativeとすること、でより良い学習ができるとされています。
ちなみにTriplet Lossの詳細についてはこちらの記事がとても参考になるので、ご興味ある方はぜひ一度ご覧ください。
polarsでTripletの作成処理を実装
では早速作っていきます
が、その前に…
# 実行環境 import polars as pl pl.show_versions() --------Version info--------- Polars: 0.18.15 Index type: UInt32 Platform: Linux-5.10.16.3-microsoft-standard-WSL2-x86_64-with-glibc2.31 Python: 3.9.11 (main, Mar 18 2022, 16:45:24) [GCC 10.2.1 20210110] ----Optional dependencies---- adbc_driver_sqlite: <not installed> cloudpickle: <not installed> connectorx: <not installed> deltalake: <not installed> fsspec: 2022.11.0 matplotlib: 3.6.0 numpy: 1.23.4 pandas: 1.5.1 pyarrow: 12.0.1 pydantic: 1.10.12 sqlalchemy: <not installed> xlsx2csv: <not installed> xlsxwriter: <not installed>
というわけで以下から処理を実装していきます。
# stringの表示数を調整 pl.Config.set_fmt_str_lengths(100) # idx, category, anchor, positive colを持ったDataFrameの作成 df = ( pl.DataFrame(data={'category': np.random.randint(10, size=2000)}) .with_row_count(name='idx') .with_columns( anchor=pl.concat_str([pl.lit('anchor'), pl.lit('cat'), pl.col('category'), pl.col('idx')], separator='_'), positive=pl.concat_str([pl.lit('positive'), pl.lit('cat'), pl.col('category'), pl.col('idx')], separator='_'), ) ) df.head() shape: (5, 4) idx category anchor positive u32 i64 str str 0 3 "anchor_cat_3_0" "positive_cat_3_0" 1 1 "anchor_cat_1_1" "positive_cat_1_1" 2 7 "anchor_cat_7_2" "positive_cat_7_2" 3 8 "anchor_cat_8_3" "positive_cat_8_3" 4 5 "anchor_cat_5_4" "positive_cat_5_4"
データを用意するのが面倒positiveとnegativeの比較時にわかりやすいようにanchor, positiveの内容は{col_name}_cat_{category_id}_{idx}
ということにします。
ではこれをもとに同じカテゴリ属しているほかのpositiveの文章をnegativeとして設定していきます。
n_negative = 5 shuffle_num = 10 shuffle_arg = { 'fixed_seed': True, } sample_args = [{**sample_arg, **{'seed': seed}} for seed in range(shuffle_num)] df_neg = ( df .select( pl.col('idx'), negative_idx=pl.concat_list([ pl.col('idx').sample(**args).over('category') for args in sample_args ]) .list.set_difference(pl.col('idx').over('idx', mapping_strategy='join')) .list.unique() .list.head(n_negative) ) .explode('negative_idx') ) df_neg = df_neg.join( df.select('idx', pl.col('positive').alias('negative')), left_on='negative_idx', right_on='idx' ) df = df.join(df_neg, on='idx', how='inner').sort('idx') df.head() shape: (5, 6) idx category anchor positive negative_idx negative u32 i64 str str u32 str 0 3 "anchor_cat_3_0" "positive_cat_3_0" 740 "positive_cat_3_740" 0 3 "anchor_cat_3_0" "positive_cat_3_0" 754 "positive_cat_3_754" 0 3 "anchor_cat_3_0" "positive_cat_3_0" 761 "positive_cat_3_761" 0 3 "anchor_cat_3_0" "positive_cat_3_0" 1049 "positive_cat_3_1049" 0 3 "anchor_cat_3_0" "positive_cat_3_0" 1124 "positive_cat_3_1124"
これだけです。
と言われてもとなると思うので、解説していきます。
sample_args
複数のseed
が設定されたargを作成しています。
これは後述するsample
を実行時に、seedが単一の値だと同じ結果しか得られないため、このような形にしています(あんまりきれいではないですが…)
ちなみにfixed_seed
をTrueにする必要があるのはv0.18までのようで、v0.19からはfixed_seed
は引数から消去されていました。negative_idx
始めにidx
をshuffleによってランダムにnegative用のidxに割り当てていきます。
この時、.over
を使用することで、category
でgroup byした状態でshuffleを行えるようになっています。次に
set_difference
によって、shuffleの結果含まれてしまったidx
と同値のものを除去しています。
なおset_difference
はdtypeがlistのもの同士である必要があるため、.over(mapping_strategy='join')
でlistに変換しています。 そして、unique
によって重複したnegative_idx
を除去し、head
でn_negative
で設定した数の要素を取得しています。
これ以降はnegative_idx
に紐づけてpositive
をnegative
として扱い、元のdfに結合させています。
おわりに
以上hard negativeのtripletを作るまででした。 tripletの作成だけでなく、他のpolarsを使った処理実装のご参考になれば幸いです