gRPC 服务端多进程优化

文章目录
  1. 1. 例说:一个 OCR 服务
    1. 1.1. gRPC 定义
    2. 1.2. 服务器端
      1. 1.2.1. 保留端口函数 _reserve_port()
      2. 1.2.2. 启动一个服务进程 _run_server()
      3. 1.2.3. 真正的服务 OCRService
    3. 1.3. 客户端
      1. 1.3.1. RPC 调用 compute_detections()
      2. 1.3.2. 初始化
      3. 1.3.3. 单个 RPC 调用 _run_worker_query()

本文通过一个 GitHub 上的仓库一窥如何在 Python 上实现 gRPC 多进程优化,从而提升整体 RPC 速度。

例说:一个 OCR 服务

我们试图从一份最简代码中了解多进程的最佳实践。本节将梳理这份代码的逻辑。

gRPC 定义

我们这个最简示例是一个 OCR 服务。客户端接收图片,发给远端,远端识别好得到文本后回传客户端,结束。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
syntax = "proto3";

package ocr;

// A candidate image for OCR extraction.
message OcrCandidate {
// binaries of the openCV image.
bytes image = 1;
}

// The detected text of the requested image candidate.
message OcrResult {
// text detected on the image
string text = 1;
}

// Service to perform OCR over image.
service OCR {
// Determines the text on an input image
rpc Detect (OcrCandidate) returns (OcrResult) {}
}

所以也只有一个 RPC 函数:接收一张图片,返回一段文本——极其简单。

服务器端

gRPC-multiprocessing/server.py at main · fpaupier/gRPC-multiprocessing (github.com)

切入点很简单,直接看 main 函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def main():
"""
Inspired from https://github.com/grpc/grpc/blob/master/examples/python/multiprocessing/server.py
"""
logger.info(f"Initializing server with {NUM_WORKERS} workers")
with _reserve_port() as port:
bind_address = f"[::]:{port}"
logger.info(f"Binding to {bind_address}")
sys.stdout.flush()
workers = []
for _ in range(NUM_WORKERS):
worker = multiprocessing.Process(target=_run_server, args=(bind_address,))
worker.start()
workers.append(worker)
for worker in workers:
worker.join()

其中 _reserve_port() 是一个保留端口的函数,他会返回一个预定好的端口,后面会细说。_run_server 是启动一个 gRPC 服务的函数,会拉起一个内含一个处理线程的 gRPC 服务端对象。bind_address 其实就是把所有 IP 的保留那个端口都锁住了。

保留端口函数 _reserve_port()

直接上代码。

1
2
3
4
5
6
7
8
9
10
11
12
@contextlib.contextmanager
def _reserve_port():
"""Find and reserve a port for all subprocesses to use"""
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0:
raise RuntimeError("Failed to set SO_REUSEPORT.")
sock.bind(("", 13000))
try:
yield sock.getsockname()[1]
finally:
sock.close()

这里默认取 13000 端口作为保留端口。你应该已经意识到了,服务器虽然拉起了一大堆 gRPC 服务进程,但他们都是共用同一个端口的,这个函数除了预定 13000 端口外,还在系统中声明这个端口可被多个进程复用。

启动一个服务进程 _run_server()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def _run_server(bind_address):
logger.debug(f"Server started. Awaiting jobs...")
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=1),
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
("grpc.so_reuseport", 1),
("grpc.use_local_subchannel_pool", 1),
],
)
image_ocr_pb2_grpc.add_OCRServicer_to_server(OCRService, server)
server.add_insecure_port(bind_address)
server.start()
server.wait_for_termination()

可以看到,这个和我们一般的启动 gRPC 的方法不一样。它是先启动了一个通用的白板 gRPC 服务进程,然后加了一些特别的参数设置,最为关键的是允许端口复用 reuse_port。然后再把我们真正的 RPC 定义添加到白板进程里面去,最后启动之。

真正的服务 OCRService

1
2
3
4
class OCRService(image_ocr_pb2_grpc.OCRServicer):
@staticmethod
def Detect(request: image_ocr_pb2.OcrCandidate, context):
return image_ocr_pb2.OcrResult(text=get_text_from_image(request.image))

其实我还得看一下我是怎么添加 gRPC 服务的,怎么还需要手动写一个对象呢?

客户端

同理,先看主函数。我把一些 logging 的东西删掉之后,就剩两句了:

1
2
3
def run():
batch = prepare_batch()
results = compute_detections(batch)

其中 prepare_batch() 是读一系列图片,然后存成一个 List[bytes],作为一个 batch 以一起丢给远端处理——就是多进程嘛。compute_detections() 是真正的 RPC 调用。

RPC 调用 compute_detections()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def compute_detections(batch: tp.List[bytes]) -> tp.List[str]:
"""
Start a pool of process to parallelize data processing across several workers.
Args:
batch: a list of images.
Returns:
the list of detected texts.
"""
server_address = "server:13000"
with multiprocessing.Pool(
processes=NUM_CLIENTS,
initializer=_initialize_worker,
initargs=(server_address,),
) as worker_pool:
ocr_results = worker_pool.map(
_run_worker_query, [pickle.dumps(img) for img in batch]
)
return [txt for txt in ocr_results]

客户端也是多进程,每个进程都是连的同一个地址,但每个进程收到的输入是不一样的,只是 batch 中的某一行而已。关键看每个里面是怎么运作的,也就是 _run_worker_query。

同时还藏了一个 initializer=_initialize_worker,这个会在每个进程运行起来后自己进行初始化。

初始化

这个过程涉及一个全局变量 _worker_stub_singleton,它在文件的一开始定义了:

1
2
_worker_channel_singleton = None
_worker_stub_singleton = None

然后是函数主体:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def _initialize_worker(server_address: str) -> None:
"""
Setup a grpc stub if not available.
Args:
server_address (str)
Returns:
None
"""
global _worker_channel_singleton
global _worker_stub_singleton
_worker_channel_singleton = grpc.insecure_channel(
server_address,
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
("grpc.so_reuseport", 1),
("grpc.use_local_subchannel_pool", 1),
],
)
_worker_stub_singleton = image_ocr_pb2_grpc.OCRStub(_worker_channel_singleton)
atexit.register(_shutdown_worker)

atexit 是 import 来的:import atexit。据 Python 官方文档,atexit 模块定义清理函数的注册和反注册函数,被注册的函数会在解释器正常终止时执行.

其中关闭连接的函数 _shutdown_worker 的定义很简单:

1
2
3
4
5
6
7
8
def _shutdown_worker():
"""
Close the open gRPC channel.
Returns:
None
"""
if _worker_channel_singleton is not None:
_worker_channel_singleton.stop()

单个 RPC 调用 _run_worker_query()

1
2
3
4
5
6
7
8
9
10
11
12
def _run_worker_query(img: bytes) -> str:
"""
Execute the call to the gRPC server.
Args:
img (bytes): bytes representation of the image
Returns:
detected text on the image
"""
response: image_ocr_pb2.OcrResult = _worker_stub_singleton.Detect(
image_ocr_pb2.OcrCandidate(image=img)
)
return response.text