Ai
1 Star 0 Fork 0

李静/PaddleOCR

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
test_cls_postprocess.py 2.13 KB
一键复制 编辑 原始数据 按行查看 历史
Wang Xin 提交于 2024-05-31 11:22 +08:00 . add test for cls_postprocess (#12534)
import os
import sys
import pytest
import paddle
import numpy as np
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(current_dir, "..")))
from ppocr.postprocess.cls_postprocess import ClsPostProcess
# Fixtures for common test inputs
@pytest.fixture
def preds_tensor():
return paddle.to_tensor(np.array([[0.1, 0.7, 0.2], [0.3, 0.3, 0.4]]))
@pytest.fixture
def label_list():
return {0: "class0", 1: "class1", 2: "class2"}
# Parameterize tests to cover multiple scenarios
@pytest.mark.parametrize(
"label_list, expected",
[
({0: "class0", 1: "class1", 2: "class2"}, [("class1", 0.7), ("class2", 0.4)]),
(None, [(1, 0.7), (2, 0.4)]),
],
)
def test_cls_post_process_with_and_without_label_list(
preds_tensor, label_list, expected
):
post_process = ClsPostProcess(label_list=label_list)
result = post_process(preds_tensor)
assert isinstance(result, list), "Result should be a list"
assert result == expected, f"Expected {expected}, got {result}"
# Test with a key in the prediction dictionary
def test_cls_post_process_with_key(preds_tensor, label_list):
preds_dict = {"key": preds_tensor}
post_process = ClsPostProcess(label_list=label_list, key="key")
result = post_process(preds_dict)
expected = [("class1", 0.7), ("class2", 0.4)]
assert isinstance(result, list), "Result should be a list"
assert result == expected, f"Expected {expected}, got {result}"
# Test with label input
def test_cls_post_process_with_label(preds_tensor, label_list):
labels = [2, 0]
post_process = ClsPostProcess(label_list=label_list)
result, label_result = post_process(preds_tensor, labels)
expected_result = [("class1", 0.7), ("class2", 0.4)]
expected_label_result = [("class2", 1.0), ("class0", 1.0)]
assert isinstance(result, list), "Result should be a list"
assert result == expected_result, f"Expected {expected_result}, got {result}"
assert isinstance(label_result, list), "Label result should be a list"
assert (
label_result == expected_label_result
), f"Expected {expected_label_result}, got {label_result}"
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/coder-lijing/PaddleOCR.git
git@gitee.com:coder-lijing/PaddleOCR.git
coder-lijing
PaddleOCR
PaddleOCR
main

搜索帮助