S4 Weights
GPUで可逆圧縮するチェックポイント
PyTorch学習チェックポイント(重み + オプティマイザ状態)向けのGPU可逆圧縮コーデック。各テンソルをバイトプレーン分割し、指数部をANS・仮数部をGDeflate(nvCOMP)へGPU上で振り分け、bf16 / fp16 / fp32でbit-exact(可逆)を保ったまま圧縮します。圧縮済みチェックポイントは自分のS3バケットに保存。g6 / g6e GPU AMIとして提供。
S4 Weights は、学習が書き出すチェックポイント(重みとオプティマイザ状態)を透過的かつ可逆に圧縮するコーデックです。復元はbyte-for-byteで一致し、NaN / ±Inf / 非正規化数 / -0.0 といった敵対的ビットパターンに対し全対応dtypeで検証済。頻繁にチェックポイントを取るランでは連続チェックポイント間のbyte-XOR差分も圧縮します。AMIのビルド自体が、ビルドGPU上で圧縮→復元のbit-exactラウンドトリップに失敗するとfailするため、壊れたプレーン再構成が顧客イメージに到達しません。チェックポイントは自分のVPC・自分のS3から外に出ません。
課題
大規模な PyTorch 学習では、モデルの重みとオプティマイザの状態を含むチェックポイントを頻繁に書き出すため、Amazon S3 上のストレージ量と転送バイト数が膨らみ続けます。チェックポイントが大きいほど保存・読み込みのたびに学習が止まる時間も長くなり、ストレージコストと GPU の待ち時間の両方が積み上がります。これらのチェックポイントは bf16 / fp16 / fp32 の数値そのものであり、欠損や量子化が許されないため、可逆(ロスレス)でなければ圧縮を信頼して使えません。
仕組み
- 1
GPU 上でバイトプレーンに分割
各テンソルを GPU 上で符号・指数・仮数のバイトプレーンに分割し、それぞれに適したコーデックへ振り分けます。指数プレーンは ANS エントロピー符号化、仮数プレーンは GDeflate、符号プレーンはビットパッキングで処理します(NVIDIA nvCOMP を基盤に実装)。
- 2
チェックポイント間の差分を圧縮
頻繁にチェックポイントを取る学習では、連続するチェックポイント間のバイト XOR 差分を保存して圧縮します。保存ごとに重みがほとんど変化しない場合、この差分は元のチェックポイントよりはるかに小さくなります。
- 3
ビット完全に復元し自分の S3 へ保存
復元は常にバイト単位で完全一致し、NaN・±Inf・非正規化数・-0.0 といった敵対的なビットパターンに対しても全対応 dtype で検証済みです。圧縮済みチェックポイントはお客様自身が指定した S3 バケットへ保存され、アカウント外に出ることはありません。
特長
可逆(byte-for-byte)をbf16 / fp16 / fp32で保証。NaN / ±Inf / 非正規化数 / -0.0 の敵対的パターンに対して検証済。
バイトプレーン分割: 指数部→ANS、仮数部→GDeflate をGPU(nvCOMP)上で実行。チェックポイント間の差分連鎖も圧縮。
圧縮済みチェックポイントは自分のVPC内の自分のS3バケットへ。g6 / g6e GPU AMI。
含まれるもの
- GPU ロスレスチェックポイントコーデック — PyTorch 学習のチェックポイント(モデルの重みとオプティマイザの状態)を bf16 / fp16 / fp32 でビット完全に圧縮
- バイトプレーン分割データプレーン(指数プレーン→ANS、仮数プレーン→GDeflate、符号プレーン→ビットパッキング、NVIDIA nvCOMP を基盤に GPU 上で実行)
- 連続チェックポイント間のバイト XOR 差分チェーン — 頻繁な保存でさらに圧縮率を高め、ブロブを小さな固定ヘッダ分以上に膨張させない
- PyTorch へのドロップイン API — 透過的な s4weights.save / s4weights.load と、ベース→差分のチェックポイントストア(save_checkpoint / load_checkpoint)
- 敵対的ビットパターン検証 — NaN・±Inf・非正規化数・-0.0 を全対応 dtype で検証し、AMI ビルド自体がビルド GPU 上での圧縮→展開のラウンドトリップがビット完全でなければ失敗する
- Ubuntu 22.04 ベースの g6 / g6e GPU AMI と、エンドツーエンドで構成する CloudFormation テンプレート(deploy/cfn-train-runner.yaml)。runner は非 root の systemd サービスとして TCP 8080 でヘルスエンドポイントを提供
- お客様自身の S3 レジストリバケットへの保存(データはアカウント外に出ない)+起動時のフェイルクローズドな RegisterUsage エンタイトルメント検証
こんな用途に
チェックポイントを頻繁に書き出し、S3 のストレージコストと転送バイト数を抑えたい大規模 PyTorch 学習
オプティマイザの状態が学習を支配するワークロードで、保存・読み込み時の停止時間を短縮したいケース
数値の欠損や量子化が許されず、復元がビット完全であることを保証したい学習パイプライン
all-bf16 や低精度オプティマイザのチェックポイントなど、圧縮が効きやすいデータを扱うチーム
よくある質問
圧縮は本当にロスレスですか?
はい。復元は常にバイト単位で完全一致し、bf16 / fp16 / fp32 の重みと fp32 のオプティマイザ状態について、NaN・±Inf・非正規化数・-0.0 といった敵対的なビットパターンで検証しています。さらに AMI ビルド自体が、ビルド GPU 上での圧縮→展開のラウンドトリップがビット完全でなければ失敗するため、プレーン再構成が壊れたコーデックがお客様のイメージに届くことはありません。
どれくらい圧縮できますか?
削減量はデータ次第です。all-bf16 や低精度オプティマイザのチェックポイントはよく圧縮され、規模が大きいほど効果が高まります。一方で、間隔を空けて保存される fp32 主体のチェックポイントはあまり圧縮されません。当社はどこで効果が出るかを正直に示しており、固定の圧縮率は主張しません。なお圧縮は常にロスレスで、ブロブを小さな固定ヘッダ分を超えて膨張させることはありません。
チェックポイントはどこに保存されますか?
圧縮済みチェックポイントは、お客様自身が指定した Amazon S3 レジストリバケットに保存され、アカウント外に出ることはありません。AMI はお客様自身の VPC 内で起動し、PyTorch の学習コードが s4weights.save / s4weights.load(または差分チェーンの save_checkpoint / load_checkpoint)でチェックポイントを書き出すと、各テンソルが GPU 上で圧縮され、ビット完全な圧縮チェックポイントが S3 レジストリへ保存されます。
どのインスタンスで動作し、どのように課金されますか?
g6 または g6e の GPU インスタンスで動作し、付属の CloudFormation テンプレート(deploy/cfn-train-runner.yaml)がエンドツーエンドで構成します。課金はインスタンスタイプごとの時間課金で、年額オプションもあります。AWS が稼働中のインスタンス時間を自動計測し、runner は起動時に一度だけ RegisterUsage をフェイルクローズドなエンタイトルメント検証として呼び出します(エンタイトルメントのないインスタンスは起動を拒否します)。
PyTorch の学習コードへの組み込みは簡単ですか?
ドロップインで使えます。透過的な s4weights.save / s4weights.load でチェックポイントを書き出すか、ベース→差分のチェックポイントストアを使う場合は save_checkpoint / load_checkpoint を利用します。各テンソルは GPU 上で圧縮され、頻繁に保存する学習では連続するチェックポイント間のバイト XOR 差分も保存・圧縮されます。
料金モデル
時間課金のソフトウェア利用料 + GPU EC2(g6 / g6e)。インスタンスタイプ別の従量課金、年額オプションあり。