From 3749e01e5a68368d39bffa2ad74d6188c8924aea Mon Sep 17 00:00:00 2001 From: wangfeng Date: Wed, 8 Feb 2023 15:45:00 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=A2=9E=E5=8A=A03=E4=B8=AA=E5=87=BD?= =?UTF-8?q?=E6=95=B0,=20ArgMax,=20ArgMin,=20Median,=20=E5=85=B6=E4=B8=ADMe?= =?UTF-8?q?dian=E6=9C=AA=E5=8A=A0=E9=AA=8C=E8=AF=81,=20=E5=85=88=E5=B8=A6?= =?UTF-8?q?=E4=B8=8A.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- stat/argmax.go | 33 +++++++++++++++++++++++++++++++++ stat/argmax_test.go | 15 +++++++++++++++ stat/argmin.go | 33 +++++++++++++++++++++++++++++++++ stat/argmin_test.go | 15 +++++++++++++++ stat/median.go | 21 +++++++++++++++++++++ 5 files changed, 117 insertions(+) create mode 100644 stat/argmax.go create mode 100644 stat/argmax_test.go create mode 100644 stat/argmin.go create mode 100644 stat/argmin_test.go create mode 100644 stat/median.go diff --git a/stat/argmax.go b/stat/argmax.go new file mode 100644 index 0000000..27b4029 --- /dev/null +++ b/stat/argmax.go @@ -0,0 +1,33 @@ +package stat + +import ( + "github.com/viterin/vek" + "github.com/viterin/vek/vek32" +) + +// ArgMax Returns the indices of the maximum values along an axis. +// 返回轴上最大值的索引 +func ArgMax[T Number](v []T) int { + var vv any = v + switch values := vv.(type) { + case []float32: + return vek32.ArgMax(values) + case []float64: + return vek.ArgMax(values) + default: + return __arg_max(v) + } + +} + +func __arg_max[T Number](x []T) int { + max := x[0] + idx := 0 + for i, v := range x[1:] { + if v > max { + max = v + idx = 1 + i + } + } + return idx +} diff --git a/stat/argmax_test.go b/stat/argmax_test.go new file mode 100644 index 0000000..c780001 --- /dev/null +++ b/stat/argmax_test.go @@ -0,0 +1,15 @@ +package stat + +import ( + "fmt" + "testing" +) + +func TestArgMax(t *testing.T) { + n1 := []float32{1.1, 2.2, 1.3, 1.4} + n2 := []float64{1.2, 1.2, 3.3} + n3 := []int64{11, 12, 33} + fmt.Println(ArgMax(n1)) + fmt.Println(ArgMax(n2)) + fmt.Println(ArgMax(n3)) +} diff --git a/stat/argmin.go b/stat/argmin.go new file mode 100644 index 0000000..b9b8b0f --- /dev/null +++ b/stat/argmin.go @@ -0,0 +1,33 @@ +package stat + +import ( + "github.com/viterin/vek" + "github.com/viterin/vek/vek32" +) + +// ArgMin Returns the indices of the minimum values along an axis. +// 返回轴上最小值的索引 +func ArgMin[T Number](v []T) int { + var vv any = v + switch values := vv.(type) { + case []float32: + return vek32.ArgMin(values) + case []float64: + return vek.ArgMin(values) + default: + return __arg_min(v) + } + +} + +func __arg_min[T Number](x []T) int { + min := x[0] + idx := 0 + for i, v := range x[1:] { + if v < min { + min = v + idx = 1 + i + } + } + return idx +} diff --git a/stat/argmin_test.go b/stat/argmin_test.go new file mode 100644 index 0000000..1f61bb7 --- /dev/null +++ b/stat/argmin_test.go @@ -0,0 +1,15 @@ +package stat + +import ( + "fmt" + "testing" +) + +func TestArgMin(t *testing.T) { + n1 := []float32{1.1, 2.2, 1.3, 1.4} + n2 := []float64{1.2, 1.2, 0.3} + n3 := []int64{11, 12, 33} + fmt.Println(ArgMin(n1)) + fmt.Println(ArgMin(n2)) + fmt.Println(ArgMin(n3)) +} diff --git a/stat/median.go b/stat/median.go new file mode 100644 index 0000000..6be0964 --- /dev/null +++ b/stat/median.go @@ -0,0 +1,21 @@ +package stat + +// Median returns median value of series. +// Linear interpolation is used for odd length. +// TODO:未加验证 +func Median[T StatType](values []T) DType { + if len(values) == 0 { + return DTypeNaN + } + + if len(values) == 1 { + return DType(0) + } + + if len(values)%2 == 0 { + i := len(values) / 2 + return DType(values[i-1]+values[i]) / 2 + } + + return DType(values[len(values)/2]) +} -- Gitee From bae9267ab5bd45d98e0d2096898daf0a4356e09c Mon Sep 17 00:00:00 2001 From: wangfeng Date: Wed, 8 Feb 2023 15:46:38 +0800 Subject: [PATCH 2/2] =?UTF-8?q?#I6DOBI=20=E5=AE=9E=E7=8E=B0BARSSINCEN?= =?UTF-8?q?=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- formula/README.md | 2 +- formula/barssincen.go | 26 ++++++++++++++++++++++++++ formula/barssincen_test.go | 25 +++++++++++++++++++++++++ generic.go | 9 ++++++++- generic_apply.go | 23 +++++++++++++++++++++++ generic_rolling.go | 18 ++++++++++++++++-- series.go | 6 +++++- series_bool.go | 3 ++- series_float64.go | 3 ++- series_int64.go | 3 ++- series_xstring.go | 4 +++- stat/type_dtype.go | 10 ++++++++++ 12 files changed, 123 insertions(+), 9 deletions(-) create mode 100644 formula/barssincen.go create mode 100644 formula/barssincen_test.go diff --git a/formula/README.md b/formula/README.md index e576b57..e4e9976 100644 --- a/formula/README.md +++ b/formula/README.md @@ -36,7 +36,7 @@ formula | 1 | FILTER | FILTER函数,S满足条件后,将其后N周期内的数据置为0 | FILTER(CLOSE>LOW,5) | [X] | [X] | | 1 | BARSLAST | 上一次条件成立到当前的周期数 | BARSLAST(X) | [√] | [√] | | 1 | BARSLASTCOUNT | 统计连续满足S条件的周期数 | BARSLASTCOUNT(X) | [X] | [X] | -| 1 | BARSSINCEN | N周期内第一次S条件成立到现在的周期数 | BARSSINCEN(S,N) | [X] | [X] | +| 1 | BARSSINCEN | N周期内第一次S条件成立到现在的周期数 | BARSSINCEN(S,N) | [√] | [√] | | 1 | CROSS | 判断向上金叉穿越,两个序列互换就是判断向下死叉穿越 | CROSS(MA(C,5),MA(C,10)) | [X] | [X] | | 1 | LONGCROSS | 两条线维持一定周期后交叉,S1在N周期内都小于S2,本周期从S1下方向上穿过S2时返回1,否则返回0 | LONGCROSS(MA(C,5),MA(C,10),5) | [X] | [X] | | 1 | VALUEWHEN | 当S条件成立时,取X的当前值,否则取VALUEWHEN的上个成立时的X值 | VALUEWHEN(S,X) | [X] | [X] | diff --git a/formula/barssincen.go b/formula/barssincen.go new file mode 100644 index 0000000..decbd3e --- /dev/null +++ b/formula/barssincen.go @@ -0,0 +1,26 @@ +package formula + +import ( + "gitee.com/quant1x/pandas" + "gitee.com/quant1x/pandas/stat" +) + +// BARSSINCEN N周期内第一次S条件成立到现在的周期数,N为常量 +func BARSSINCEN(S pandas.Series, N any) []stat.Int { + ret := S.Rolling(N).Apply(func(X pandas.Series, M stat.DType) stat.DType { + x := X.DTypes() + n := int(M) + argMax := stat.ArgMax(x) + r := 0 + if argMax != 0 || x[0] != 0 { + r = n - 1 - argMax + + } else { + r = 0 + } + return stat.DType(r) + }) + r1 := ret.FillNa(0, true) + r2 := r1.AsInt() + return r2 +} diff --git a/formula/barssincen_test.go b/formula/barssincen_test.go new file mode 100644 index 0000000..9187079 --- /dev/null +++ b/formula/barssincen_test.go @@ -0,0 +1,25 @@ +package formula + +import ( + "fmt" + "gitee.com/quant1x/pandas" + "gitee.com/quant1x/pandas/stat" + "testing" +) + +func TestBARSSINCEN(t *testing.T) { + f1 := []int64{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4} + s1 := pandas.NewSeries(pandas.SERIES_TYPE_DTYPE, "", f1) + df := pandas.NewDataFrame(s1) + fmt.Println(df) + + b1 := s1.Logic(func(idx int, v any) bool { + f := v.(stat.DType) + return f > 3 + }) + df = df.Join(pandas.NewSeries(pandas.SERIES_TYPE_BOOL, "r", b1)) + fmt.Println(df) + //c1 = df > 3 + r1 := BARSSINCEN(df.Col("r"), 4) + fmt.Println(r1) +} diff --git a/generic.go b/generic.go index d419982..823a7ec 100644 --- a/generic.go +++ b/generic.go @@ -161,6 +161,12 @@ func (self *NDFrame) DTypes() []stat.DType { return stat.Slice2DType(self.Values()) } +// AsInt 强制转换成整型 +func (self *NDFrame) AsInt() []stat.Int { + values := self.DTypes() + return stat.DType2Int(values) +} + func (self *NDFrame) Empty() Series { var frame NDFrame if self.type_ == SERIES_TYPE_STRING { @@ -311,7 +317,7 @@ func (self *NDFrame) Std() stat.DType { return stdDev } -func (self *NDFrame) FillNa(v any, inplace bool) { +func (self *NDFrame) FillNa(v any, inplace bool) Series { values := self.Values() switch rows := values.(type) { case []string: @@ -339,4 +345,5 @@ func (self *NDFrame) FillNa(v any, inplace bool) { } } } + return self } diff --git a/generic_apply.go b/generic_apply.go index b311406..6faf9f3 100644 --- a/generic_apply.go +++ b/generic_apply.go @@ -22,3 +22,26 @@ func (self *NDFrame) Apply(f func(idx int, v any)) { // 其它类型忽略 } } + +func (self *NDFrame) Logic(f func(idx int, v any) bool) []bool { + x := make([]bool, self.Len()) + vv := reflect.ValueOf(self.values) + vk := vv.Kind() + switch vk { + case reflect.Invalid: // {interface} nil + //series.assign(idx, size, Nil2Float64) + case reflect.Slice: // 切片, 不定长 + for i := 0; i < vv.Len(); i++ { + tv := vv.Index(i).Interface() + x[i] = f(i, tv) + } + case reflect.Array: // 数组, 定长 + for i := 0; i < vv.Len(); i++ { + tv := vv.Index(i).Interface() + x[i] = f(i, tv) + } + default: + // 其它类型忽略 + } + return x +} diff --git a/generic_rolling.go b/generic_rolling.go index 8c0494d..caf86d3 100644 --- a/generic_rolling.go +++ b/generic_rolling.go @@ -48,8 +48,7 @@ func (r RollingAndExpandingMixin) getBlocks() (blocks []Series) { return } -// Apply 接受一个回调 -func (r RollingAndExpandingMixin) Apply(f func(S Series, N stat.DType) stat.DType) (s Series) { +func (r RollingAndExpandingMixin) Apply_v1(f func(S Series, N stat.DType) stat.DType) (s Series) { s = r.series.Empty() for i, block := range r.getBlocks() { if block.Len() == 0 { @@ -61,3 +60,18 @@ func (r RollingAndExpandingMixin) Apply(f func(S Series, N stat.DType) stat.DTyp } return } + +// Apply 接受一个回调 +func (r RollingAndExpandingMixin) Apply(f func(S Series, N stat.DType) stat.DType) (s Series) { + values := make([]stat.DType, r.series.Len()) + for i, block := range r.getBlocks() { + if block.Len() == 0 { + values[i] = stat.DTypeNaN + continue + } + v := f(block, r.window[i]) + values[i] = v + } + s = NewSeries(SERIES_TYPE_DTYPE, r.series.Name(), values) + return +} diff --git a/series.go b/series.go index 95b4620..0bbd2a1 100644 --- a/series.go +++ b/series.go @@ -43,6 +43,8 @@ type Series interface { Float() []float32 // DTypes 强制转[]stat.DType DTypes() []stat.DType + // 强制转换成整型 + AsInt() []stat.Int // sort.Interface @@ -76,7 +78,7 @@ type Series interface { // StdDev calculates the standard deviation of a series StdDev() stat.DType // FillNa Fill NA/NaN values using the specified method. - FillNa(v any, inplace bool) + FillNa(v any, inplace bool) Series // Max 找出最大值 Max() any // Min 找出最小值 @@ -87,6 +89,8 @@ type Series interface { Append(values ...any) // Apply 接受一个回调函数 Apply(f func(idx int, v any)) + // Logic 逻辑处理 + Logic(f func(idx int, v any) bool) []bool // Diff 元素的第一个离散差 Diff(param any) (s Series) // Ref 引用其它周期的数据 diff --git a/series_bool.go b/series_bool.go index 49d8091..bba6d2e 100644 --- a/series_bool.go +++ b/series_bool.go @@ -161,7 +161,7 @@ func (self *SeriesBool) StdDev() float64 { } // FillNa bool类型不可能在导入series还是NaN -func (self *SeriesBool) FillNa(v any, inplace bool) { +func (self *SeriesBool) FillNa(v any, inplace bool) Series { //values := self.Values() //switch rows := values.(type) { //case []bool: @@ -171,4 +171,5 @@ func (self *SeriesBool) FillNa(v any, inplace bool) { // } // } //} + return self } diff --git a/series_float64.go b/series_float64.go index cf18de5..6808b40 100644 --- a/series_float64.go +++ b/series_float64.go @@ -221,7 +221,7 @@ func (self *SeriesFloat64) StdDev() float64 { return stdDev } -func (self *SeriesFloat64) FillNa(v any, inplace bool) { +func (self *SeriesFloat64) FillNa(v any, inplace bool) Series { values := self.Values() switch rows := values.(type) { case []float64: @@ -231,4 +231,5 @@ func (self *SeriesFloat64) FillNa(v any, inplace bool) { } } } + return self } diff --git a/series_int64.go b/series_int64.go index d16c926..6ed732b 100644 --- a/series_int64.go +++ b/series_int64.go @@ -180,7 +180,7 @@ func (self *SeriesInt64) StdDev() float64 { } // FillNa int64没有NaN -func (self *SeriesInt64) FillNa(v any, inplace bool) { +func (self *SeriesInt64) FillNa(v any, inplace bool) Series { values := self.Values() switch rows := values.(type) { case []int64: @@ -190,4 +190,5 @@ func (self *SeriesInt64) FillNa(v any, inplace bool) { } } } + return self } diff --git a/series_xstring.go b/series_xstring.go index eee9a6d..6d65396 100644 --- a/series_xstring.go +++ b/series_xstring.go @@ -165,7 +165,7 @@ func (self *SeriesString) StdDev() float64 { panic("implement me") } -func (self *SeriesString) FillNa(v any, inplace bool) { +func (self *SeriesString) FillNa(v any, inplace bool) Series { values := self.Values() switch rows := values.(type) { case []string: @@ -175,4 +175,6 @@ func (self *SeriesString) FillNa(v any, inplace bool) { } } } + + return self } diff --git a/stat/type_dtype.go b/stat/type_dtype.go index c87da20..26bbf5e 100644 --- a/stat/type_dtype.go +++ b/stat/type_dtype.go @@ -1,6 +1,11 @@ package stat +import ( + "github.com/viterin/vek" +) + type DType = float64 +type Int = int32 // DTypeIsNaN 判断DType是否NaN func DTypeIsNaN(d DType) bool { @@ -16,3 +21,8 @@ func Slice2DType(v any) []DType { func Any2DType(v any) DType { return AnyToFloat64(v) } + +// DType切片转int32切片 +func DType2Int(d []DType) []Int { + return vek.ToInt32(d) +} -- Gitee