1 Star 70 Fork 32

John-逍遥 / android_plugin_readme

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
README_linear_gradient_descent.md 11.27 KB
一键复制 编辑 原始数据 按行查看 历史
John-逍遥 提交于 2021-08-11 15:04 . improve md

梯度下降线性拟合 代码展示

以下只展示关键部分

工程

build.gradle

repositories {
    maven { url 'https://maven.aliyun.com/repository/google' }
    maven { url 'https://maven.aliyun.com/repository/jcenter' }
    maven { url "https://maven.aliyun.com/repository/central" }
    maven { url "https://maven.aliyun.com/repository/gradle-plugin" }
}

gradle 依赖引用

dependencies {
    implementation 'com.github.testpress:MPAndroidChart:v3.0.0-beta2'
}
MainActivity.kt
package com.lujianfei.plugin12_2

import android.content.Context
import android.content.Intent
import android.graphics.Color
import android.net.Uri
import android.os.Handler
import android.os.Looper
import android.view.View
import android.view.inputmethod.InputMethodManager
import android.widget.Button
import android.widget.FrameLayout
import android.widget.TextView
import com.github.testpress.mikephil.charting.charts.LineChart
import com.github.testpress.mikephil.charting.data.Entry
import com.github.testpress.mikephil.charting.data.LineData
import com.github.testpress.mikephil.charting.data.LineDataSet
import com.github.testpress.mikephil.charting.interfaces.datasets.ILineDataSet
import com.lujianfei.module_plugin_base.base.BasePluginActivity
import com.lujianfei.module_plugin_base.beans.PluginActivityBean
import com.lujianfei.module_plugin_base.utils.DensityUtils
import com.lujianfei.module_plugin_base.widget.PluginToolBar

class MainActivity : BasePluginActivity() {
    companion object {
        const val TAG = "MainActivity"
        const val BUBBLE_SIZE = 10f
    }   

    /**
     * 直线图,用于显示拟合的直线
     */
    private var lineChart : LineChart?= null
    /**
     * 线性拟合器
     */
    private val mLinearFit by lazy { LinearFit() }

    private var bt_iteration : View?= null
    private var bt_auto_iteration : Button?= null
    private var txt_formula : TextView?= null
    private var txt_learning_rate : TextView?= null
    private var txt_epoch : TextView?= null
    private var txt_error : TextView?= null
    private val mHandler by lazy { Handler(Looper.getMainLooper()) }

    private val testdata by lazy { MathHelper.getTestData() }

    private var chartContainer: FrameLayout?= null
    
    override fun resouceId(): Int = R.layout.activity_main

    override fun initView() {      
        chartContainer = findViewById(R.id.chartContainer)
        bt_iteration = findViewById(R.id.bt_iteration)
        bt_auto_iteration = findViewById(R.id.bt_auto_iteration)
        txt_formula = findViewById(R.id.txt_formula)
        txt_learning_rate = findViewById(R.id.txt_learning_rate)
        txt_epoch = findViewById(R.id.txt_epoch)
        txt_error = findViewById(R.id.txt_error)

        updateDisplayParams()
        initChartView()
        initChartData()
    }

    private fun updateDisplayParams() {
        // 待拟合散点图
        lineChart?.data?.dataSets?.let {
            val testdataset = lineChart?.data?.dataSets?.get(0) as  LineDataSet
            val testvalues = testdataset.values
            txt_error?.text = "方差:error = ${mLinearFit.variance(testvalues)}"
        }?:let {
            txt_error?.text = "方差:error = --"
        }
        txt_formula?.text = "拟合解析式:f(x) = ${mLinearFit.theta0} + ${mLinearFit.theta1} x"
        txt_learning_rate?.text = "学习率:a = ${mLinearFit.learningRate}"
        txt_epoch?.text = "迭代次数:epoch = ${mLinearFit.epoch}"
    }

    private fun initChartView() {
        that?.let {
            lineChart = LineChart(it)
            lineChart?.apply {
                DensityUtils.getScreenHeight()?.let { screenHeight ->
                    val lp = FrameLayout.LayoutParams(FrameLayout.LayoutParams.MATCH_PARENT,screenHeight / 2 )
                    lp.topMargin = DensityUtils.dip2px(20f)
                    layoutParams = lp
                }
            }
            chartContainer?.addView(lineChart)
        }
    }

    private fun initChartData() {
        val dataSets = arrayListOf<ILineDataSet>()
        //添加数据集
        dataSets.add(initBubbleData()) // 添加散点图
        dataSets.add(initLineData()) // 添加拟合线
        val mLineData = LineData(dataSets)
        mLineData.setDrawValues(false)
        lineChart?.legend?.setCustom(arrayListOf(Color.GREEN,Color.RED), arrayListOf("散点","直线"))
        lineChart?.data = mLineData
    }

    private fun initBubbleData(): LineDataSet {
        val values = arrayListOf<Entry>()
        for (p in testdata) {
            values.add(Entry(p.x, p.y, BUBBLE_SIZE))
        }
        val lineDataSet = LineDataSet(values, "")
        lineDataSet.setCircleColor(Color.GREEN)
        return lineDataSet
    }

    private fun initLineData():LineDataSet {
        val values = arrayListOf<Entry>()
        for (p in testdata) {
            values.add(Entry(p.x, mLinearFit.f(p.x)))
        }
        val mLineDataSet = LineDataSet(values, "")
        mLineDataSet.setCircleColor(Color.RED)
        mLineDataSet.circleRadius = 1f
        return mLineDataSet
    }   

    override fun initEvent() {       
        bt_iteration?.setOnClickListener {
            updateLineData()
            updateDisplayParams()
        }
        bt_auto_iteration?.setOnClickListener {
            it?.isSelected = it?.isSelected != true
            if (it?.isSelected == true) {
                bt_auto_iteration?.text = "停止自动迭代"
                mHandler.postDelayed(taskAutoIteration, 10)
            } else {
                bt_auto_iteration?.text = "自动迭代"
                mHandler.removeCallbacks(taskAutoIteration)
            }
        }
    }

    private val taskAutoIteration  = object : Runnable {
        override fun run() {
            updateLineData()
            updateDisplayParams()
            mHandler.postDelayed(this, 10)
        }
    }

    private fun updateLineData() {
        // 待拟合散点图
        val testdataset = lineChart?.data?.dataSets?.get(0) as  LineDataSet
        val testvalues = testdataset.values

        // 梯度下降更新 theta0, theta1
        mLinearFit.gradientDescent(testvalues)

        // 更新梯度下降后的直线数据
        val dataset = lineChart?.data?.dataSets?.get(1) as  LineDataSet
        val values = dataset.values
        for (value in values) {
            value.y = mLinearFit.f(value.x)
        }
        dataset.values = values
        lineChart?.data?.notifyDataChanged()
        lineChart?.notifyDataSetChanged()
        lineChart?.invalidate()
    }
   
    override fun onPluginDestroy() {
        mHandler.removeCallbacks(taskAutoIteration)
        super.onPluginDestroy()
    }
}
activity_main.xml
<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:orientation="vertical"
    tools:context=".MainActivity">   

    <FrameLayout
        android:id="@+id/chartContainer"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"/>

    <LinearLayout
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_above="@id/control_panel"
        android:paddingStart="10dp"
        android:paddingEnd="10dp"
        android:paddingBottom="10dp"
        android:orientation="vertical">

        <TextView
            android:id="@+id/txt_formula"
            android:layout_width="match_parent"
            android:layout_height="wrap_content"
            android:textSize="18sp"
            android:text="拟合解析式:"/>

        <TextView
            android:id="@+id/txt_learning_rate"
            android:layout_width="match_parent"
            android:layout_height="wrap_content"
            android:textSize="18sp"
            android:text="学习率:"/>

        <TextView
            android:id="@+id/txt_epoch"
            android:layout_width="match_parent"
            android:layout_height="wrap_content"
            android:textSize="18sp"
            android:text="迭代次数:"/>

        <TextView
            android:id="@+id/txt_error"
            android:layout_width="match_parent"
            android:layout_height="wrap_content"
            android:textSize="18sp"
            android:text="方差:"/>

    </LinearLayout>

    <LinearLayout
        android:id="@+id/control_panel"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:orientation="horizontal"
        android:layout_alignParentBottom="true"
        android:layout_marginBottom="30dp"
        android:layout_marginStart="10dp"
        android:layout_marginEnd="10dp">
        <com.lujianfei.module_plugin_base.widget.PluginButton
            android:id="@+id/bt_iteration"
            android:layout_width="0dp"
            android:layout_height="wrap_content"
            android:gravity="center"
            android:layout_weight="1"
            android:text="手动迭代"
            />
        <com.lujianfei.module_plugin_base.widget.PluginButton
            android:id="@+id/bt_auto_iteration"
            android:layout_width="0dp"
            android:layout_height="wrap_content"
            android:gravity="center"
            android:layout_weight="1"
            android:layout_marginStart="10dp"
            android:text="自动迭代"
            />
    </LinearLayout>
</RelativeLayout>
MathHelper.kt
package com.lujianfei.plugin12_2

import android.graphics.PointF

object MathHelper {

    /**
     * 目标解析式
     */
    fun f(x:Float):Float {
        val k = 0.5f
        val b = 5f
        return k * x + b
    }

    /**
     * 获取待拟合的散点数据
     */
    fun getTestData():List<PointF> {
        val testData = arrayListOf<PointF>()
        for (x in 0..50) {
            testData.add(PointF(x.toFloat(),f(x.toFloat()) + (-2..2).random()))
        }
        return testData
    }
}
核心类
LinearFit.kt
package com.lujianfei.plugin12_2

import com.github.testpress.mikephil.charting.data.Entry

/**
 * 线性拟合
 */
class LinearFit {

    val learningRate = 0.001f

    var theta0 = 0f
    var theta1 = 0f
    var epoch = 0

    init {
        theta0 = (-10..10).random().toFloat()
        theta1 = (-10..10).random().toFloat()
    }

    fun f(x:Float):Float {
        return theta0  + theta1 * x
    }

    /**
     * 梯度下降更新 theta0,theta1
     */
    fun gradientDescent(dataset:List<Entry>) {
        var theta0_tmp = theta0
        var theta1_tmp = theta1

        theta0_tmp -= learningRate * 1f/dataset.size * sigma(dataset, algorithm = { f(it.x) - it.y })
        theta1_tmp -= learningRate * 1f/dataset.size * sigma(dataset, algorithm = { (f(it.x) - it.y) * it.x })

        theta0 = theta0_tmp
        theta1 = theta1_tmp

        epoch++
    }

    /**
     * 方差计算
     */
    fun variance(dataset:List<Entry>):Float {
        return sigma(dataset, algorithm = { (f(it.x) - it.y) *  (f(it.x) - it.y) } ) / dataset.size
    }

    private fun sigma(dataset:List<Entry>, algorithm:((Entry)->Float)):Float {
        var result = 0f
        for (value in dataset) {
            result += algorithm.invoke(value)
        }
        return result
    }
}
Android
1
https://gitee.com/lujianfei/android_plugin_readme.git
git@gitee.com:lujianfei/android_plugin_readme.git
lujianfei
android_plugin_readme
android_plugin_readme
master

搜索帮助