Numba 教學:加速 Python 科學計算(上)
你能找到最好的中文教學!
鑑於繁體中文資源匱乏,最近剛好又重新看了一下文檔,於是整理資訊分享給大家。本篇的目標讀者是沒學過計算機組織的初階用戶到中階用戶都可以讀,筆者能非常肯定的說這篇文章絕對是你能找到最好的教學。
舊版文檔內容缺失,查看文檔時注意左上角版本號是否為 Stable,偏偏舊版文檔 Google 搜 尋在前面,一不小心就點進去了。
Numba 簡介與比較
Numba 是一個針對 Python 數值和科學計算,使用 LLVM 函式庫優化效能的即時編譯器 (JIT compiler),能顯著提升 Python 執行速度。
Python 速度慢的原因是身為動態語言,運行時需要額外開銷來進行類型檢查,轉譯成字節碼在虛擬機上執行1又多了一層開銷,還有 GIL 的限制進一步影響效能2,於是 Numba 就針對這些問題來解決,以下是它的優化機制:
- 靜態類型推斷:Numba 在編譯時分析程式碼推斷變數類型,避免型別檢查影響效能。
- 即時編譯:將函式編譯3成針對當前 CPU 架構優化的機器碼,並且以 LLVM 優化效能。
- 向量化與平行化:透過 LLVM 使用 SIMD 進行向量化運算,並支援多核平行計算和 CUDA 運算。
Numba 適用於大量包含迴圈的 Numpy 數值運算,但不適合如 pandas 的 I/O 操作。除了 Numba 以外還有其他加速套件,那我們是否該選擇 Numba 呢?這裡列出常見的競爭選項,有 Cython、pybind11、Pythran 和 CuPy,我們從特點討論到性能,最後做出結論。
-
特點
-
效能 從 Python 加速符文 這篇文章中我們可以看到效能4相差不大,除此之外,你能確定文章作者真的知道如何正確該套件嗎5?因此,我們應該考量套件的限制和可維護性,而非單純追求效能極限,否則直接用 C 寫就可以了。
-
結論
經過這些討論我們可以總結成以下- Numba:簡單高效,適合不熟悉程式優化技巧的用戶。缺點是因為太方便所以運作起來像是黑盒子,有時會感到不安心。
- Pythran:搜尋結果只有一萬筆資料,不要折磨自己。
- Cython:麻煩又不見得比較快。最大也是唯一的優點是支援更多 Python 語法,以及對程式行為有更多控制。
- pybind11:適合極限性能要求,對程式行為有完全掌控的用戶。
- CuPy:使用 CUDA,針對大量平行計算場景的最佳選擇。
安裝
安裝 Numba 以及相關的加速套件,包括 SVML (short vector math library) 向量化套件和 tbb/openmp 多線程套件,安裝後不需設定,Numba 會自行調用。
# conda
conda install numba
conda install intel-cmplr-lib-rt
conda install tbb
conda install anaconda::intel-openmp
# pip
pip install numba
pip install intel-cmplr-lib-rt
pip install tbb
pip install intel-openmp
安裝完成後重新啟動終端,使用 numba -s | grep SVML
檢查 SVML 是否成功被 Numba 偵測到,如果沒有,Linux 用戶可以用 sudo ldconfig
刷新 lib 連結。
基礎使用
說是基礎使用,但是已經包含七成的使用情境。
一分鐘學會 Numba
比官方的五分鐘教學又快五倍,夠狠吧。這個範例測試對陣列開根號後加總的速度,比較有沒有使用 Numba 和使用陣列/迴圈這四種方法的執行時間。
import numpy as np
import time
from numba import jit, prange
# Numba Loop
@jit(nopython=True, fastmath=True, parallel=True, nogil=True)
def numba_loop(arr):
bias = 2
total = 0.0
for x in prange(len(arr)): # Numba likes loops
total += np.sqrt(x) # Numba likes numpy
return bias + total # Numba likes broadcasting
# Python Loop,沒有使用裝飾器
def python_loop(arr):
bias = 2
total = 0.0
for x in arr:
total += np.sqrt(x)
return bias + total
# Numba Vector
@jit(nopython=True, fastmath=True, parallel=True, nogil=True)
def numba_arr(arr):
bias = 2
return bias + np.sum(np.sqrt(arr))
# Python Vector,沒有使用裝飾器
def python_arr(arr):
bias = 2
return bias + np.sum(np.sqrt(arr))
n_runs = 1000
n = 10000000
arr = np.arange(n)
# 第一次運行的初始化,第二次以後才是單純的執行時間
result_python_arr = python_arr(arr)
result_numba_arr = numba_arr(arr)
result_numba = numba_loop(arr)
start = time.time()
result_python = python_loop(arr)
end = time.time()
print(f"Python迴圈版本執行時間: {end - start} 秒")
start = time.time()
result_python_arr = python_arr(arr)
end = time.time()
print(f"Python陣列版本執行時間: {end - start} 秒")
start = time.time()
for _ in range(n_runs):
result_numba = numba_loop(arr)
end = time.time()
print(f"Numba迴圈版本執行時間: {(end - start)/n_runs} 秒")
start = time.time()
for _ in range(n_runs):
result_numba_arr = numba_arr(arr)
end = time.time()
print(f"Numba陣列版本執行時間: {(end - start)/n_runs} 秒")
print("Are the outputs equal?", np.isclose(result_numba, result_python))
print("Are the outputs equal?", np.isclose(result_numba_arr, result_python_arr))
# Python迴圈版本執行時間: 9.418870210647583 秒
# Python陣列版本執行時間: 0.021904706954956055 秒
# Numba迴圈版本執行時間: 0.0013016948699951171 秒
# Numba陣列版本執行時間: 0.0024524447917938235 秒
# Are the outputs equal? True
# Are the outputs equal? True
可以看到使用方式很簡單,僅需在想要優化的函式前加上 @jit
裝飾器,接著在要平行化處理的地方顯式的改為 prange
就完成了。裝飾器的選項有以下幾個6:
參數 | 說明 |
---|---|
nopython | 是否嚴格忽略 Python C API 此參數是整篇文章中影響速度最大的因素,使用 @njit 等價於啟用此參數 |
fastmath | 是否放寬 IEEE 754 的精度限制以獲得額外性能 |
parallel | 是否使用迴圈平行運算 |
cache | 是否將編譯結果寫入快取,避免每次呼叫 Python 程式時都需要編譯 |
nogil | 是否關閉全局鎖,關閉後允許在多線程中同時執行多個函式實例 |
效能方面,在這個測試中我們可以看到使用 Numba 後速度可以提升約兩倍,也發現一個有趣的事實:「迴圈版本比陣列版本更快」,這引導我們到第一個重點 Numba likes loops,另外兩個是 Numpy 和 matrix broadcasting。
- 這些選項的效能差異依照函式而有所不同。
- Numba 每次編譯後全域變數會變為常數,在程式中修改該變數不會被函式察覺。
對於暫時不想處理競爭危害的用戶,請先不要使用 parallel
和 nogil
選項。
- 開啟 parallel/nogil 選項時必須小心競爭危害 (race condition)。
簡單解釋競爭危害,兩個線程一起處理一個運算x += 1
,兩個一起取值,結果分別寫回 x 的值都是x+1
導致最終結果是x+1
而不是預期的x+2
。 - 雖然上面的範例顯示結果一致,但還是一定要 避免任何可能 的多線程問題!
進一步優化效能
基礎使用章節已經涵蓋官方文檔中的所有效能優化技巧,這裡補充進階的優化方式。
-
使用 Numba 反而變慢
- 別忘了扣掉首次執行需要消耗的編譯時間。
- 檢查 I/O 瓶頸,不要放任何需要 I/O 的程式碼在函式中。
- 總計算量太小。
- 宣告後就不要修改矩陣維度或型別。
- 語法越簡單越好,不要使用語法等等包裝,因為你不知道 Numba 是否支援。
- 記憶體問題 The wrong way to speed up your code with numba。
- 分支預測問題 Understanding CPUs can help speed up Numba and NumPy code
-
使用
@vectorize
或@guvectorize
向量化
中文教學幾乎沒人提到向量化到底在做什麼。向量化裝飾器除了使函式支援 ufunc 以外還可以大幅提升效能,詳細說明請見教學。 -
使用第三方套件進行效能分析。
fastmath
筆者在這裡簡單的討論一下 fastmath 選項。
雖然 fastmath 在文檔中只說他放寬了 IEEE 754 的精度限制,沒有說到的是他和 SVML 掛勾,但筆者以此 Github issue 進行測試,如果顯示機器碼 movabsq $__svml_atan24
代表安裝成功,此時我們將 fastmath 關閉後發現向量化失敗,偵錯訊息顯示 LV: Found FP op with unsafe algebra.
。
為甚麼敢說本篇是最正確的教學,對於其他文章我就問一句話, 效能測試時有裝 SVML 嗎? 這甚至都不用改程式就可以帶來極大幅度的效能提升,但是筆者從來沒看過任何文章提到過。
如何除錯
Numba 官方文檔有如何除錯的教學,使用 @jit(debug=True)
,詳情請見 Troubleshooting and tips。
另外一個是筆者的土砲方法,當年在寫 Numba 在出現錯誤時 Numba 的報錯資訊不明確,那時的土砲方法是「找到錯誤行數的方式是二分法直接刪程式碼到 Numba 不報錯」
錯誤通常來自於使用 Numba 不支援的函式,除錯請先看函式是否支援以免當冤大頭,再來就是檢查變數型別錯誤,例如誤用不支援相加的不同的變數型別。