代码拉取完成,页面将自动刷新
以下只展示关键部分
build.gradle
repositories {
maven { url 'https://maven.aliyun.com/repository/google' }
maven { url 'https://maven.aliyun.com/repository/jcenter' }
maven { url "https://maven.aliyun.com/repository/central" }
maven { url "https://maven.aliyun.com/repository/gradle-plugin" }
}
dependencies {
implementation 'com.github.testpress:MPAndroidChart:v3.0.0-beta2'
}
package com.lujianfei.plugin12_2
import android.content.Context
import android.content.Intent
import android.graphics.Color
import android.net.Uri
import android.os.Handler
import android.os.Looper
import android.view.View
import android.view.inputmethod.InputMethodManager
import android.widget.Button
import android.widget.FrameLayout
import android.widget.TextView
import com.github.testpress.mikephil.charting.charts.LineChart
import com.github.testpress.mikephil.charting.data.Entry
import com.github.testpress.mikephil.charting.data.LineData
import com.github.testpress.mikephil.charting.data.LineDataSet
import com.github.testpress.mikephil.charting.interfaces.datasets.ILineDataSet
import com.lujianfei.module_plugin_base.base.BasePluginActivity
import com.lujianfei.module_plugin_base.beans.PluginActivityBean
import com.lujianfei.module_plugin_base.utils.DensityUtils
import com.lujianfei.module_plugin_base.widget.PluginToolBar
class MainActivity : BasePluginActivity() {
companion object {
const val TAG = "MainActivity"
const val BUBBLE_SIZE = 10f
}
/**
* 直线图,用于显示拟合的直线
*/
private var lineChart : LineChart?= null
/**
* 线性拟合器
*/
private val mLinearFit by lazy { LinearFit() }
private var bt_iteration : View?= null
private var bt_auto_iteration : Button?= null
private var txt_formula : TextView?= null
private var txt_learning_rate : TextView?= null
private var txt_epoch : TextView?= null
private var txt_error : TextView?= null
private val mHandler by lazy { Handler(Looper.getMainLooper()) }
private val testdata by lazy { MathHelper.getTestData() }
private var chartContainer: FrameLayout?= null
override fun resouceId(): Int = R.layout.activity_main
override fun initView() {
chartContainer = findViewById(R.id.chartContainer)
bt_iteration = findViewById(R.id.bt_iteration)
bt_auto_iteration = findViewById(R.id.bt_auto_iteration)
txt_formula = findViewById(R.id.txt_formula)
txt_learning_rate = findViewById(R.id.txt_learning_rate)
txt_epoch = findViewById(R.id.txt_epoch)
txt_error = findViewById(R.id.txt_error)
updateDisplayParams()
initChartView()
initChartData()
}
private fun updateDisplayParams() {
// 待拟合散点图
lineChart?.data?.dataSets?.let {
val testdataset = lineChart?.data?.dataSets?.get(0) as LineDataSet
val testvalues = testdataset.values
txt_error?.text = "方差:error = ${mLinearFit.variance(testvalues)}"
}?:let {
txt_error?.text = "方差:error = --"
}
txt_formula?.text = "拟合解析式:f(x) = ${mLinearFit.theta0} + ${mLinearFit.theta1} x"
txt_learning_rate?.text = "学习率:a = ${mLinearFit.learningRate}"
txt_epoch?.text = "迭代次数:epoch = ${mLinearFit.epoch}"
}
private fun initChartView() {
that?.let {
lineChart = LineChart(it)
lineChart?.apply {
DensityUtils.getScreenHeight()?.let { screenHeight ->
val lp = FrameLayout.LayoutParams(FrameLayout.LayoutParams.MATCH_PARENT,screenHeight / 2 )
lp.topMargin = DensityUtils.dip2px(20f)
layoutParams = lp
}
}
chartContainer?.addView(lineChart)
}
}
private fun initChartData() {
val dataSets = arrayListOf<ILineDataSet>()
//添加数据集
dataSets.add(initBubbleData()) // 添加散点图
dataSets.add(initLineData()) // 添加拟合线
val mLineData = LineData(dataSets)
mLineData.setDrawValues(false)
lineChart?.legend?.setCustom(arrayListOf(Color.GREEN,Color.RED), arrayListOf("散点","直线"))
lineChart?.data = mLineData
}
private fun initBubbleData(): LineDataSet {
val values = arrayListOf<Entry>()
for (p in testdata) {
values.add(Entry(p.x, p.y, BUBBLE_SIZE))
}
val lineDataSet = LineDataSet(values, "")
lineDataSet.setCircleColor(Color.GREEN)
return lineDataSet
}
private fun initLineData():LineDataSet {
val values = arrayListOf<Entry>()
for (p in testdata) {
values.add(Entry(p.x, mLinearFit.f(p.x)))
}
val mLineDataSet = LineDataSet(values, "")
mLineDataSet.setCircleColor(Color.RED)
mLineDataSet.circleRadius = 1f
return mLineDataSet
}
override fun initEvent() {
bt_iteration?.setOnClickListener {
updateLineData()
updateDisplayParams()
}
bt_auto_iteration?.setOnClickListener {
it?.isSelected = it?.isSelected != true
if (it?.isSelected == true) {
bt_auto_iteration?.text = "停止自动迭代"
mHandler.postDelayed(taskAutoIteration, 10)
} else {
bt_auto_iteration?.text = "自动迭代"
mHandler.removeCallbacks(taskAutoIteration)
}
}
}
private val taskAutoIteration = object : Runnable {
override fun run() {
updateLineData()
updateDisplayParams()
mHandler.postDelayed(this, 10)
}
}
private fun updateLineData() {
// 待拟合散点图
val testdataset = lineChart?.data?.dataSets?.get(0) as LineDataSet
val testvalues = testdataset.values
// 梯度下降更新 theta0, theta1
mLinearFit.gradientDescent(testvalues)
// 更新梯度下降后的直线数据
val dataset = lineChart?.data?.dataSets?.get(1) as LineDataSet
val values = dataset.values
for (value in values) {
value.y = mLinearFit.f(value.x)
}
dataset.values = values
lineChart?.data?.notifyDataChanged()
lineChart?.notifyDataSetChanged()
lineChart?.invalidate()
}
override fun onPluginDestroy() {
mHandler.removeCallbacks(taskAutoIteration)
super.onPluginDestroy()
}
}
<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical"
tools:context=".MainActivity">
<FrameLayout
android:id="@+id/chartContainer"
android:layout_width="match_parent"
android:layout_height="wrap_content"/>
<LinearLayout
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_above="@id/control_panel"
android:paddingStart="10dp"
android:paddingEnd="10dp"
android:paddingBottom="10dp"
android:orientation="vertical">
<TextView
android:id="@+id/txt_formula"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:textSize="18sp"
android:text="拟合解析式:"/>
<TextView
android:id="@+id/txt_learning_rate"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:textSize="18sp"
android:text="学习率:"/>
<TextView
android:id="@+id/txt_epoch"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:textSize="18sp"
android:text="迭代次数:"/>
<TextView
android:id="@+id/txt_error"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:textSize="18sp"
android:text="方差:"/>
</LinearLayout>
<LinearLayout
android:id="@+id/control_panel"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:orientation="horizontal"
android:layout_alignParentBottom="true"
android:layout_marginBottom="30dp"
android:layout_marginStart="10dp"
android:layout_marginEnd="10dp">
<com.lujianfei.module_plugin_base.widget.PluginButton
android:id="@+id/bt_iteration"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:gravity="center"
android:layout_weight="1"
android:text="手动迭代"
/>
<com.lujianfei.module_plugin_base.widget.PluginButton
android:id="@+id/bt_auto_iteration"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:gravity="center"
android:layout_weight="1"
android:layout_marginStart="10dp"
android:text="自动迭代"
/>
</LinearLayout>
</RelativeLayout>
package com.lujianfei.plugin12_2
import android.graphics.PointF
object MathHelper {
/**
* 目标解析式
*/
fun f(x:Float):Float {
val k = 0.5f
val b = 5f
return k * x + b
}
/**
* 获取待拟合的散点数据
*/
fun getTestData():List<PointF> {
val testData = arrayListOf<PointF>()
for (x in 0..50) {
testData.add(PointF(x.toFloat(),f(x.toFloat()) + (-2..2).random()))
}
return testData
}
}
package com.lujianfei.plugin12_2
import com.github.testpress.mikephil.charting.data.Entry
/**
* 线性拟合
*/
class LinearFit {
val learningRate = 0.001f
var theta0 = 0f
var theta1 = 0f
var epoch = 0
init {
theta0 = (-10..10).random().toFloat()
theta1 = (-10..10).random().toFloat()
}
fun f(x:Float):Float {
return theta0 + theta1 * x
}
/**
* 梯度下降更新 theta0,theta1
*/
fun gradientDescent(dataset:List<Entry>) {
var theta0_tmp = theta0
var theta1_tmp = theta1
theta0_tmp -= learningRate * 1f/dataset.size * sigma(dataset, algorithm = { f(it.x) - it.y })
theta1_tmp -= learningRate * 1f/dataset.size * sigma(dataset, algorithm = { (f(it.x) - it.y) * it.x })
theta0 = theta0_tmp
theta1 = theta1_tmp
epoch++
}
/**
* 方差计算
*/
fun variance(dataset:List<Entry>):Float {
return sigma(dataset, algorithm = { (f(it.x) - it.y) * (f(it.x) - it.y) } ) / dataset.size
}
private fun sigma(dataset:List<Entry>, algorithm:((Entry)->Float)):Float {
var result = 0f
for (value in dataset) {
result += algorithm.invoke(value)
}
return result
}
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。