FlashKDA dari Moonshot AI hadir sebagai solusi kernel CUDA high-performance untuk Kimi Delta Attention. Temukan cara implementasi drop-in ini memberikan speedup 1.72x-2.22x dan mengoptimasi inference LLM.

Tim riset Moonshot AI (yang bikin Kimi.ai) baru aja kasih kontribusi besar buat infrastruktur AI open-source. Mereka merilis FlashKDA, yaitu implementasi kernel CUDA berbasis CUTLASS untuk mekanisme Kimi Delta Attention (KDA).

FlashKDA ini udah tersedia di GitHub dengan lisensi MIT. Yang menarik, library ini ngasih speedup prefill 1.72x sampai 2.22x dibanding baseline flash-linear-attention di GPU NVIDIA H20. Plus, ini bisa langsung dipakai sebagai backend drop-in untuk library flash-linear-attention yang populer.

Buat ngerti kenapa FlashKDA penting, kita perlu lihat dulu landscape attention di LLM. Attention softmax standar punya kompleksitas kuadratik terhadap panjang sequence. Artinya, makin panjang konteks yang kamu masukkin, biaya komputasinya naik drastis.

Advertisement

Nah, ini yang bikin banyak peneliti fokus ke linear attention mechanisms. Mekanisme ini ngapproximate atau nganti operasi softmax buat dapetin scaling linear. Kimi Delta Attention (KDA) adalah kontribusi Moonshot AI di bidang ini.

KDA itu refinement dari Gated DeltaNet dengan mekanisme gating yang lebih fine-grained, channel-wise. Hasilnya, model bisa pake finite-state RNN memory lebih efektif.

Yang perlu dicatat, KDA bukan cuma prototipe riset. Mekanisme ini jadi core attention di Kimi Linear, model hybrid open-source dari Moonshot AI dengan 48B total parameter dan 3B activated parameter.

Kimi Linear pake rasio 3:1 KDA-to-MLA (Multi-Head Latent Attention). Ada tiga layer KDA untuk setiap satu layer global attention. Dengan setup ini, penggunaan KV cache bisa berkurang sampai 75% saat generasi sequence panjang.

Throughput decoding-nya naik sampai 6x lebih tinggi di konteks 1 juta token dibanding full attention. FlashKDA adalah kernel CUDA production-grade yang bikin arsitektur itu kenceng saat prefill.

Forward pass KDA menerima queries (q), keys (k), values (v), gate sebelum aktivasi (g), dan beta logits (beta). Ada juga scale factor, output tensor, dan parameter gate kayak A_log, dt_bias, dan lower_bound.

Kernel ini secara internal nerapin sigmoid activation di beta. Mekanismenya juga support initial dan final recurrent states opsional. Fitur ini berguna buat multi-turn inference di mana kamu mau bawa state across requests.

FlashKDA dibangun di atas CUTLASS, library open-source dari NVIDIA berisi template abstraksi CUDA C++ untuk linear algebra dan custom kernel development. CUTLASS bikin developer bisa nulis kernel yang maksimalin Tensor Core architecture.

Library ini target SM90 ke atas, artinya arsitektur NVIDIA Hopper (H100, H20) dan yang lebih baru. Minimum requirements-nya adalah CUDA 12.9 dan PyTorch 2.4.

Codebase-nya dominan CUDA (56.4%), dengan binding Python (36.2%) dan glue code C++ (6.7%). Core API-nya adalah flash_kda.fwd yang menerima input q, k, v, g dalam bf16 dengan shape [B, T, H, K] atau [B, T, H, V].

Parameter beta juga bf16 dengan shape [B, T, H], scale dalam fp32, dan output tensor bf16. Ada juga parameter gate A_log, dt_bias, lower_bound dalam fp32, serta optional initial_state dan final_state dalam bf16 atau fp32.

Yang penting, ada support cu_seqlens untuk variable-length batching. Ini critical buat production use karena di real inference serving, request dalam satu batch jarang punya sequence length sama.

Hasil benchmark (per April 2026) membandingkan flash_kda dengan fla_chunk_kda di sequence length T=8192 dan head dimension D=128. Untuk H=96, speedup-nya 1.72x di fixed length, 1.95x di variable length dengan sequence lengths beragam, dan 2.22x di variable length uniform (1024×8).

Untuk H=64, speedup-nya 1.83x di fixed length, 1.80x di variable length beragam, dan 2.18x di uniform variable length. Speedup tertinggi 2.22x muncul di kasus variable-length uniform.

Secara konsisten, FlashKDA ngalahin baseline flash-linear-attention dengan margin signifikan. Satu aspek praktis dari FlashKDA adalah integrasinya yang seamless.

Setelah di-install, FlashKDA auto-dispatched dari chunk_kda di flash-linear-attention. Artinya, codebase yang udah pake flash-linear-attention gak perlu wiring manual buat dapetin kernel yang lebih cepat.

Instalasinya straightforward pakai git clone dan pip install. Ada juga test suite yang verifikasi exact-match against PyTorch reference implementation.

Ada beberapa constraint yang perlu diperhatiin. Saat ini kernel membutuhkan head dimension K = V = 128. Jadi belum flexible untuk ukuran lain. Hardware requirement juga spesifik ke SM90+, jadi kamu butuh GPU Hopper kayak H100 atau H20.

Buat kamu yang develop atau deploy model dengan arsitektur linear attention, FlashKDA ini game-changer. Kamu bisa dapetin speedup hampir 2x tanpa harus ubah architecture atau codebase yang sudah ada.

Cukup install FlashKDA dan pastikan environment kamu pakai CUDA 12.9+ serta PyTorch 2.4+. Kalau kamu pakai flash-linear-attention, upgrade-nya otomatis karena FlashKDA jadi backend drop-in replacement.

Ini solusi ideal buat high-throughput inference serving, apalagi dengan support variable-length batching yang bikin handling real-world request patterns jadi lebih efisien. Worth banget buat dicoba kalau kamu kerja di bidang LLM optimization.

AI Updates lagi bergerak cepat, jadi jangan cuma lihat headline.

MarkTechPost

Catatan redaksi

Kalau lo cuma ambil satu hal dari artikel ini

AI Updates update dari MarkTechPost.

Sumber asli

Artikel ini merupakan rewrite editorial dari laporan MarkTechPost.

Baca artikel asli di MarkTechPost
#AIUpdates#MarkTechPost#rss