Python là ngôn ngữ dễ học, có cú pháp sáng sủa nhưng nó có một khuyết điểm là không tối ưu hiệu năng. Điều đó là hoàn toàn chính xác nhưng đó là một sự đánh đổi. Với những ngôn ngữ như C++, Go hay Java đem lại hiệu năng cao nhưng lại không phù hợp với những công việc đòi hỏi tốc độ phát triển nhanh hay trong những công việc đòi hỏi cú pháp phải sáng sủa, dễ đọc như làm data science, machine learning,... Tuy nhiên, nếu bạn hiểu Python, bạn có thể tối ưu được hiệu năng khá nhiều mà không cần phải tích hợp với các thư viện hỗ trợ hay những kĩ thuật cấp cao. Trong bài viết này, mình sẽ thực hiện tối ưu hiệu năng code Python, qua đó giúp các bạn có thêm kinh nghiệm tối ưu code của bản thân.
Trong bài viết này, mình sẽ trình bày cách mình tối ưu code của một bài toán cụ thể, qua đó mình mong các bạn rút ra được kinh nghiệm cho bản thân. Mình sẽ không hướng tới việc đóng khuôn "phương pháp tối ưu" nào cả. Bài toán mình đưa ra sẽ hoàn toàn là bài toán dạng CPU bound, nghĩa là thiên về xử lí logic bằng CPU chứ không phải IO. Let's go.
Tìm đường đi ngắn nhất. Có lẽ đây là bài toán cơ bản mà bạn đã được học hay ít nhất đã được nghe nói từ khi còn trên giảng đường đại học. Bài toán có để như sau: Cho một ma trận A với m x n phần tử, A[i][j] là phần tử hàng i cột j có giá trị là 0 hoặc 1. Bạn được cho vị trí xuất phát (x1, y1) và vị trí đích (x2, y2). Nhiệm vụ của bạn là kiểm tra xem có thể tìm được đường đi ngắn nhất từ (x1, y1) đến (x2, y2) được không. Lưu ý: bạn chỉ được đi vào những ô A[i][j] = 0 và chỉ được đi theo 4 hướng trên, dưới, trái, phải.
Bài toán này là bài toán điển hình cho thuật toán loang BFS (Breadth First Search) chi tiết cách giải các bạn có thể tìm ở trên Google nhé :smile:.
Mình sẽ cài đặt nó như sau:
from queue import LifoQueue
import time
d = [(0, 1), (1, 0), (-1, 0), (0, -1)]
matrix = [
[0, 1, 1, 1, 1, 0, 0, 0, 1, 1],
[0, 0, 0, 1, 1, 0, 1, 0, 1, 1],
[1, 1, 0, 1, 0, 0, 1, 0, 1, 1],
[1, 1, 0, 1, 0, 1, 1, 0, 1, 1],
[1, 1, 0, 0, 0, 1, 1, 0, 1, 1],
[1, 1, 0, 1, 1, 1, 1, 0, 1, 1],
[1, 1, 0, 0, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
]
def find(start, finish):
visited = [[False] * len(matrix[0]) for _ in range(len(matrix))]
q = LifoQueue()
q.put(start)
while not q.empty():
current = q.get()
if current == finish:
return True
visited[current[0]][current[1]] = True
for dx, dy in d:
nx = current[0] + dx
ny = current[1] + dy
if nx < 0 or nx >= len(matrix):
continue
if ny < 0 or ny >= len(matrix[0]):
continue
if visited[nx][ny] or matrix[nx][ny] == 1:
continue
q.put((nx, ny))
return False
def main():
n = 100_000
start = time.perf_counter()
for i in range(n):
find((0, 0), (9, 9))
print(f"{time.perf_counter()-start:.2f}")
if __name__ == "__main__":
main()
Mình thực hiện chạy thử 100.000 lần và mất khoảng thời gian là 13.77
giây. Vậy hãy cùng phân tích xem, code của mình chậm ở đâu.
Python có một thư viện hỗ trợ debug từng dòng code là line_profiler. Mình sẽ chạy thử:
Bạn nhớ profile hàm find
như này nhé:
@profile
def find(start, finish):
...
$ kernprod -lv main.py
Chúng ta được kết quả như sau:
Wrote profile results to main_1.py.lprof
Timer unit: 1e-06 s
Total time: 83.9397 s
File: main_1.py
Function: find at line 25
Line # Hits Time Per Hit % Time Line Contents
==============================================================
25 @profile
26 def find(start, finish):
27 100000 664982.0 6.6 0.8 visited = [[False] * len(matrix[0]) for _ in range(len(matrix))]
28
29 100000 1866792.0 18.7 2.2 q = LifoQueue()
30 100000 546686.0 5.5 0.7 q.put(start)
31
32 3600000 6476214.0 1.8 7.7 while not q.empty():
33 3600000 18110234.0 5.0 21.6 current = q.get()
34 3500000 1152041.0 0.3 1.4 if current == finish:
35 100000 30313.0 0.3 0.0 return True
36
37 3500000 1755616.0 0.5 2.1 visited[current[0]][current[1]] = True
38 14000000 7216479.0 0.5 8.6 for dx, dy in d:
39 14000000 4656879.0 0.3 5.5 nx = current[0] + dx
40 14000000 4468209.0 0.3 5.3 ny = current[1] + dy
41 13600000 6531169.0 0.5 7.8 if nx < 0 or nx >= len(matrix):
42 400000 94673.0 0.2 0.1 continue
43 13100000 6495185.0 0.5 7.7 if ny < 0 or ny >= len(matrix[0]):
44 500000 120301.0 0.2 0.1 continue
45 9600000 4569949.0 0.5 5.4 if visited[nx][ny] or matrix[nx][ny] == 1:
46 9600000 2290745.0 0.2 2.7 continue
47
48 3500000 16893247.0 4.8 20.1 q.put((nx, ny))
49
50 return False