Tối ưu code Python mà không cần dùng thư viện bên thứ ba

Thứ tư, ngày 5 tháng 7 năm 2023

Python là ngôn ngữ dễ học, có cú pháp sáng sủa nhưng nó có một khuyết điểm là không tối ưu hiệu năng. Điều đó là hoàn toàn chính xác nhưng đó là một sự đánh đổi. Với những ngôn ngữ như C++, Go hay Java đem lại hiệu năng cao nhưng lại không phù hợp với những công việc đòi hỏi tốc độ phát triển nhanh hay trong những công việc đòi hỏi cú pháp phải sáng sủa, dễ đọc như làm data science, machine learning,... Tuy nhiên, nếu bạn hiểu Python, bạn có thể tối ưu được hiệu năng khá nhiều mà không cần phải tích hợp với các thư viện hỗ trợ hay những kĩ thuật cấp cao. Trong bài viết này, mình sẽ thực hiện tối ưu hiệu năng code Python, qua đó giúp các bạn có thêm kinh nghiệm tối ưu code của bản thân.

Trong bài viết này, mình sẽ trình bày cách mình tối ưu code của một bài toán cụ thể, qua đó mình mong các bạn rút ra được kinh nghiệm cho bản thân. Mình sẽ không hướng tới việc đóng khuôn "phương pháp tối ưu" nào cả. Bài toán mình đưa ra sẽ hoàn toàn là bài toán dạng CPU bound, nghĩa là thiên về xử lí logic bằng CPU chứ không phải IO. Let's go.

Bài toán

Tìm đường đi ngắn nhất. Có lẽ đây là bài toán cơ bản mà bạn đã được học hay ít nhất đã được nghe nói từ khi còn trên giảng đường đại học. Bài toán có để như sau: Cho một ma trận A với m x n phần tử, A[i][j] là phần tử hàng i cột j có giá trị là 0 hoặc 1. Bạn được cho vị trí xuất phát (x1, y1) và vị trí đích (x2, y2). Nhiệm vụ của bạn là kiểm tra xem có thể tìm được đường đi ngắn nhất từ (x1, y1) đến (x2, y2) được không. Lưu ý: bạn chỉ được đi vào những ô A[i][j] = 0 và chỉ được đi theo 4 hướng trên, dưới, trái, phải.

Bài toán này là bài toán điển hình cho thuật toán loang BFS (Breadth First Search) chi tiết cách giải các bạn có thể tìm ở trên Google nhé :smile:.

Mình sẽ cài đặt nó như sau:

from queue import LifoQueue
import time


d = [(0, 1), (1, 0), (-1, 0), (0, -1)]
matrix = [
    [0, 1, 1, 1, 1, 0, 0, 0, 1, 1],
    [0, 0, 0, 1, 1, 0, 1, 0, 1, 1],
    [1, 1, 0, 1, 0, 0, 1, 0, 1, 1],
    [1, 1, 0, 1, 0, 1, 1, 0, 1, 1],
    [1, 1, 0, 0, 0, 1, 1, 0, 1, 1],
    [1, 1, 0, 1, 1, 1, 1, 0, 1, 1],
    [1, 1, 0, 0, 1, 1, 1, 0, 0, 0],
    [1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
    [1, 1, 1, 0, 0, 0, 0, 0, 1, 0],
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
]


def find(start, finish):
    visited = [[False] * len(matrix[0]) for _ in range(len(matrix))]

    q = LifoQueue()
    q.put(start)

    while not q.empty():
        current = q.get()
        if current == finish:
            return True

        visited[current[0]][current[1]] = True
        for dx, dy in d:
            nx = current[0] + dx
            ny = current[1] + dy
            if nx < 0 or nx >= len(matrix):
                continue
            if ny < 0 or ny >= len(matrix[0]):
                continue
            if visited[nx][ny] or matrix[nx][ny] == 1:
                continue

            q.put((nx, ny))

    return False


def main():
    n = 100_000

    start = time.perf_counter()
    for i in range(n):
        find((0, 0), (9, 9))

    print(f"{time.perf_counter()-start:.2f}")


if __name__ == "__main__":
    main()

Mình thực hiện chạy thử 100.000 lần và mất khoảng thời gian là 13.77 giây. Vậy hãy cùng phân tích xem, code của mình chậm ở đâu.

Python có một thư viện hỗ trợ debug từng dòng code là line_profiler. Mình sẽ chạy thử:

Bạn nhớ profile hàm find như này nhé:

@profile
def find(start, finish):
    ...
$ kernprod -lv main.py

Chúng ta được kết quả như sau:

Wrote profile results to main_1.py.lprof
Timer unit: 1e-06 s

Total time: 83.9397 s
File: main_1.py
Function: find at line 25

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    25                                           @profile
    26                                           def find(start, finish):
    27    100000     664982.0      6.6      0.8      visited = [[False] * len(matrix[0]) for _ in range(len(matrix))]
    28
    29    100000    1866792.0     18.7      2.2      q = LifoQueue()
    30    100000     546686.0      5.5      0.7      q.put(start)
    31
    32   3600000    6476214.0      1.8      7.7      while not q.empty():
    33   3600000   18110234.0      5.0     21.6          current = q.get()
    34   3500000    1152041.0      0.3      1.4          if current == finish:
    35    100000      30313.0      0.3      0.0              return True
    36
    37   3500000    1755616.0      0.5      2.1          visited[current[0]][current[1]] = True
    38  14000000    7216479.0      0.5      8.6          for dx, dy in d:
    39  14000000    4656879.0      0.3      5.5              nx = current[0] + dx
    40  14000000    4468209.0      0.3      5.3              ny = current[1] + dy
    41  13600000    6531169.0      0.5      7.8              if nx < 0 or nx >= len(matrix):
    42    400000      94673.0      0.2      0.1                  continue
    43  13100000    6495185.0      0.5      7.7              if ny < 0 or ny >= len(matrix[0]):
    44    500000     120301.0      0.2      0.1                  continue
    45   9600000    4569949.0      0.5      5.4              if visited[nx][ny] or matrix[nx][ny] == 1:
    46   9600000    2290745.0      0.2      2.7                  continue
    47
    48   3500000   16893247.0      4.8     20.1              q.put((nx, ny))
    49
    50                                               return False