# C-LSTM **Repository Path**: han--leng/C-LSTM ## Basic Information - **Project Name**: C-LSTM - **Description**: C 语言实现 LSTM 算法 GITHUB https://github.com/az13js-org/C-LSTM - **Primary Language**: C - **License**: MIT - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 1 - **Forks**: 4 - **Created**: 2021-11-14 - **Last Updated**: 2022-05-24 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # C 语言实现 LSTM 算法 ## LSTM 算法简介 LSTM 全称是 "Long Short-Term Memory",一种用来学习大量时序序列中隐含的相关性并用于预测其可能的趋势的机器学习算法。它的应用范围包括但不局限于价格走势预测、估计剩余寿命、分析语言的情感趋向、自动写作和语音合成。 算法描述了一个用于计算的工作单元,它按照时间顺序接受自然数作为输入,通过计算得到对应的输出。当一系列输入计算完成时,也就得到了对应的一个输出序列。单个 LSTM 单元的学习能力是有限的,可以将输出的序列作为输入序列给另外一个 LSTM 作为输入,通过这种方式利用多个 LSTM 单元的组合提高整体学习能力。 ## LSTM 的计算过程 作为机器学习算法的一种,LSTM 的应用包括利用大量数据进行训练和根据训练得到的参数预测两个步骤。其预测过程使用正向传播算法,训练过程采用误差反向传播算法。 ### LSTM 的正向传播算法 一个处理 1 维序列的 LSTM 单元有 12 个参数。这里将这 12 个参数表示为:  假设输入序列 `x` 和输出序列 `h` 分别有 n 个元素,表示为:  每次计算中还会产生以下的临时变量:  最后 LSTM 的计算可以表示为:  里面的一个函数符号表示 Sigmoid 函数:  ### LSTM 的误差反向传播算法 假设期望输出为:  采用均方误差(MSE,Mean Squard Error)来评估实际输出与期望输出的误差:  那么在一次正向传播后,LSTM 输出序列 h 的每个元素对 E 的影响可以用下面的一阶偏导数表示:  进行误差反向传播需要使用 E 对 12 个参数的每一个的偏导数。与普通的神经网络算法不同的是,LSTM 利用 C(t-1) 和 h(t-1) 参与第 t 次的计算,使得第 t 次之前的计算结果会对第 t 次的输出 h(t) 产生影响。 由于:  所以 t = 1 时:  当 t > 1 时:             为方便计算,激活函数导数可取:  最后:  一般情况下学习率取值:  采用简单的梯度下降,可以在正向传播后修正参数:  ## C 实现方法 用结构体来保存计算过程所需的变量,并提供一个函数用来初始化并返回这个结构体。后续提供一系列的函数用于操作这个结构体。 ### 数据结构 结构体中变量名称和算法的参数之间的对应关系是:
| 结构体变量 | 定义 | 对应算法中的变量 |
|---|---|---|
| 整数 length | int length; | 表示 LSTM 计算序列长度 |
| 浮点数指针 x | double *x; | 输入序列 x |
| 浮点数指针 h | double *h; | 输出序列 h |
| 浮点数指针 f | double *f; | 中间变量序列 f |
| 浮点数指针 i | double *i; | 中间变量序列 i |
| 浮点数指针 tilde_C | double *tilde_C; | 中间变量序列 ![]() |
| 浮点数指针 C | double *C; | 中间变量序列 C |
| 浮点数指针 o | double *o; | 中间变量序列 o |
| 浮点数指针 hat_h | double *hat_h; | 期望输出序列 ![]() |
| 浮点数 W_fh | double W_fh; | 参数 ![]() |
| 浮点数 W_fx | double W_fx; | 参数 ![]() |
| 浮点数 b_f | double b_f; | 参数 ![]() |
| 浮点数 W_ih | double W_ih; | 参数 ![]() |
| 浮点数 W_ix | double W_ix; | 参数 ![]() |
| 浮点数 b_i | double b_i; | 参数 ![]() |
| 浮点数 W_Ch | double W_Ch; | 参数 ![]() |
| 浮点数 W_Cx | double W_Cx; | 参数 ![]() |
| 浮点数 b_C | double b_C; | 参数 ![]() |
| 浮点数 W_oh | double W_oh; | 参数 ![]() |
| 浮点数 W_ox | double W_ox; | 参数 ![]() |
| 浮点数 b_o | double b_o; | 参数 ![]() |
| 结构体变量 | 定义 | 说明 |
|---|---|---|
| 整数 error_no | int error_no; | 错误号,无错误默认0。用于记录最后一次程序发生的错误。 |
| 字符指针 error_msg | char *error_msg; | 发生的错误的文字说明,默认无错误,内容为指向字符串"\0"的指针。 |
| 错误号 | 信息 | 说明 |
|---|---|---|
| 0 | "\0" | 无错误。 |
| 1 | "not enough memory" | 内存不足。 |