diff --git a/formula/max.go b/formula/max.go index 2ce13743619f571633f9bb41364d6562b4fc389c..1da9e8fc66744a744562ab682180ba664091a1f2 100644 --- a/formula/max.go +++ b/formula/max.go @@ -1,13 +1,26 @@ package formula import ( - "gitee.com/quant1x/pandas" "gitee.com/quant1x/pandas/stat" ) // MAX 两个序列横向对比 -func MAX(S1, S2 stat.Series) stat.Series { - d := stat.Maximum(S1.Floats(), S2.Floats()) - return pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "", d) +func MAX(S1 stat.Series, S2 any) stat.Series { + length := S1.Len() + var b []stat.DType + switch sx := S2.(type) { + case stat.Series: + b = sx.DTypes() + case int: + b = stat.Repeat[stat.DType](stat.DType(sx), length) + case stat.DType: + b = stat.Repeat[stat.DType](sx, length) + //case int8, uint8, int16, uint16, int32, uint32, int64, uint64, int, uint, uintptr, float32, float64: + // b = Repeat[DType](DType(sx), length) + default: + panic(stat.Throw(S2)) + } + d := stat.Maximum(S1.DTypes(), b) + return stat.NewSeries[stat.DType](d...) } diff --git a/formula/min.go b/formula/min.go index 2303e4f81fe25e7eaaeb1f2453c5165778896a05..dafaa434cd66cefe50edf959aff32f75f2f9c799 100644 --- a/formula/min.go +++ b/formula/min.go @@ -1,13 +1,26 @@ package formula import ( - "gitee.com/quant1x/pandas" "gitee.com/quant1x/pandas/stat" ) // MIN 两个序列横向对比 -func MIN(S1, S2 stat.Series) stat.Series { - d := stat.Minimum(S1.Floats(), S2.Floats()) - return pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "", d) +func MIN(S1 stat.Series, S2 any) stat.Series { + length := S1.Len() + var b []stat.DType + switch sx := S2.(type) { + case stat.Series: + b = sx.DTypes() + case int: + b = stat.Repeat[stat.DType](stat.DType(sx), length) + case stat.DType: + b = stat.Repeat[stat.DType](sx, length) + //case int8, uint8, int16, uint16, int32, uint32, int64, uint64, int, uint, uintptr, float32, float64: + // b = Repeat[DType](DType(sx), length) + default: + panic(stat.Throw(S2)) + } + d := stat.Minimum(S1.DTypes(), b) + return stat.NewSeries[stat.DType](d...) } diff --git a/formula/ref.go b/formula/ref.go index b74adf3bccc09891200946a2fb765ce5965fb44f..315b3f2fbfd542a52fd47669d590c03aaf28f701 100644 --- a/formula/ref.go +++ b/formula/ref.go @@ -6,19 +6,18 @@ import ( ) // REF 引用前N的序列 -func REF(S stat.Series, N any) any { - var X []float32 +func REF(S stat.Series, N any) stat.Series { + var X []stat.DType switch v := N.(type) { case int: - X = stat.Repeat[float32](float32(v), S.Len()) + X = stat.Repeat[stat.DType](stat.DType(v), S.Len()) case stat.Series: - vs := v.Values() - X = stat.SliceToFloat32(vs) - X = stat.Align(X, stat.Nil2Float32, S.Len()) + vs := v.DTypes() + X = stat.Align(vs, stat.DTypeNaN, S.Len()) default: panic(exception.New(1, "error window")) } - return S.Ref(X).Values() + return S.Ref(X) } func REF2[T stat.GenericType](S []T, N any) []T { @@ -28,7 +27,7 @@ func REF2[T stat.GenericType](S []T, N any) []T { case int: X = stat.Repeat[stat.DType](stat.DType(v), sLen) case []stat.DType: - X = stat.Align(X, stat.Nil2Float64, sLen) + X = stat.Align(v, stat.DTypeNaN, sLen) default: panic(exception.New(1, "error window")) } diff --git a/formula/sma.go b/formula/sma.go index 9cf4abda0c9465f5d6cfc78d7d58512cf2f92cf8..5bd96931ba6671de2c824fc27586aaa9fc0ae06e 100644 --- a/formula/sma.go +++ b/formula/sma.go @@ -6,7 +6,7 @@ import ( ) // SMA 中国式的SMA,至少需要120周期才精确 (雪球180周期) alpha=1/(1+com) -func SMA(S stat.Series, N any, M int) any { +func SMA(S stat.Series, N any, M int) stat.Series { if M == 0 { M = 1 } @@ -21,7 +21,8 @@ func SMA(S stat.Series, N any, M int) any { default: panic(exception.New(1, "error window")) } - x := S.EWM(stat.EW{Alpha: float64(M) / float64(X), Adjust: false}).Mean().Values() + //x := S.EWM(stat.EW{Alpha: float64(M) / float64(X), Adjust: false}).Mean().Values() + x := S.EWM(stat.EW{Alpha: float64(M) / float64(X), Adjust: false}).Mean() return x } diff --git a/indicator/rsi.go b/indicator/rsi.go new file mode 100644 index 0000000000000000000000000000000000000000..464cb22fda3ee56c7ba7224c62a08ec2b650dc8b --- /dev/null +++ b/indicator/rsi.go @@ -0,0 +1,37 @@ +package indicator + +import ( + "gitee.com/quant1x/pandas" + . "gitee.com/quant1x/pandas/formula" +) + +// RSI 指标 +// +// LC:=REF(CLOSE,1); +// LC赋值:1日前的收盘价 +// RSI1:SMA(MAX(CLOSE-LC,0),N1,1)/SMA(ABS(CLOSE-LC),N1,1)*100; +// 输出RSI1:收盘价-LC和0的较大值的N1日[1日权重]移动平均/收盘价-LC的绝对值的N1日[1日权重]移动平均*100 +// RSI2:SMA(MAX(CLOSE-LC,0),N2,1)/SMA(ABS(CLOSE-LC),N2,1)*100; +// 输出RSI2:收盘价-LC和0的较大值的N2日[1日权重]移动平均/收盘价-LC的绝对值的N2日[1日权重]移动平均*100 +// RSI3:SMA(MAX(CLOSE-LC,0),N3,1)/SMA(ABS(CLOSE-LC),N3,1)*100; +// 输出RSI3:收盘价-LC和0的较大值的N3日[1日权重]移动平均/收盘价-LC的绝对值的N3日[1日权重]移动平均*100 +// +// 系统默认参数6,12,24 +func RSI(df pandas.DataFrame, N1, N2, N3 int) pandas.DataFrame { + var ( + CLOSE = df.ColAsNDArray("close") + //HIGH = df.ColAsNDArray("high") + //LOW = df.ColAsNDArray("low") + ) + // LC:=REF(CLOSE,1); + LC := REF(CLOSE, 1) + cls := CLOSE.Sub(LC) + // RSI1:SMA(MAX(CLOSE-LC,0),N1,1)/SMA(ABS(CLOSE-LC),N1,1)*100; + RSI1 := SMA(MAX(cls, 0), N1, 1).Div(SMA(ABS(cls), N1, 1)).Mul(100) + // RSI2:SMA(MAX(CLOSE-LC,0),N2,1)/SMA(ABS(CLOSE-LC),N2,1)*100; + RSI2 := SMA(MAX(cls, 0), N2, 1).Div(SMA(ABS(cls), N2, 1)).Mul(100) + // RSI3:SMA(MAX(CLOSE-LC,0),N3,1)/SMA(ABS(CLOSE-LC),N3,1)*100; + RSI3 := SMA(MAX(cls, 0), N3, 1).Div(SMA(ABS(cls), N3, 1)).Mul(100) + + return pandas.NewDataFrame(RSI1, RSI2, RSI3) +} diff --git a/indicator/rsi_test.go b/indicator/rsi_test.go new file mode 100644 index 0000000000000000000000000000000000000000..64dc3202fcf5022ba34e9346e27fc43a3291a4be --- /dev/null +++ b/indicator/rsi_test.go @@ -0,0 +1,14 @@ +package indicator + +import ( + "fmt" + "gitee.com/quant1x/pandas/data/cache" + "testing" +) + +func TestRSI(t *testing.T) { + df := cache.KLine("sz002528") + fmt.Println(df) + df1 := RSI(df, 6, 12, 24) + fmt.Println(df1) +} diff --git a/stat/type_float32.go b/stat/type_float32.go index 1fb7250321ed0275789fe6bacbd9032d5abdb023..38b5fec19d71a2480f10360c439548ea9eaea627 100644 --- a/stat/type_float32.go +++ b/stat/type_float32.go @@ -5,7 +5,6 @@ import ( "gitee.com/quant1x/pandas/exception" "github.com/mymmsc/gox/logger" "github.com/viterin/vek/vek32" - "golang.org/x/exp/slices" "math" "reflect" "strconv" @@ -66,7 +65,8 @@ func SliceToFloat32(v any) []float32 { case []uint: return slice_any_to_float32(values) case []float32: // 克隆 - return slices.Clone(values) + //return slices.Clone(values) + return values case []float64: // 加速 return vek32.FromFloat64(values) case []bool: diff --git a/stat/type_float64.go b/stat/type_float64.go index d583634da16d5afce0d188b7c842014fd2bc2ad0..251ce30650fd8e5ff165946dbad3ccbafad7ce43 100644 --- a/stat/type_float64.go +++ b/stat/type_float64.go @@ -5,7 +5,6 @@ import ( "gitee.com/quant1x/pandas/exception" "github.com/mymmsc/gox/logger" "github.com/viterin/vek" - "golang.org/x/exp/slices" "math" "reflect" "strconv" @@ -110,7 +109,8 @@ func SliceToFloat64(v any) []float64 { case []float32: // 加速 return vek.FromFloat32(values) case []float64: // 克隆 - return slices.Clone(values) + //return slices.Clone(values) + return values case []bool: count := len(values) if count == 0 {