polarsでhard negativeのtripletを作成する

皆さんこんにちは。
機械学習エンジニアのwsatingです。

最近openAIが公開したAssistant APIでこれまで以上にRAGが簡単にできるようになってきましたね。
そんな最近の潮流から逆行(?)して今回はRetrieveでよく使われるsentence BERTのfine tuningに使われるTripletをpolarsで作る話をしたいと思います。

Triplet Loss、Tripletとは

Sentence BERTの学習に用いるLossの一つであるTriplet Lossというものがあります。詳細は省きますが、
- anchor
- positive
- negative
の3要素(Triplet)を用いて、anchorとpositiveの距離を近く、anchorとnegativeの距離を遠ざけるような学習をさせることができます。
学習時にnegativeをpositiveと近いもの、例えば同じカテゴリに属するものとする、hard negativeとすること、でより良い学習ができるとされています。

ちなみにTriplet Lossの詳細についてはこちらの記事がとても参考になるので、ご興味ある方はぜひ一度ご覧ください。

qiita.com

polarsでTripletの作成処理を実装

では早速作っていきます
が、その前に…

# 実行環境
import polars as pl
pl.show_versions()

--------Version info---------
Polars:              0.18.15
Index type:          UInt32
Platform:            Linux-5.10.16.3-microsoft-standard-WSL2-x86_64-with-glibc2.31
Python:              3.9.11 (main, Mar 18 2022, 16:45:24) 
[GCC 10.2.1 20210110]

----Optional dependencies----
adbc_driver_sqlite:  <not installed>
cloudpickle:         <not installed>
connectorx:          <not installed>
deltalake:           <not installed>
fsspec:              2022.11.0
matplotlib:          3.6.0
numpy:               1.23.4
pandas:              1.5.1
pyarrow:             12.0.1
pydantic:            1.10.12
sqlalchemy:          <not installed>
xlsx2csv:            <not installed>
xlsxwriter:          <not installed>


というわけで以下から処理を実装していきます。

# stringの表示数を調整
pl.Config.set_fmt_str_lengths(100)
# idx, category, anchor, positive colを持ったDataFrameの作成
df = (
    pl.DataFrame(data={'category': np.random.randint(10, size=2000)})
    .with_row_count(name='idx')
    .with_columns(
        anchor=pl.concat_str([pl.lit('anchor'), pl.lit('cat'), pl.col('category'), pl.col('idx')], separator='_'),
        positive=pl.concat_str([pl.lit('positive'), pl.lit('cat'), pl.col('category'), pl.col('idx')], separator='_'),
    )
)
df.head()
shape: (5, 4)
idx category    anchor  positive
u32 i64 str  str
0  3  "anchor_cat_3_0" "positive_cat_3_0"
1  1  "anchor_cat_1_1" "positive_cat_1_1"
2  7  "anchor_cat_7_2" "positive_cat_7_2"
3  8  "anchor_cat_8_3" "positive_cat_8_3"
4  5  "anchor_cat_5_4" "positive_cat_5_4"

データを用意するのが面倒positiveとnegativeの比較時にわかりやすいようにanchor, positiveの内容は{col_name}_cat_{category_id}_{idx}ということにします。

ではこれをもとに同じカテゴリ属しているほかのpositiveの文章をnegativeとして設定していきます。

n_negative = 5
shuffle_num = 10
shuffle_arg = {
    'fixed_seed': True,
}
sample_args = [{**sample_arg, **{'seed': seed}} for seed in range(shuffle_num)]

df_neg = (
    df
    .select(
        pl.col('idx'),
        negative_idx=pl.concat_list([
            pl.col('idx').sample(**args).over('category')
            for args in sample_args
        ])
        .list.set_difference(pl.col('idx').over('idx', mapping_strategy='join'))
        .list.unique()
        .list.head(n_negative)
    )
    .explode('negative_idx')
)


df_neg = df_neg.join(
    df.select('idx', pl.col('positive').alias('negative')),
    left_on='negative_idx', right_on='idx'
)
df = df.join(df_neg, on='idx', how='inner').sort('idx')
df.head()
shape: (5, 6)
idx category    anchor  positive    negative_idx    negative
u32 i64 str  str  u32 str
0  3  "anchor_cat_3_0" "positive_cat_3_0"   740    "positive_cat_3_740"
0  3  "anchor_cat_3_0" "positive_cat_3_0"   754    "positive_cat_3_754"
0  3  "anchor_cat_3_0" "positive_cat_3_0"   761    "positive_cat_3_761"
0  3  "anchor_cat_3_0" "positive_cat_3_0"   1049   "positive_cat_3_1049"
0  3  "anchor_cat_3_0" "positive_cat_3_0"   1124   "positive_cat_3_1124"

これだけです。
と言われてもとなると思うので、解説していきます。

  • sample_args
    複数のseedが設定されたargを作成しています。
    これは後述するsampleを実行時に、seedが単一の値だと同じ結果しか得られないため、このような形にしています(あんまりきれいではないですが…)
    ちなみにfixed_seedをTrueにする必要があるのはv0.18までのようで、v0.19からはfixed_seedは引数から消去されていました。

  • negative_idx
    始めにidxをshuffleによってランダムにnegative用のidxに割り当てていきます。
    この時、.overを使用することで、categoryでgroup byした状態でshuffleを行えるようになっています。

    次にset_differenceによって、shuffleの結果含まれてしまったidxと同値のものを除去しています。
    なおset_differenceはdtypeがlistのもの同士である必要があるため、.over(mapping_strategy='join')でlistに変換しています。 そして、uniqueによって重複したnegative_idxを除去し、headn_negativeで設定した数の要素を取得しています。

これ以降はnegative_idxに紐づけてpositivenegativeとして扱い、元のdfに結合させています。

おわりに

以上hard negativeのtripletを作るまででした。 tripletの作成だけでなく、他のpolarsを使った処理実装のご参考になれば幸いです