425 Star 4.3K Fork 423

GVPPaddlePaddle / Paddle

 / 详情

Add greedy CTC evaluator python API

已完成
创建于  
2021-03-26 16:42

源自github用户wanghaoshuang:
This issue depend on https://github.com/PaddlePaddle/Paddle/pull/7527
CTC evaluator = top-k_op + ctc_align_op + edit_distance_op

Test script:

import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
from paddle.v2.fluid import core

x = fluid.layers.data(name='x', shape=[8], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
ctc_result = fluid.layers.ctc_greedy_decoder(input=x, blank=0)
edit_distance = fluid.evaluator.EditDistance(input=ctc_result,label=y)
print "step1"

place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
print "step2"
edit_distance.reset(exe)
batch_num = 2
for i in range(batch_num):
    print "step3"
    y_data = np.random.randint(0, 8, [7, 1])
    y_lod = [[0, 2, 4, 7]]
    y_tensor = core.LoDTensor()
    y_tensor.set(y_data, place)
    y_tensor.set_lod(y_lod)

    x_data = np.random.uniform(0.1, 1, [11, 8]).astype("float32")
    x_lod = [[0, 3, 5, 11]]
    x_tensor = core.LoDTensor()
    x_tensor.set(x_data, place)
    x_tensor.set_lod(x_lod)

    cost, = exe.run(fluid.default_main_program(),
                              feed={
                                   'x': x_tensor,
                                   'y': y_tensor
                                },
                                fetch_list=edit_distance.metrics)
    pass_error = edit_distance.eval(exe)
    print "cost: %s" % cost
    print "pass_id=%d; pass_error=%s" % (i, str(pass_error))

pass_error = edit_distance.eval(exe)
print "total_pass_error=%s" % str(pass_error)

评论 (0)

PaddlePaddle-Gardener 创建了任务
PaddlePaddle-Coordinator 任务状态待办的 修改为已完成
展开全部操作日志

登录 后才可以发表评论

状态
负责人
里程碑
Pull Requests
关联的 Pull Requests 被合并后可能会关闭此 issue
分支
开始日期   -   截止日期
-
置顶选项
优先级
参与者(1)
Python
1
https://gitee.com/paddlepaddle/Paddle.git
git@gitee.com:paddlepaddle/Paddle.git
paddlepaddle
Paddle
Paddle

搜索帮助