diff --git a/formula/README.md b/formula/README.md index e576b578f2b827702f0d8868e6c38d7f9d815492..e4e9976f78ea53c8358deb831c9dca11a0efa00b 100644 --- a/formula/README.md +++ b/formula/README.md @@ -36,7 +36,7 @@ formula | 1 | FILTER | FILTER函数,S满足条件后,将其后N周期内的数据置为0 | FILTER(CLOSE>LOW,5) | [X] | [X] | | 1 | BARSLAST | 上一次条件成立到当前的周期数 | BARSLAST(X) | [√] | [√] | | 1 | BARSLASTCOUNT | 统计连续满足S条件的周期数 | BARSLASTCOUNT(X) | [X] | [X] | -| 1 | BARSSINCEN | N周期内第一次S条件成立到现在的周期数 | BARSSINCEN(S,N) | [X] | [X] | +| 1 | BARSSINCEN | N周期内第一次S条件成立到现在的周期数 | BARSSINCEN(S,N) | [√] | [√] | | 1 | CROSS | 判断向上金叉穿越,两个序列互换就是判断向下死叉穿越 | CROSS(MA(C,5),MA(C,10)) | [X] | [X] | | 1 | LONGCROSS | 两条线维持一定周期后交叉,S1在N周期内都小于S2,本周期从S1下方向上穿过S2时返回1,否则返回0 | LONGCROSS(MA(C,5),MA(C,10),5) | [X] | [X] | | 1 | VALUEWHEN | 当S条件成立时,取X的当前值,否则取VALUEWHEN的上个成立时的X值 | VALUEWHEN(S,X) | [X] | [X] | diff --git a/formula/barssincen.go b/formula/barssincen.go new file mode 100644 index 0000000000000000000000000000000000000000..decbd3ee8c586d3428265eed5a6dab0239df2ec0 --- /dev/null +++ b/formula/barssincen.go @@ -0,0 +1,26 @@ +package formula + +import ( + "gitee.com/quant1x/pandas" + "gitee.com/quant1x/pandas/stat" +) + +// BARSSINCEN N周期内第一次S条件成立到现在的周期数,N为常量 +func BARSSINCEN(S pandas.Series, N any) []stat.Int { + ret := S.Rolling(N).Apply(func(X pandas.Series, M stat.DType) stat.DType { + x := X.DTypes() + n := int(M) + argMax := stat.ArgMax(x) + r := 0 + if argMax != 0 || x[0] != 0 { + r = n - 1 - argMax + + } else { + r = 0 + } + return stat.DType(r) + }) + r1 := ret.FillNa(0, true) + r2 := r1.AsInt() + return r2 +} diff --git a/formula/barssincen_test.go b/formula/barssincen_test.go new file mode 100644 index 0000000000000000000000000000000000000000..918707982b692dc60ceb030e989efac3ec361714 --- /dev/null +++ b/formula/barssincen_test.go @@ -0,0 +1,25 @@ +package formula + +import ( + "fmt" + "gitee.com/quant1x/pandas" + "gitee.com/quant1x/pandas/stat" + "testing" +) + +func TestBARSSINCEN(t *testing.T) { + f1 := []int64{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4} + s1 := pandas.NewSeries(pandas.SERIES_TYPE_DTYPE, "", f1) + df := pandas.NewDataFrame(s1) + fmt.Println(df) + + b1 := s1.Logic(func(idx int, v any) bool { + f := v.(stat.DType) + return f > 3 + }) + df = df.Join(pandas.NewSeries(pandas.SERIES_TYPE_BOOL, "r", b1)) + fmt.Println(df) + //c1 = df > 3 + r1 := BARSSINCEN(df.Col("r"), 4) + fmt.Println(r1) +} diff --git a/generic.go b/generic.go index d41998298ba07d6f0629847c58c60711b60f9fde..823a7ec873b64bf401d6f74d088923866a390199 100644 --- a/generic.go +++ b/generic.go @@ -161,6 +161,12 @@ func (self *NDFrame) DTypes() []stat.DType { return stat.Slice2DType(self.Values()) } +// AsInt 强制转换成整型 +func (self *NDFrame) AsInt() []stat.Int { + values := self.DTypes() + return stat.DType2Int(values) +} + func (self *NDFrame) Empty() Series { var frame NDFrame if self.type_ == SERIES_TYPE_STRING { @@ -311,7 +317,7 @@ func (self *NDFrame) Std() stat.DType { return stdDev } -func (self *NDFrame) FillNa(v any, inplace bool) { +func (self *NDFrame) FillNa(v any, inplace bool) Series { values := self.Values() switch rows := values.(type) { case []string: @@ -339,4 +345,5 @@ func (self *NDFrame) FillNa(v any, inplace bool) { } } } + return self } diff --git a/generic_apply.go b/generic_apply.go index b311406490e54151a93930794b2e687b0a921db6..6faf9f37e93197b29ce9286910baa3f890e548fa 100644 --- a/generic_apply.go +++ b/generic_apply.go @@ -22,3 +22,26 @@ func (self *NDFrame) Apply(f func(idx int, v any)) { // 其它类型忽略 } } + +func (self *NDFrame) Logic(f func(idx int, v any) bool) []bool { + x := make([]bool, self.Len()) + vv := reflect.ValueOf(self.values) + vk := vv.Kind() + switch vk { + case reflect.Invalid: // {interface} nil + //series.assign(idx, size, Nil2Float64) + case reflect.Slice: // 切片, 不定长 + for i := 0; i < vv.Len(); i++ { + tv := vv.Index(i).Interface() + x[i] = f(i, tv) + } + case reflect.Array: // 数组, 定长 + for i := 0; i < vv.Len(); i++ { + tv := vv.Index(i).Interface() + x[i] = f(i, tv) + } + default: + // 其它类型忽略 + } + return x +} diff --git a/generic_rolling.go b/generic_rolling.go index 8c0494d5ad9cd99b97728157b124e1330d69edeb..caf86d38c7c263c53821ca9862b53375ebcc851d 100644 --- a/generic_rolling.go +++ b/generic_rolling.go @@ -48,8 +48,7 @@ func (r RollingAndExpandingMixin) getBlocks() (blocks []Series) { return } -// Apply 接受一个回调 -func (r RollingAndExpandingMixin) Apply(f func(S Series, N stat.DType) stat.DType) (s Series) { +func (r RollingAndExpandingMixin) Apply_v1(f func(S Series, N stat.DType) stat.DType) (s Series) { s = r.series.Empty() for i, block := range r.getBlocks() { if block.Len() == 0 { @@ -61,3 +60,18 @@ func (r RollingAndExpandingMixin) Apply(f func(S Series, N stat.DType) stat.DTyp } return } + +// Apply 接受一个回调 +func (r RollingAndExpandingMixin) Apply(f func(S Series, N stat.DType) stat.DType) (s Series) { + values := make([]stat.DType, r.series.Len()) + for i, block := range r.getBlocks() { + if block.Len() == 0 { + values[i] = stat.DTypeNaN + continue + } + v := f(block, r.window[i]) + values[i] = v + } + s = NewSeries(SERIES_TYPE_DTYPE, r.series.Name(), values) + return +} diff --git a/series.go b/series.go index 95b462028baf25dce17adce3ae2c4ddafcbdba64..0bbd2a139fb34b3fc260f8aa5a9518b41c9e9158 100644 --- a/series.go +++ b/series.go @@ -43,6 +43,8 @@ type Series interface { Float() []float32 // DTypes 强制转[]stat.DType DTypes() []stat.DType + // 强制转换成整型 + AsInt() []stat.Int // sort.Interface @@ -76,7 +78,7 @@ type Series interface { // StdDev calculates the standard deviation of a series StdDev() stat.DType // FillNa Fill NA/NaN values using the specified method. - FillNa(v any, inplace bool) + FillNa(v any, inplace bool) Series // Max 找出最大值 Max() any // Min 找出最小值 @@ -87,6 +89,8 @@ type Series interface { Append(values ...any) // Apply 接受一个回调函数 Apply(f func(idx int, v any)) + // Logic 逻辑处理 + Logic(f func(idx int, v any) bool) []bool // Diff 元素的第一个离散差 Diff(param any) (s Series) // Ref 引用其它周期的数据 diff --git a/series_bool.go b/series_bool.go index 49d809184eabf74526f5203ffa67c91a90f7bf73..bba6d2e24ce35d86b6dccebdb253f8e0979e9a9a 100644 --- a/series_bool.go +++ b/series_bool.go @@ -161,7 +161,7 @@ func (self *SeriesBool) StdDev() float64 { } // FillNa bool类型不可能在导入series还是NaN -func (self *SeriesBool) FillNa(v any, inplace bool) { +func (self *SeriesBool) FillNa(v any, inplace bool) Series { //values := self.Values() //switch rows := values.(type) { //case []bool: @@ -171,4 +171,5 @@ func (self *SeriesBool) FillNa(v any, inplace bool) { // } // } //} + return self } diff --git a/series_float64.go b/series_float64.go index cf18de5d0740235ebd23855df89463c0d89925a1..6808b40f12c76fb4daa94b8ab0111d650d3bab7b 100644 --- a/series_float64.go +++ b/series_float64.go @@ -221,7 +221,7 @@ func (self *SeriesFloat64) StdDev() float64 { return stdDev } -func (self *SeriesFloat64) FillNa(v any, inplace bool) { +func (self *SeriesFloat64) FillNa(v any, inplace bool) Series { values := self.Values() switch rows := values.(type) { case []float64: @@ -231,4 +231,5 @@ func (self *SeriesFloat64) FillNa(v any, inplace bool) { } } } + return self } diff --git a/series_int64.go b/series_int64.go index d16c9266b00402874aa1243b41afcf8d8fb0359a..6ed732baf9dca9af6b0e29cc126c50f8e3def4a8 100644 --- a/series_int64.go +++ b/series_int64.go @@ -180,7 +180,7 @@ func (self *SeriesInt64) StdDev() float64 { } // FillNa int64没有NaN -func (self *SeriesInt64) FillNa(v any, inplace bool) { +func (self *SeriesInt64) FillNa(v any, inplace bool) Series { values := self.Values() switch rows := values.(type) { case []int64: @@ -190,4 +190,5 @@ func (self *SeriesInt64) FillNa(v any, inplace bool) { } } } + return self } diff --git a/series_xstring.go b/series_xstring.go index eee9a6dd9363639c632188392deef4c14a594747..6d653963f892a3d5cf74aaead111ab036bf18c3b 100644 --- a/series_xstring.go +++ b/series_xstring.go @@ -165,7 +165,7 @@ func (self *SeriesString) StdDev() float64 { panic("implement me") } -func (self *SeriesString) FillNa(v any, inplace bool) { +func (self *SeriesString) FillNa(v any, inplace bool) Series { values := self.Values() switch rows := values.(type) { case []string: @@ -175,4 +175,6 @@ func (self *SeriesString) FillNa(v any, inplace bool) { } } } + + return self } diff --git a/stat/argmax.go b/stat/argmax.go new file mode 100644 index 0000000000000000000000000000000000000000..27b4029a4f1760972b98c4d5e5860f1434648bf2 --- /dev/null +++ b/stat/argmax.go @@ -0,0 +1,33 @@ +package stat + +import ( + "github.com/viterin/vek" + "github.com/viterin/vek/vek32" +) + +// ArgMax Returns the indices of the maximum values along an axis. +// 返回轴上最大值的索引 +func ArgMax[T Number](v []T) int { + var vv any = v + switch values := vv.(type) { + case []float32: + return vek32.ArgMax(values) + case []float64: + return vek.ArgMax(values) + default: + return __arg_max(v) + } + +} + +func __arg_max[T Number](x []T) int { + max := x[0] + idx := 0 + for i, v := range x[1:] { + if v > max { + max = v + idx = 1 + i + } + } + return idx +} diff --git a/stat/argmax_test.go b/stat/argmax_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c7800019b975e866a60b826f92fd8dcbda0e1be1 --- /dev/null +++ b/stat/argmax_test.go @@ -0,0 +1,15 @@ +package stat + +import ( + "fmt" + "testing" +) + +func TestArgMax(t *testing.T) { + n1 := []float32{1.1, 2.2, 1.3, 1.4} + n2 := []float64{1.2, 1.2, 3.3} + n3 := []int64{11, 12, 33} + fmt.Println(ArgMax(n1)) + fmt.Println(ArgMax(n2)) + fmt.Println(ArgMax(n3)) +} diff --git a/stat/argmin.go b/stat/argmin.go new file mode 100644 index 0000000000000000000000000000000000000000..b9b8b0ff02c0d0a6c12876de768ae4c16048afb6 --- /dev/null +++ b/stat/argmin.go @@ -0,0 +1,33 @@ +package stat + +import ( + "github.com/viterin/vek" + "github.com/viterin/vek/vek32" +) + +// ArgMin Returns the indices of the minimum values along an axis. +// 返回轴上最小值的索引 +func ArgMin[T Number](v []T) int { + var vv any = v + switch values := vv.(type) { + case []float32: + return vek32.ArgMin(values) + case []float64: + return vek.ArgMin(values) + default: + return __arg_min(v) + } + +} + +func __arg_min[T Number](x []T) int { + min := x[0] + idx := 0 + for i, v := range x[1:] { + if v < min { + min = v + idx = 1 + i + } + } + return idx +} diff --git a/stat/argmin_test.go b/stat/argmin_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1f61bb721328fe37ed42e84a8130c11c4b19f68a --- /dev/null +++ b/stat/argmin_test.go @@ -0,0 +1,15 @@ +package stat + +import ( + "fmt" + "testing" +) + +func TestArgMin(t *testing.T) { + n1 := []float32{1.1, 2.2, 1.3, 1.4} + n2 := []float64{1.2, 1.2, 0.3} + n3 := []int64{11, 12, 33} + fmt.Println(ArgMin(n1)) + fmt.Println(ArgMin(n2)) + fmt.Println(ArgMin(n3)) +} diff --git a/stat/median.go b/stat/median.go new file mode 100644 index 0000000000000000000000000000000000000000..6be0964b85e2b17df0c9ae98b5d8ace58546128d --- /dev/null +++ b/stat/median.go @@ -0,0 +1,21 @@ +package stat + +// Median returns median value of series. +// Linear interpolation is used for odd length. +// TODO:未加验证 +func Median[T StatType](values []T) DType { + if len(values) == 0 { + return DTypeNaN + } + + if len(values) == 1 { + return DType(0) + } + + if len(values)%2 == 0 { + i := len(values) / 2 + return DType(values[i-1]+values[i]) / 2 + } + + return DType(values[len(values)/2]) +} diff --git a/stat/type_dtype.go b/stat/type_dtype.go index c87da202eb76630cac9682759c334ceb22b81f28..26bbf5e6a27997e6f1ab1d6d319424581861c420 100644 --- a/stat/type_dtype.go +++ b/stat/type_dtype.go @@ -1,6 +1,11 @@ package stat +import ( + "github.com/viterin/vek" +) + type DType = float64 +type Int = int32 // DTypeIsNaN 判断DType是否NaN func DTypeIsNaN(d DType) bool { @@ -16,3 +21,8 @@ func Slice2DType(v any) []DType { func Any2DType(v any) DType { return AnyToFloat64(v) } + +// DType切片转int32切片 +func DType2Int(d []DType) []Int { + return vek.ToInt32(d) +}