diff --git a/dataframe.go b/dataframe.go index 6bd625c76e6499cedccfe58216b92bf563e282bb..30c22d9c0386879ef9b21c35534f046436492ba4 100644 --- a/dataframe.go +++ b/dataframe.go @@ -2,12 +2,13 @@ package pandas import ( "fmt" + "gitee.com/quant1x/pandas/stat" "sort" ) // DataFrame 以gota的DataFrame的方法为主, 兼顾新流程, 避免单元格元素结构化 type DataFrame struct { - columns []Series + columns []stat.Series ncols int nrows int @@ -16,24 +17,24 @@ type DataFrame struct { } // NewDataFrame is the generic DataFrame constructor -func NewDataFrame(se ...Series) DataFrame { +func NewDataFrame(se ...stat.Series) DataFrame { if se == nil || len(se) == 0 { return DataFrame{Err: fmt.Errorf("empty DataFrame")} } - columns := make([]Series, len(se)) + columns := make([]stat.Series, len(se)) for i, s := range se { - var d Series - if s.Type() == SERIES_TYPE_INT64 { - d = NewSeries(SERIES_TYPE_INT64, s.Name(), s.Values()) - } else if s.Type() == SERIES_TYPE_BOOL { - d = NewSeries(SERIES_TYPE_BOOL, s.Name(), s.Values()) - } else if s.Type() == SERIES_TYPE_STRING { - d = NewSeries(SERIES_TYPE_STRING, s.Name(), s.Values()) - } else if s.Type() == SERIES_TYPE_FLOAT32 { - d = NewSeries(SERIES_TYPE_FLOAT32, s.Name(), s.Values()) + var d stat.Series + if s.Type() == stat.SERIES_TYPE_INT64 { + d = NewSeries(stat.SERIES_TYPE_INT64, s.Name(), s.Values()) + } else if s.Type() == stat.SERIES_TYPE_BOOL { + d = NewSeries(stat.SERIES_TYPE_BOOL, s.Name(), s.Values()) + } else if s.Type() == stat.SERIES_TYPE_STRING { + d = NewSeries(stat.SERIES_TYPE_STRING, s.Name(), s.Values()) + } else if s.Type() == stat.SERIES_TYPE_FLOAT32 { + d = NewSeries(stat.SERIES_TYPE_FLOAT32, s.Name(), s.Values()) } else { - d = NewSeries(SERIES_TYPE_FLOAT64, s.Name(), s.Values()) + d = NewSeries(stat.SERIES_TYPE_FLOAT64, s.Name(), s.Values()) } columns[i] = d } @@ -77,7 +78,7 @@ func (self DataFrame) Error() error { } // 检查列的尺寸 -func checkColumnsDimensions(se ...Series) (nrows, ncols int, err error) { +func checkColumnsDimensions(se ...stat.Series) (nrows, ncols int, err error) { ncols = len(se) nrows = -1 if se == nil || ncols == 0 { diff --git a/dataframe_csv.go b/dataframe_csv.go index 19a7a6fa0cf0ba78890f969fb7b8fe2eec3b808e..a258e63e70b359f3a7ed9a64dd1193002aed8385 100644 --- a/dataframe_csv.go +++ b/dataframe_csv.go @@ -2,6 +2,7 @@ package pandas import ( "encoding/csv" + "gitee.com/quant1x/pandas/stat" "github.com/mymmsc/gox/api" "github.com/mymmsc/gox/logger" "github.com/mymmsc/gox/util/homedir" @@ -24,7 +25,7 @@ func ReadCSV(in any, options ...LoadOption) DataFrame { filename = param } - if !IsEmpty(filename) { + if !stat.IsEmpty(filename) { filepath, err := homedir.Expand(filename) if err != nil { logger.Errorf("%s, error=%+v\n", filename, err) @@ -74,7 +75,7 @@ func (self DataFrame) WriteCSV(out any, options ...WriteOption) error { filename = param } - if !IsEmpty(filename) { + if !stat.IsEmpty(filename) { filepath, err := homedir.Expand(filename) if err != nil { return err diff --git a/dataframe_excel.go b/dataframe_excel.go index 7365926aaf7daa1109438a8261c8e6dea7a8b3cc..74f9b752dc8ccb83ccc401c2705c86147dea6746 100644 --- a/dataframe_excel.go +++ b/dataframe_excel.go @@ -2,6 +2,7 @@ package pandas import ( "fmt" + "gitee.com/quant1x/pandas/stat" "github.com/mymmsc/gox/logger" "github.com/mymmsc/gox/util/homedir" xlsv1 "github.com/tealeg/xlsx" @@ -9,9 +10,9 @@ import ( "strings" ) -// 读取excel文件 +// ReadExcel 读取excel文件 func ReadExcel(filename string, options ...LoadOption) DataFrame { - if IsEmpty(filename) { + if stat.IsEmpty(filename) { return DataFrame{Err: fmt.Errorf("filaname is empty")} } diff --git a/dataframe_indexes.go b/dataframe_indexes.go index 6030cdfff44260402f9cdcd019dd21d1b8f8d8e0..c8a49290a3b1b656d52b4ca8b4e8addeb1f0268e 100644 --- a/dataframe_indexes.go +++ b/dataframe_indexes.go @@ -1,6 +1,9 @@ package pandas -import "fmt" +import ( + "fmt" + "gitee.com/quant1x/pandas/stat" +) func parseSelectIndexes(l int, indexes SelectIndexes, colnames []string) ([]int, error) { var idx []int @@ -85,7 +88,7 @@ func (df DataFrame) Select(indexes SelectIndexes) DataFrame { if err != nil { return DataFrame{Err: fmt.Errorf("can't select columns: %v", err)} } - columns := make([]Series, len(idx)) + columns := make([]stat.Series, len(idx)) for k, i := range idx { if i < 0 || i >= df.ncols { return DataFrame{Err: fmt.Errorf("can't select columns: index out of range")} diff --git a/dataframe_join.go b/dataframe_join.go index 6f8645eacabaa1a1b4a45a1e38889a5e757f9998..8f35d650e15c737fa8abc79fc4d8ce065c7d0d3f 100644 --- a/dataframe_join.go +++ b/dataframe_join.go @@ -2,8 +2,8 @@ package pandas import "gitee.com/quant1x/pandas/stat" -func (self DataFrame) align(ss ...Series) []Series { - defaultValue := []Series{} +func (self DataFrame) align(ss ...stat.Series) []stat.Series { + defaultValue := []stat.Series{} sLen := len(ss) if sLen == 0 { return defaultValue @@ -17,7 +17,7 @@ func (self DataFrame) align(ss ...Series) []Series { if maxLength <= 0 { return defaultValue } - cols := make([]Series, sLen) + cols := make([]stat.Series, sLen) for i, v := range ss { vt := v.Type() vn := v.Name() @@ -25,16 +25,16 @@ func (self DataFrame) align(ss ...Series) []Series { // 声明any的ns变量用于接收逻辑分支的输出 // 切片数据不能直接对齐, 需要根据类型指定Nil和NaN默认值 var ns any - if vt == SERIES_TYPE_BOOL { + if vt == stat.SERIES_TYPE_BOOL { ns = stat.Align(vs.([]bool), stat.Nil2Bool, int(maxLength)) - } else if vt == SERIES_TYPE_INT64 { + } else if vt == stat.SERIES_TYPE_INT64 { ns = stat.Align(vs.([]int64), stat.Nil2Int64, int(maxLength)) - } else if vt == SERIES_TYPE_STRING { + } else if vt == stat.SERIES_TYPE_STRING { ns = stat.Align(vs.([]string), stat.Nil2String, int(maxLength)) - } else if vt == SERIES_TYPE_FLOAT32 { - ns = stat.Align(vs.([]float32), Nil2Float32, int(maxLength)) - } else if vt == SERIES_TYPE_FLOAT64 { - ns = stat.Align(vs.([]float64), Nil2Float64, int(maxLength)) + } else if vt == stat.SERIES_TYPE_FLOAT32 { + ns = stat.Align(vs.([]float32), stat.Nil2Float32, int(maxLength)) + } else if vt == stat.SERIES_TYPE_FLOAT64 { + ns = stat.Align(vs.([]float64), stat.Nil2Float64, int(maxLength)) } cols[i] = NewSeries(vt, vn, ns) } @@ -42,12 +42,12 @@ func (self DataFrame) align(ss ...Series) []Series { } // Join 默认右连接, 加入一个series -func (self DataFrame) Join(series Series) DataFrame { +func (self DataFrame) Join(series stat.Series) DataFrame { if series.Len() < 0 { return self } nCol := self.Ncol() - cols := make([]Series, nCol+1) + cols := make([]stat.Series, nCol+1) cols[len(cols)-1] = series for i, s := range self.columns { cols[i] = s diff --git a/dataframe_matrix.go b/dataframe_matrix.go index 66c0c976476ea43ad29674ee18b00a057b32b2ae..91466f3cc4b2c03b28038c8ddc660f274b692431 100644 --- a/dataframe_matrix.go +++ b/dataframe_matrix.go @@ -1,18 +1,21 @@ package pandas -import "gonum.org/v1/gonum/mat" +import ( + "gitee.com/quant1x/pandas/stat" + "gonum.org/v1/gonum/mat" +) // LoadMatrix loads the given Matrix as a DataFrame // TODO: Add Loadoptions func LoadMatrix(mat mat.Matrix) DataFrame { nrows, ncols := mat.Dims() - columns := make([]Series, ncols) + columns := make([]stat.Series, ncols) for i := 0; i < ncols; i++ { floats := make([]float64, nrows) for j := 0; j < nrows; j++ { floats[j] = mat.At(j, i) } - columns[i] = NewSeries(SERIES_TYPE_FLOAT64, "", floats) + columns[i] = NewSeries(stat.SERIES_TYPE_FLOAT64, "", floats) } nrows, ncols, err := checkColumnsDimensions(columns...) if err != nil { diff --git a/dataframe_options.go b/dataframe_options.go index 0d2e6a3d0a29c18763f66572073c352d3a347181..3f437ce39f5e9b3e69787248ff4f0b0ab795bc85 100644 --- a/dataframe_options.go +++ b/dataframe_options.go @@ -1,8 +1,10 @@ package pandas +import "gitee.com/quant1x/pandas/stat" + type loadOptions struct { // Specifies which is the default type in case detectTypes is disabled. - defaultType Type + defaultType stat.Type // If set, the type of each column will be automatically detected unless // otherwise specified. @@ -28,11 +30,11 @@ type loadOptions struct { comment rune // The types of specific columns can be specified via column name. - types map[string]Type + types map[string]stat.Type } // DefaultType sets the defaultType option for loadOptions. -func DefaultType(t Type) LoadOption { +func DefaultType(t stat.Type) LoadOption { return func(c *loadOptions) { c.defaultType = t } @@ -67,7 +69,7 @@ func NaNValues(nanValues []string) LoadOption { } // WithTypes sets the types option for loadOptions. -func WithTypes(coltypes map[string]Type) LoadOption { +func WithTypes(coltypes map[string]stat.Type) LoadOption { return func(c *loadOptions) { c.types = coltypes } diff --git a/dataframe_records.go b/dataframe_records.go index eb1ec305fb278211a1aa37785d37693dca5c8964..46a430303e4b7451351dce6aa46799768619cd55 100644 --- a/dataframe_records.go +++ b/dataframe_records.go @@ -10,7 +10,7 @@ import ( func LoadRecords(records [][]string, options ...LoadOption) DataFrame { // Set the default load options cfg := loadOptions{ - defaultType: SERIES_TYPE_STRING, + defaultType: stat.SERIES_TYPE_STRING, detectTypes: true, hasHeader: true, nanValues: stat.PossibleNaOfString, @@ -44,7 +44,7 @@ func LoadRecords(records [][]string, options ...LoadOption) DataFrame { headers = cfg.names } - types := make([]Type, len(headers)) + types := make([]stat.Type, len(headers)) rawcols := make([][]string, len(headers)) for i, colname := range headers { rawcol := make([]string, len(records)) @@ -69,7 +69,7 @@ func LoadRecords(records [][]string, options ...LoadOption) DataFrame { types[i] = t } - columns := make([]Series, len(headers)) + columns := make([]stat.Series, len(headers)) for i, colname := range headers { cols := rawcols[i] col := NewSeries(types[i], colname, cols) diff --git a/dataframe_remove.go b/dataframe_remove.go index 107c5af54d639c78c7b38030fc560a1bec93d81a..b50c8f82eb23bdad0f2b69fadce5a32041b1d476 100644 --- a/dataframe_remove.go +++ b/dataframe_remove.go @@ -9,7 +9,7 @@ func (self DataFrame) Remove(p stat.ScopeLimit) DataFrame { if err != nil { return self } - columns := []Series{} + columns := []stat.Series{} for i := range self.columns { ht := self.columns[i].Subset(0, start, true) tail := self.columns[i].Subset(end+1, rowLen).Values() diff --git a/dataframe_select.go b/dataframe_select.go index d88f8b0f0a8f00b638260d1c5044c3ac565456ec..0c2dc0df37cd8354bc85c2321497a63730f06e19 100644 --- a/dataframe_select.go +++ b/dataframe_select.go @@ -1,17 +1,20 @@ package pandas -import "fmt" +import ( + "fmt" + "gitee.com/quant1x/pandas/stat" +) // Col returns a copy of the Series with the given column name contained in the DataFrame. // 选取一列 -func (self DataFrame) Col(colname string) Series { +func (self DataFrame) Col(colname string) stat.Series { if self.Err != nil { - return NewSeriesWithType(SERIES_TYPE_INVAILD, "") + return NewSeriesWithType(stat.SERIES_TYPE_INVAILD, "") } // Check that colname exist on dataframe idx := findInStringSlice(colname, self.Names()) if idx < 0 { - return NewSeriesWithType(SERIES_TYPE_INVAILD, "") + return NewSeriesWithType(stat.SERIES_TYPE_INVAILD, "") } return self.columns[idx].Copy() } diff --git a/dataframe_struct.go b/dataframe_struct.go index 45d35d3237df1ff36c9f938f62c9ff8bed218b0f..265c18aca92a70bf99a52fda979599983f7b8abc 100644 --- a/dataframe_struct.go +++ b/dataframe_struct.go @@ -46,7 +46,7 @@ func LoadStructs(i interface{}, options ...LoadOption) DataFrame { // Set the default load options cfg := loadOptions{ - defaultType: SERIES_TYPE_STRING, + defaultType: stat.SERIES_TYPE_STRING, detectTypes: true, hasHeader: true, nanValues: stat.PossibleNaOfString, @@ -69,7 +69,7 @@ func LoadStructs(i interface{}, options ...LoadOption) DataFrame { } numFields := val.Index(0).Type().NumField() - var columns []Series + var columns []stat.Series for j := 0; j < numFields; j++ { // Extract field metadata if !val.Index(0).Field(j).CanInterface() { @@ -100,7 +100,7 @@ func LoadStructs(i interface{}, options ...LoadOption) DataFrame { } // Handle `types` option - var t Type + var t stat.Type if cfgtype, ok := cfg.types[fieldName]; ok { t = cfgtype } else { @@ -136,17 +136,17 @@ func LoadStructs(i interface{}, options ...LoadOption) DataFrame { elements = append(tmp, elements...) fieldName = "" } - if t == SERIES_TYPE_STRING { - columns = append(columns, NewSeries(SERIES_TYPE_STRING, fieldName, elements)) - } else if t == SERIES_TYPE_BOOL { - columns = append(columns, NewSeries(SERIES_TYPE_BOOL, fieldName, elements)) - } else if t == SERIES_TYPE_INT64 { - columns = append(columns, NewSeries(SERIES_TYPE_INT64, fieldName, elements)) - } else if t == SERIES_TYPE_FLOAT32 { - columns = append(columns, NewSeries(SERIES_TYPE_FLOAT32, fieldName, elements)) + if t == stat.SERIES_TYPE_STRING { + columns = append(columns, NewSeries(stat.SERIES_TYPE_STRING, fieldName, elements)) + } else if t == stat.SERIES_TYPE_BOOL { + columns = append(columns, NewSeries(stat.SERIES_TYPE_BOOL, fieldName, elements)) + } else if t == stat.SERIES_TYPE_INT64 { + columns = append(columns, NewSeries(stat.SERIES_TYPE_INT64, fieldName, elements)) + } else if t == stat.SERIES_TYPE_FLOAT32 { + columns = append(columns, NewSeries(stat.SERIES_TYPE_FLOAT32, fieldName, elements)) } else { // 默认float - columns = append(columns, NewSeries(SERIES_TYPE_FLOAT64, fieldName, elements)) + columns = append(columns, NewSeries(stat.SERIES_TYPE_FLOAT64, fieldName, elements)) } } return NewDataFrame(columns...) diff --git a/dataframe_subset.go b/dataframe_subset.go index 60089324dce9ba349577244b36987f00b33efb1b..b77f13cec6e8a1b8c0f36af3457b48c6cf784129 100644 --- a/dataframe_subset.go +++ b/dataframe_subset.go @@ -8,7 +8,7 @@ func (self DataFrame) Subset(start, end int) DataFrame { if self.Err != nil { return self } - columns := make([]Series, self.ncols) + columns := make([]stat.Series, self.ncols) for i, column := range self.columns { s := column.Subset(start, end) columns[i] = s @@ -26,7 +26,7 @@ func (self DataFrame) Subset(start, end int) DataFrame { // Select 选择一段记录 func (self DataFrame) SelectRows(p stat.ScopeLimit) DataFrame { - columns := []Series{} + columns := []stat.Series{} for i := range self.columns { columns = append(columns, self.columns[i].Select(p)) } diff --git a/dataframe_test.go b/dataframe_test.go index f1ef0733b183b07f286b83a4fcf9365bd3877cad..08023483aeb8caadef53b062997ce663773720be 100644 --- a/dataframe_test.go +++ b/dataframe_test.go @@ -5,24 +5,6 @@ import ( "testing" ) -func TestDataFrameT0(t *testing.T) { - var s1 Series - s1 = NewSeriesFloat64("sales", nil, 50.3, 23.4, 56.2) - fmt.Println(s1) - expected := 4 - - if s1.Len() != expected { - t.Errorf("wrong val: expected: %v actual: %v", expected, s1.Len()) - } - s2 := s1.Shift(-2) - df := NewDataFrame(s1, s2) - fmt.Println(df) - df.FillNa(0.00, true) - fmt.Println(df) - - _ = s2 -} - func TestLoadStructs(t *testing.T) { type testStruct struct { A string diff --git a/dataframe_type.go b/dataframe_type.go new file mode 100644 index 0000000000000000000000000000000000000000..7e725f03c90db0700f4575ace2d61b35787ae220 --- /dev/null +++ b/dataframe_type.go @@ -0,0 +1,127 @@ +package pandas + +import ( + "fmt" + "gitee.com/quant1x/pandas/exception" + "gitee.com/quant1x/pandas/stat" + "reflect" + "strconv" + "strings" +) + +const ( + MAX_FLOAT32_PRICE = float32(9999.9999) // float32的价最大阀值触发扩展到float64 +) + +var ( + ErrUnsupportedType = exception.New(0, "Unsupported type") +) + +func mustFloat64(f float32) bool { + if f > MAX_FLOAT32_PRICE { + return true + } + return false +} + +func findTypeByString(arr []string) (stat.Type, error) { + var hasFloats, hasInts, hasBools, hasStrings bool + var useFloat32, useFloat64 bool + var stringLengthEqual = -1 + var stringLenth = -1 + for _, str := range arr { + if str == "" || str == "NaN" { + continue + } + tLen := len(str) + if strings.HasPrefix(str, "0") { + stringLengthEqual = 0 + } + if stringLenth < 1 { + if stringLengthEqual == -1 { + stringLenth = tLen + } + } else if stringLengthEqual >= 0 && tLen != stringLenth { + stringLengthEqual += 1 + } + + if _, err := strconv.Atoi(str); err == nil { + hasInts = true + continue + } + if f, err := strconv.ParseFloat(str, 64); err == nil { + hasFloats = true + if float32(f) < stat.MaxFloat32 { + if mustFloat64(float32(f)) { + useFloat64 = true + } else { + useFloat32 = true + } + } + continue + } + if str == "true" || str == "false" { + hasBools = true + continue + } + hasStrings = true + } + if stringLengthEqual == 0 { + hasStrings = true + } + // 类型优先级, string > bool > float > int, string 为默认类型 + switch { + case hasStrings: + return stat.SERIES_TYPE_STRING, nil + case hasBools: + return stat.SERIES_TYPE_BOOL, nil + case useFloat32 && !useFloat64: + return stat.SERIES_TYPE_FLOAT32, nil + case hasFloats: + return stat.SERIES_TYPE_FLOAT64, nil + case hasInts: + return stat.SERIES_TYPE_INT64, nil + default: + return stat.SERIES_TYPE_STRING, fmt.Errorf("couldn't detect type") + } + +} + +func parseType(s string) (stat.Type, error) { + switch s { + case "float", "float32": + return stat.SERIES_TYPE_FLOAT32, nil + case "float64": + return stat.SERIES_TYPE_FLOAT64, nil + case "int", "int64", "int32", "int16", "int8": + return stat.SERIES_TYPE_INT64, nil + case "string": + return stat.SERIES_TYPE_STRING, nil + case "bool": + return stat.SERIES_TYPE_BOOL, nil + } + return stat.SERIES_TYPE_INVAILD, fmt.Errorf("type (%s) is not supported", s) +} + +func detectTypes[T stat.GenericType](v T) (stat.Type, any) { + var _type = stat.SERIES_TYPE_STRING + vv := reflect.ValueOf(v) + vk := vv.Kind() + switch vk { + case reflect.Invalid: + _type = stat.SERIES_TYPE_INVAILD + case reflect.Bool: + _type = stat.SERIES_TYPE_BOOL + case reflect.Int64: + _type = stat.SERIES_TYPE_INT64 + case reflect.Float32: + _type = stat.SERIES_TYPE_FLOAT32 + case reflect.Float64: + _type = stat.SERIES_TYPE_FLOAT64 + case reflect.String: + _type = stat.SERIES_TYPE_STRING + default: + panic(fmt.Errorf("unknown type, %+v", v)) + } + return _type, vv.Interface() +} diff --git a/formula/abs.go b/formula/abs.go index 081ca7dee10aea042c5a88105d54ad228b93e9c8..ffee9bda4bff85016a9225b12fef9af6fae7ab5d 100644 --- a/formula/abs.go +++ b/formula/abs.go @@ -1,13 +1,13 @@ package formula import ( - "gitee.com/quant1x/pandas" "gitee.com/quant1x/pandas/stat" ) // ABS 计算S的绝对值 -func ABS(S pandas.Series) pandas.Series { +func ABS(S stat.Series) stat.Series { s := S.DTypes() d := stat.Abs(s) - return pandas.NewSeries(pandas.SERIES_TYPE_DTYPE, "", d) + //return pandas.NewSeries(stat.SERIES_TYPE_DTYPE, "", d) + return stat.NewSeries(d...) } diff --git a/formula/abs_test.go b/formula/abs_test.go index 39a05597028eda030848d1291503fae624127abd..758984f7210d9e2f03906ca73f82625828c6b2ea 100644 --- a/formula/abs_test.go +++ b/formula/abs_test.go @@ -3,12 +3,13 @@ package formula import ( "fmt" "gitee.com/quant1x/pandas" + "gitee.com/quant1x/pandas/stat" "testing" ) func TestABS(t *testing.T) { - v1 := []int32{1, -1, 2, -2} - s := pandas.NewSeries(pandas.SERIES_TYPE_FLOAT64, "", v1) + d1 := []int32{1, -1, 2, -2} + s := pandas.NewSeries(stat.SERIES_TYPE_FLOAT64, "", d1) fmt.Println(ABS(s)) } diff --git a/formula/avedev.go b/formula/avedev.go index f0f1bf0351523f9c4bb6616bf703445f028f171e..b404535772e5a98896f2b4118ce370be029d9fe9 100644 --- a/formula/avedev.go +++ b/formula/avedev.go @@ -1,15 +1,14 @@ package formula import ( - "gitee.com/quant1x/pandas" "gitee.com/quant1x/pandas/stat" ) // AVEDEV 平均绝对偏差, (序列与其平均值的绝对差的平均值) // // AVEDEV(S,N) 返回平均绝对偏差 -func AVEDEV(S pandas.Series, N any) any { - return S.Rolling(N).Apply(func(X pandas.Series, W stat.DType) stat.DType { +func AVEDEV(S stat.Series, N any) any { + return S.Rolling(N).Apply(func(X stat.Series, W stat.DType) stat.DType { x := X.DTypes() x1 := X.Mean() r := stat.Sub(x, x1) diff --git a/formula/barslast.go b/formula/barslast.go index fd7ba80bd5af5dcd38cbd09206ec9a5b1e9c165e..a22002181acfbfaef1447334ba15f1f58557aef8 100644 --- a/formula/barslast.go +++ b/formula/barslast.go @@ -1,12 +1,11 @@ package formula import ( - "gitee.com/quant1x/pandas" "gitee.com/quant1x/pandas/stat" ) // BARSLAST 为了测试SMA,BARSLAST必须要先实现, 给SMA提供序列换参数, 以便验证, python那边还没实现 -func BARSLAST(S pandas.Series) []stat.DType { +func BARSLAST(S stat.Series) []stat.DType { fs := S.DTypes() as := stat.Repeat[stat.DType](1, S.Len()) bs := stat.Repeat[stat.DType](0, S.Len()) diff --git a/formula/barslastcount.go b/formula/barslastcount.go index a34e4aaa2db92d2fa83fb59de664def38d0d9a0e..35e6007ef454b872a5e8381af1de36befdcba22c 100644 --- a/formula/barslastcount.go +++ b/formula/barslastcount.go @@ -1,12 +1,11 @@ package formula import ( - "gitee.com/quant1x/pandas" "gitee.com/quant1x/pandas/stat" ) // BARSLASTCOUNT 统计连续满足S条件的周期数 -func BARSLASTCOUNT(S pandas.Series) []int64 { +func BARSLASTCOUNT(S stat.Series) []int64 { s := S.DTypes() slen := len(s) rt := stat.Repeat[int64](0, slen+1) diff --git a/formula/barssincen.go b/formula/barssincen.go index decbd3ee8c586d3428265eed5a6dab0239df2ec0..637b4059fe633709c90125196d8e73230d3a5d61 100644 --- a/formula/barssincen.go +++ b/formula/barssincen.go @@ -1,13 +1,12 @@ package formula import ( - "gitee.com/quant1x/pandas" "gitee.com/quant1x/pandas/stat" ) // BARSSINCEN N周期内第一次S条件成立到现在的周期数,N为常量 -func BARSSINCEN(S pandas.Series, N any) []stat.Int { - ret := S.Rolling(N).Apply(func(X pandas.Series, M stat.DType) stat.DType { +func BARSSINCEN(S stat.Series, N any) []stat.Int { + ret := S.Rolling(N).Apply(func(X stat.Series, M stat.DType) stat.DType { x := X.DTypes() n := int(M) argMax := stat.ArgMax(x) @@ -21,6 +20,6 @@ func BARSSINCEN(S pandas.Series, N any) []stat.Int { return stat.DType(r) }) r1 := ret.FillNa(0, true) - r2 := r1.AsInt() + r2 := r1.Ints() return r2 } diff --git a/formula/barssincen_test.go b/formula/barssincen_test.go index 918707982b692dc60ceb030e989efac3ec361714..8b5ac16202a351ac3993b48c6623c48b736179a1 100644 --- a/formula/barssincen_test.go +++ b/formula/barssincen_test.go @@ -9,7 +9,7 @@ import ( func TestBARSSINCEN(t *testing.T) { f1 := []int64{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4} - s1 := pandas.NewSeries(pandas.SERIES_TYPE_DTYPE, "", f1) + s1 := pandas.NewSeries(stat.SERIES_TYPE_DTYPE, "", f1) df := pandas.NewDataFrame(s1) fmt.Println(df) @@ -17,7 +17,7 @@ func TestBARSSINCEN(t *testing.T) { f := v.(stat.DType) return f > 3 }) - df = df.Join(pandas.NewSeries(pandas.SERIES_TYPE_BOOL, "r", b1)) + df = df.Join(pandas.NewSeries(stat.SERIES_TYPE_BOOL, "r", b1)) fmt.Println(df) //c1 = df > 3 r1 := BARSSINCEN(df.Col("r"), 4) diff --git a/formula/comparison.go b/formula/comparison.go index 8321c47e9162923d464e6e6d50e4fad647203d5a..2ef4a6d23a9968b9a54776265aa3f05191afcb98 100644 --- a/formula/comparison.go +++ b/formula/comparison.go @@ -1,14 +1,13 @@ package formula import ( - "gitee.com/quant1x/pandas" "gitee.com/quant1x/pandas/exception" "gitee.com/quant1x/pandas/stat" "github.com/viterin/vek" ) -func AND(a, b []bool) []bool { - return vek.And(a, b) +func AND[T stat.Number | ~bool](a, b []T) []bool { + return stat.And(a, b) } func OR(a, b []bool) []bool { @@ -51,7 +50,7 @@ func __compare(v []stat.DType, x any, comparator func(x, y []float64) []bool) [] vlen = xlen } X = stat.Align[stat.DType](vx, defaultValue, vlen) - case pandas.Series: + case stat.Series: vs := vx.DTypes() xlen := len(vs) if vlen < xlen { diff --git a/formula/const.go b/formula/const.go index 935d3217501c00814c6a772503e0dc3f69308238..edf0c6d32f9c5505f4716546408789614c948d75 100644 --- a/formula/const.go +++ b/formula/const.go @@ -6,9 +6,9 @@ import ( ) // CONST 取S最后的值为常量 -func CONST(S pandas.Series) pandas.Series { +func CONST(S stat.Series) stat.Series { length := S.Len() - s := S.Float() + s := S.Floats() s = stat.Repeat(s[length-1], S.Len()) - return pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "", s) + return pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "", s) } diff --git a/formula/count.go b/formula/count.go index cff857ea959500f0ffb17df7415d914888365eb2..c3d430b634b50ee590cc18905de9ac618f62da94 100644 --- a/formula/count.go +++ b/formula/count.go @@ -1,15 +1,11 @@ package formula import ( - "gitee.com/quant1x/pandas" "gitee.com/quant1x/pandas/stat" ) // COUNT 统计S为真的天数 -func COUNT(S pandas.Series, N any) pandas.Series { - return S.Rolling(N).Count() -} -func COUNT2(S []bool, N int) []int { +func COUNT(S []bool, N int) []int { xLen := len(S) x := stat.Rolling(S, N) ret := make([]int, xLen) @@ -25,10 +21,14 @@ func COUNT2(S []bool, N int) []int { return ret } +func COUNT_V2(S stat.Series, N any) stat.Series { + return S.Rolling(N).Count() +} + // COUNT_v1 一般性比较 -func COUNT_v1(S pandas.Series, N any) []stat.Int { +func COUNT_v1(S stat.Series, N any) []stat.Int { //values := S.DTypes() - return S.Rolling(N).Apply(func(X pandas.Series, W stat.DType) stat.DType { + return S.Rolling(N).Apply(func(X stat.Series, W stat.DType) stat.DType { x := X.DTypes() n := 0 for _, v := range x { @@ -37,5 +37,5 @@ func COUNT_v1(S pandas.Series, N any) []stat.Int { } } return stat.DType(n) - }).AsInt() + }).Ints() } diff --git a/formula/count_test.go b/formula/count_test.go index 1ba8001c7973d76766dce702f4e857fc8d3eda49..cddf06e2be62609b00e463eaa684cd7f31230629 100644 --- a/formula/count_test.go +++ b/formula/count_test.go @@ -10,7 +10,7 @@ func TestCOUNT(t *testing.T) { f0 := []float64{1, 2, 3, 4, 5, 6, 0, 8, 9, 10, 11, 12} i0 := CompareGte(f0, 1) s0 := pandas.NewSeriesWithoutType("f0", i0) - fmt.Println(COUNT(s0, 5)) + fmt.Println(COUNT_v1(s0, 5)) //s2 := []float64{1, 2, 3, 4, 3, 3, 2, 1, stat.DTypeNaN, stat.DTypeNaN, stat.DTypeNaN, stat.DTypeNaN} //fmt.Println(s2) ////stat.Fill(s2, 1.0, true) diff --git a/formula/cross.go b/formula/cross.go index 903878fbf7e49fc43d6516979f071733ab1ed287..8cbb1205ded7e3dde9abaf3a7d446ff4dfe5cc08 100644 --- a/formula/cross.go +++ b/formula/cross.go @@ -1,7 +1,6 @@ package formula import ( - "gitee.com/quant1x/pandas" "gitee.com/quant1x/pandas/stat" "github.com/viterin/vek" ) @@ -10,11 +9,11 @@ import ( // // 判断向上金叉穿越 CROSS(MA(C,5),MA(C,10)) // 判断向下死叉穿越 CROSS(MA(C,10),MA(C,5)) -func CROSS(S1, S2 pandas.Series) []bool { - r1 := S1.DTypes() - r2 := S2.DTypes() - r11 := S1.Ref(1).DTypes() - r12 := S2.Ref(1).DTypes() +func CROSS(S1, S2 []stat.DType) []bool { + r1 := S1 + r2 := S2 + r11 := REF2(S1, 1) + r12 := REF2(S2, 1) b1 := CompareLt(r11, r12) b2 := CompareGte(r1, r2) @@ -23,11 +22,11 @@ func CROSS(S1, S2 pandas.Series) []bool { return c } -func CROSS2(S1, S2 []stat.DType) []bool { - r1 := S1 - r2 := S2 - r11 := REF2(S1, 1) - r12 := REF2(S2, 1) +func CROSS1(S1, S2 stat.Series) []bool { + r1 := S1.DTypes() + r2 := S2.DTypes() + r11 := S1.Ref(1).DTypes() + r12 := S2.Ref(1).DTypes() b1 := CompareLt(r11, r12) b2 := CompareGte(r1, r2) diff --git a/formula/cross_test.go b/formula/cross_test.go index 3a3d293bb7e3ffe9093a88ea87e3bbaa0f76a481..a806ca42050790a7bc08579ac320848f621c1cb6 100644 --- a/formula/cross_test.go +++ b/formula/cross_test.go @@ -13,5 +13,6 @@ func TestCROSS(t *testing.T) { s1 := pandas.NewSeriesWithoutType("d1", d1) s2 := pandas.NewSeriesWithoutType("d2", d2) - fmt.Println(CROSS(s1, s2)) + fmt.Println(CROSS(d1, d2)) + fmt.Println(CROSS1(s1, s2)) } diff --git a/formula/dma.go b/formula/dma.go index 7e561fc6bc316f5dc6a494259dba5b6382c17c5f..91ea0a8371f61e7aa72c1715402b0b9a2518ebaf 100644 --- a/formula/dma.go +++ b/formula/dma.go @@ -1,7 +1,6 @@ package formula import ( - "gitee.com/quant1x/pandas" "gitee.com/quant1x/pandas/stat" ) @@ -13,10 +12,10 @@ import ( // 算法:Y=A*X+(1-A)*Y',其中Y'表示上一周期Y值,A必须大于0且小于1.A支持变量 // 例如: // DMA(CLOSE,VOL/CAPITAL)表示求以换手率作平滑因子的平均价 -func DMA(S pandas.Series, A any) []stat.DType { +func DMA(S stat.Series, A any) []stat.DType { switch N := A.(type) { case /*nil, */ int8, uint8, int16, uint16, int32, uint32, int64, uint64, int, uint, float32, float64 /*, bool, string*/ : - x := S.EWM(pandas.EW{Alpha: 1 / stat.Any2DType(N), Adjust: false}).Mean().DTypes() + x := S.EWM(stat.EW{Alpha: 1 / stat.Any2DType(N), Adjust: false}).Mean().DTypes() return x case []stat.DType: s := S.DTypes() @@ -31,7 +30,7 @@ func DMA(S pandas.Series, A any) []stat.DType { Y[i] = a*s[i] + (1-a)*Y[i-1] } return Y - case pandas.Series: + case stat.Series: s := S.DTypes() M := N.DTypes() stat.Fill(M, 1.0, true) diff --git a/formula/dma_test.go b/formula/dma_test.go index e84059408971ecaed01d276a3f5bf6a86c9175e2..96d198a152e65dd7aa85d92889993a261cfd27cc 100644 --- a/formula/dma_test.go +++ b/formula/dma_test.go @@ -25,8 +25,8 @@ func TestDMA(t *testing.T) { cs := CLOSE.Values().([]float32) REF10 := REF(CLOSE, 10).([]float32) - v1 := vek32.Div(cs, REF10) - df01 := pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "x", v1) + d1 := vek32.Div(cs, REF10) + df01 := pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "x", d1) x0 := make([]stat.DType, CLOSE.Len()) df01.Apply(func(idx int, v any) { f := v.(float32) @@ -36,12 +36,12 @@ func TestDMA(t *testing.T) { } x0[idx] = t }) - n := BARSLAST(pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "", x0)) + n := BARSLAST(pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "", x0)) fmt.Println(n[len(n)-10:]) - x := DMA(CLOSE, pandas.NewSeries(pandas.SERIES_TYPE_DTYPE, "", n)) + x := DMA(CLOSE, pandas.NewSeries(stat.SERIES_TYPE_DTYPE, "", n)) //x := EMA(CLOSE, 7) - sx := pandas.NewSeries(pandas.SERIES_TYPE_DTYPE, "x", x) + sx := pandas.NewSeries(stat.SERIES_TYPE_DTYPE, "x", x) df = pandas.NewDataFrame(CLOSE, sx) fmt.Println(df) } diff --git a/formula/ema.go b/formula/ema.go index 8fae77a016871bbd26df5fc0bdf69b31da4e3157..33590e775c4bf929f4b2de85f64ebc3eaf93513c 100644 --- a/formula/ema.go +++ b/formula/ema.go @@ -1,7 +1,6 @@ package formula import ( - "gitee.com/quant1x/pandas" "gitee.com/quant1x/pandas/exception" "gitee.com/quant1x/pandas/stat" ) @@ -9,19 +8,19 @@ import ( // EMA 指数移动平均,为了精度 S>4*N EMA至少需要120周期 // alpha=2/(span+1) // TODO:这个版本是对的, 通达信EMA居然实现了真的序列, 那为啥SMA不是呢?! -func EMA(S pandas.Series, N any) any { +func EMA(S stat.Series, N any) any { var X []stat.DType switch v := N.(type) { case int: X = stat.Repeat[stat.DType](stat.DType(v), S.Len()) - case pandas.Series: + case stat.Series: vs := v.DTypes() X = stat.Align(vs, stat.DTypeNaN, S.Len()) default: panic(exception.New(1, "error window")) } k := X[0] - x := S.EWM(pandas.EW{Span: stat.DTypeNaN, Callback: func(idx int) stat.DType { + x := S.EWM(stat.EW{Span: stat.DTypeNaN, Callback: func(idx int) stat.DType { j := X[idx] if j == 0 { j = 1 @@ -35,44 +34,44 @@ func EMA(S pandas.Series, N any) any { } // EMA_v2 通达信公式管理器上提示, EMA(S, N) 相当于SMA(S, N + 1, M=2), 骗子, 根本不对 -func EMA_v2(S pandas.Series, N any) any { +func EMA_v2(S stat.Series, N any) any { M := 2 var X float32 switch v := N.(type) { case int: X = float32(v) - case pandas.Series: + case stat.Series: vs := v.Values() fs := stat.SliceToFloat32(vs) X = fs[len(fs)-1] default: panic(exception.New(1, "error window")) } - x := S.EWM(pandas.EW{Alpha: float64(M) / float64(X+1), Adjust: false}).Mean().Values() + x := S.EWM(stat.EW{Alpha: float64(M) / float64(X+1), Adjust: false}).Mean().Values() return x } // EMA_v0 仿SMA实现, 错误 -func EMA_v0(S pandas.Series, N any) any { +func EMA_v0(S stat.Series, N any) any { var X float32 switch v := N.(type) { case int: X = float32(v) - case pandas.Series: + case stat.Series: vs := v.Values() fs := stat.SliceToFloat32(vs) X = fs[len(fs)-1] default: panic(exception.New(1, "error window")) } - x := S.EWM(pandas.EW{Span: stat.DType(X), Adjust: false}).Mean().Values() + x := S.EWM(stat.EW{Span: stat.DType(X), Adjust: false}).Mean().Values() return x } // EMA_v1 Rolling(N), 每个都取最后一个, 错误 -func EMA_v1(S pandas.Series, N any) any { - x := S.Rolling(N).Apply(func(S pandas.Series, N stat.DType) stat.DType { - r := S.EWM(pandas.EW{Span: N, Adjust: false}).Mean().DTypes() +func EMA_v1(S stat.Series, N any) any { + x := S.Rolling(N).Apply(func(S stat.Series, N stat.DType) stat.DType { + r := S.EWM(stat.EW{Span: N, Adjust: false}).Mean().DTypes() if len(r) == 0 { return stat.DTypeNaN } diff --git a/formula/ema_test.go b/formula/ema_test.go index 2c06b410ccd0cfccc265a3528ec3a3a24041ff74..970e5f7514d88d160d02bf742dbc5b0fb41e6cd6 100644 --- a/formula/ema_test.go +++ b/formula/ema_test.go @@ -21,8 +21,8 @@ func TestEMA(t *testing.T) { cs := CLOSE.Values().([]float32) REF10 := REF(CLOSE, 10).([]float32) - v1 := vek32.Div(cs, REF10) - df01 := pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "x", v1) + d1 := vek32.Div(cs, REF10) + df01 := pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "x", d1) x0 := make([]stat.DType, CLOSE.Len()) df01.Apply(func(idx int, v any) { f := v.(float32) @@ -33,12 +33,12 @@ func TestEMA(t *testing.T) { x0[idx] = t }) //x := stat.Where(v2, as, bs) - n := BARSLAST(pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "", x0)) + n := BARSLAST(pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "", x0)) fmt.Println(n[len(n)-10:]) - x := EMA(CLOSE, pandas.NewSeries(pandas.SERIES_TYPE_DTYPE, "", n)) + x := EMA(CLOSE, pandas.NewSeries(stat.SERIES_TYPE_DTYPE, "", n)) //x := EMA(CLOSE, 7) - sx := pandas.NewSeries(pandas.SERIES_TYPE_DTYPE, "x", x) + sx := pandas.NewSeries(stat.SERIES_TYPE_DTYPE, "x", x) df = pandas.NewDataFrame(CLOSE, sx) fmt.Println(df) diff --git a/formula/forcast.go b/formula/forcast.go index 3dfa29ea95815032fa202409b4d1a26151bc982d..da07bc1eff5a0bf24b078cb8c2eef57e9a078b22 100644 --- a/formula/forcast.go +++ b/formula/forcast.go @@ -1,13 +1,12 @@ package formula import ( - "gitee.com/quant1x/pandas" "gitee.com/quant1x/pandas/stat" ) // FORCAST 返回S序列N周期回线性回归后的预测值 -func FORCAST(S pandas.Series, N any) any { - return S.Rolling(N).Apply(func(X pandas.Series, W stat.DType) stat.DType { +func FORCAST(S stat.Series, N any) any { + return S.Rolling(N).Apply(func(X stat.Series, W stat.DType) stat.DType { x := X.DTypes() ws := stat.Range[float64](int(W)) c := stat.PolyFit(ws, x, 1) diff --git a/formula/hhv.go b/formula/hhv.go index ec4faa093fa7c85798ee38279736bfa825b69ab0..f15bee117f559e813004d05739e088cb5b621cbb 100644 --- a/formula/hhv.go +++ b/formula/hhv.go @@ -1,8 +1,10 @@ package formula -import "gitee.com/quant1x/pandas" +import ( + "gitee.com/quant1x/pandas/stat" +) // HHV 最近N周期的S最大值 -func HHV(S pandas.Series, N any) pandas.Series { +func HHV(S stat.Series, N any) stat.Series { return S.Rolling(N).Max() } diff --git a/formula/if.go b/formula/if.go index df67e9233d0b05c0aa53a055de2e8cb24903e474..a1119030ebb64d5f572623df85e0bc58a2b97cf9 100644 --- a/formula/if.go +++ b/formula/if.go @@ -6,20 +6,20 @@ import ( ) // IF 序列布尔判断 return=A if S==True else B -func IF(S, A, B pandas.Series) pandas.Series { +func IF(S, A, B stat.Series) stat.Series { return IFF(S, A, B) } // IFF 序列布尔判断 return=A if S==True else B -func IFF(S, A, B pandas.Series) pandas.Series { - s := S.Float() - a := A.Float() - b := B.Float() +func IFF(S, A, B stat.Series) stat.Series { + s := S.Floats() + a := A.Floats() + b := B.Floats() ret := stat.Where(s, a, b) - return pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "", ret) + return pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "", ret) } // IFN 序列布尔判断 return=A if S==False else B -func IFN(S, A, B pandas.Series) pandas.Series { +func IFN(S, A, B stat.Series) stat.Series { return IFF(S, B, A) } diff --git a/formula/if_test.go b/formula/if_test.go index b45acf2d6763800628fffcd81cb1b9b918a40a04..196ab312bb17ff9be0c17ea001e4b11573cdd0b4 100644 --- a/formula/if_test.go +++ b/formula/if_test.go @@ -3,26 +3,27 @@ package formula import ( "fmt" "gitee.com/quant1x/pandas" + "gitee.com/quant1x/pandas/stat" "testing" ) func TestIF(t *testing.T) { - S := pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "", []float32{1, 1, 1}) - A := pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "", []float32{11, 12, 13}) - B := pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "", []float32{21, 22, 23}) + S := pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "", []float32{1, 1, 1}) + A := pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "", []float32{11, 12, 13}) + B := pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "", []float32{21, 22, 23}) fmt.Println(IF(S, A, B)) } func TestIFF(t *testing.T) { - S := pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "", []float32{1, 1, 1}) - A := pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "", []float32{11, 12, 13}) - B := pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "", []float32{21, 22, 23}) + S := pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "", []float32{1, 1, 1}) + A := pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "", []float32{11, 12, 13}) + B := pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "", []float32{21, 22, 23}) fmt.Println(IFF(S, A, B)) } func TestIFN(t *testing.T) { - S := pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "", []float32{1, 0, 1}) - A := pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "", []float32{11, 12, 13}) - B := pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "", []float32{21, 22, 23}) + S := pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "", []float32{1, 0, 1}) + A := pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "", []float32{11, 12, 13}) + B := pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "", []float32{21, 22, 23}) fmt.Println(IFN(S, A, B)) } diff --git a/formula/last.go b/formula/last.go index 62762f37217b1b39aafdfdce6f1d4a07fa18e0d1..d76b6988dd10666e476104d5387dd4b86c3ba7d7 100644 --- a/formula/last.go +++ b/formula/last.go @@ -1,7 +1,6 @@ package formula import ( - "gitee.com/quant1x/pandas" "gitee.com/quant1x/pandas/stat" ) @@ -12,8 +11,8 @@ import ( // LAST(CLOSE>OPEN,10,5) // 表示从前10日到前5日内一直阳线 // 若A为0,表示从第一天开始,B为0,表示到最后日止 -func LAST(X pandas.Series, A, B int) pandas.Series { - s := X.Rolling(A + 1).Apply(func(S pandas.Series, N stat.DType) stat.DType { +func LAST(X stat.Series, A, B int) stat.Series { + s := X.Rolling(A + 1).Apply(func(S stat.Series, N stat.DType) stat.DType { s := S.DTypes() s = stat.Reverse(s) T := s[B:] diff --git a/formula/llv.go b/formula/llv.go index b26a7fe65e3c84e5f7d5c67ec87df463176a045a..75d81703e59af6057079f40605ce4f956c8ae904 100644 --- a/formula/llv.go +++ b/formula/llv.go @@ -1,8 +1,10 @@ package formula -import "gitee.com/quant1x/pandas" +import ( + "gitee.com/quant1x/pandas/stat" +) // LLV 最近N周期的S最小值 -func LLV(S pandas.Series, N any) pandas.Series { +func LLV(S stat.Series, N any) stat.Series { return S.Rolling(N).Min() } diff --git a/formula/ma.go b/formula/ma.go index c0a3cb272ad14d78f31142b22c465768cab255c4..4ae4fe8324b6821069a392a8752903b8281dfdb9 100644 --- a/formula/ma.go +++ b/formula/ma.go @@ -1,23 +1,11 @@ package formula import ( - "gitee.com/quant1x/pandas" "gitee.com/quant1x/pandas/stat" ) // MA 计算移动均线 // 求序列的N日简单移动平均值, 返回序列 -func MA(S pandas.Series, N any) []stat.DType { - //var X []float32 - //switch v := N.(type) { - //case int: - // X = stat.Repeat[float32](float32(v), S.Len()) - //case pandas.Series: - // vs := v.Values() - // X = pandas.SliceToFloat32(vs) - // X = stat.Align(X, pandas.Nil2Float32, S.Len()) - //default: - // panic(exception.New(1, "error window")) - //} +func MA(S stat.Series, N any) []stat.DType { return S.Rolling(N).Mean().DTypes() } diff --git a/formula/max.go b/formula/max.go index adf5618b4c90da324a7b0da87898c488977736af..2ce13743619f571633f9bb41364d6562b4fc389c 100644 --- a/formula/max.go +++ b/formula/max.go @@ -6,8 +6,8 @@ import ( ) // MAX 两个序列横向对比 -func MAX(S1, S2 pandas.Series) pandas.Series { - d := stat.Maximum(S1.Float(), S2.Float()) - return pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "", d) +func MAX(S1, S2 stat.Series) stat.Series { + d := stat.Maximum(S1.Floats(), S2.Floats()) + return pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "", d) } diff --git a/formula/max_test.go b/formula/max_test.go index ea4fbaaaf4e6d383f88b5389fa82d7ec7b4c66f8..1a23590052f9cc70fbb5f91b9bcedfd763a60f4e 100644 --- a/formula/max_test.go +++ b/formula/max_test.go @@ -3,6 +3,7 @@ package formula import ( "fmt" "gitee.com/quant1x/pandas" + "gitee.com/quant1x/pandas/stat" "math" "testing" ) @@ -12,7 +13,7 @@ func TestMAX(t *testing.T) { fmt.Println(float64(1.4) < math.NaN()) f1 := []float32{1.1, 2.2, 1.3, 1.4} f2 := []float32{1.2, 1.2, 3.3} - s1 := pandas.NewSeries(pandas.SERIES_TYPE_FLOAT64, "x1", f1) - s2 := pandas.NewSeries(pandas.SERIES_TYPE_FLOAT64, "x2", f2) + s1 := pandas.NewSeries(stat.SERIES_TYPE_FLOAT64, "x1", f1) + s2 := pandas.NewSeries(stat.SERIES_TYPE_FLOAT64, "x2", f2) fmt.Println(MAX(s1, s2)) } diff --git a/formula/min.go b/formula/min.go index 4d5628709d0d37e00006f4ca15687d672ab5ef87..2303e4f81fe25e7eaaeb1f2453c5165778896a05 100644 --- a/formula/min.go +++ b/formula/min.go @@ -6,8 +6,8 @@ import ( ) // MIN 两个序列横向对比 -func MIN(S1, S2 pandas.Series) pandas.Series { - d := stat.Minimum(S1.Float(), S2.Float()) - return pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "", d) +func MIN(S1, S2 stat.Series) stat.Series { + d := stat.Minimum(S1.Floats(), S2.Floats()) + return pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "", d) } diff --git a/formula/min_test.go b/formula/min_test.go index 49fa900a736e21a418a3125667ba73abf99a28cd..22a0203cfbdd58cc970382a6bf21eeb257adb75c 100644 --- a/formula/min_test.go +++ b/formula/min_test.go @@ -3,13 +3,14 @@ package formula import ( "fmt" "gitee.com/quant1x/pandas" + "gitee.com/quant1x/pandas/stat" "testing" ) func TestMIN(t *testing.T) { f1 := []float32{1.1, 2.2, 1.3, 1.4} f2 := []float32{1.2, 1.2, 3.3} - s1 := pandas.NewSeries(pandas.SERIES_TYPE_FLOAT64, "x1", f1) - s2 := pandas.NewSeries(pandas.SERIES_TYPE_FLOAT64, "x2", f2) + s1 := pandas.NewSeries(stat.SERIES_TYPE_FLOAT64, "x1", f1) + s2 := pandas.NewSeries(stat.SERIES_TYPE_FLOAT64, "x2", f2) fmt.Println(MIN(s1, s2)) } diff --git a/formula/ref.go b/formula/ref.go index 3e201c29b3dfac21efd2788221eed549230f6c47..b74adf3bccc09891200946a2fb765ce5965fb44f 100644 --- a/formula/ref.go +++ b/formula/ref.go @@ -1,21 +1,20 @@ package formula import ( - "gitee.com/quant1x/pandas" "gitee.com/quant1x/pandas/exception" "gitee.com/quant1x/pandas/stat" ) // REF 引用前N的序列 -func REF(S pandas.Series, N any) any { +func REF(S stat.Series, N any) any { var X []float32 switch v := N.(type) { case int: X = stat.Repeat[float32](float32(v), S.Len()) - case pandas.Series: + case stat.Series: vs := v.Values() X = stat.SliceToFloat32(vs) - X = stat.Align(X, pandas.Nil2Float32, S.Len()) + X = stat.Align(X, stat.Nil2Float32, S.Len()) default: panic(exception.New(1, "error window")) } diff --git a/formula/slope.go b/formula/slope.go index 2e1461c16724bf90cf7bd76c51cb8c7297dbb710..eef1e36a5417b4986aad495e94a3aa8cf38171d1 100644 --- a/formula/slope.go +++ b/formula/slope.go @@ -1,15 +1,14 @@ package formula import ( - "gitee.com/quant1x/pandas" "gitee.com/quant1x/pandas/stat" ) // SLOPE 计算周期回线性回归斜率 // // SLOPE(S,N) 返回线性回归斜率,N支持变量 -func SLOPE(S pandas.Series, N any) any { - return S.Rolling(N).Apply(func(X pandas.Series, W stat.DType) stat.DType { +func SLOPE(S stat.Series, N any) any { + return S.Rolling(N).Apply(func(X stat.Series, W stat.DType) stat.DType { x := X.DTypes() w := stat.Range[stat.DType](len(x)) c := stat.PolyFit(w, x, 1) diff --git a/formula/sma.go b/formula/sma.go index ea982e0b78a8e931af9f8ca10c0ace0695f615f0..9cf4abda0c9465f5d6cfc78d7d58512cf2f92cf8 100644 --- a/formula/sma.go +++ b/formula/sma.go @@ -1,13 +1,12 @@ package formula import ( - "gitee.com/quant1x/pandas" "gitee.com/quant1x/pandas/exception" "gitee.com/quant1x/pandas/stat" ) // SMA 中国式的SMA,至少需要120周期才精确 (雪球180周期) alpha=1/(1+com) -func SMA(S pandas.Series, N any, M int) any { +func SMA(S stat.Series, N any, M int) any { if M == 0 { M = 1 } @@ -15,19 +14,19 @@ func SMA(S pandas.Series, N any, M int) any { switch v := N.(type) { case int: X = float32(v) - case pandas.Series: + case stat.Series: vs := v.Values() fs := stat.SliceToFloat32(vs) X = fs[len(fs)-1] default: panic(exception.New(1, "error window")) } - x := S.EWM(pandas.EW{Alpha: float64(M) / float64(X), Adjust: false}).Mean().Values() + x := S.EWM(stat.EW{Alpha: float64(M) / float64(X), Adjust: false}).Mean().Values() return x } // 最接近 -func SMA_v5(S pandas.Series, N any, M int) any { +func SMA_v5(S stat.Series, N any, M int) any { if M == 0 { M = 1 } @@ -35,15 +34,15 @@ func SMA_v5(S pandas.Series, N any, M int) any { switch v := N.(type) { case int: X = stat.Repeat[float32](float32(v), S.Len()) - case pandas.Series: + case stat.Series: vs := v.Values() X = stat.SliceToFloat32(vs) - X = stat.Align(X, pandas.Nil2Float32, S.Len()) + X = stat.Align(X, stat.Nil2Float32, S.Len()) default: panic(exception.New(1, "error window")) } k := X[0] - x := S.EWM(pandas.EW{Alpha: pandas.Nil2Float64, Callback: func(idx int) stat.DType { + x := S.EWM(stat.EW{Alpha: stat.Nil2Float64, Callback: func(idx int) stat.DType { j := X[idx] if j == 0 { j = 1 @@ -57,15 +56,15 @@ func SMA_v5(S pandas.Series, N any, M int) any { } // SMA_v4 听说SMA(S, N, 1) 其实就是MA(S,N), 试验后发现是骗子 -func SMA_v4(S pandas.Series, N any, M int) any { +func SMA_v4(S stat.Series, N any, M int) any { var X []float32 switch v := N.(type) { case int: X = stat.Repeat[float32](float32(v), S.Len()) - case pandas.Series: + case stat.Series: vs := v.Values() X = stat.SliceToFloat32(vs) - X = stat.Align(X, pandas.Nil2Float32, S.Len()) + X = stat.Align(X, stat.Nil2Float32, S.Len()) default: panic(exception.New(1, "error window")) } @@ -73,12 +72,12 @@ func SMA_v4(S pandas.Series, N any, M int) any { } // SMA_v3 使用滑动窗口 -func SMA_v3(S pandas.Series, N any, M int) any { +func SMA_v3(S stat.Series, N any, M int) any { if M == 0 { M = 1 } - x := S.Rolling(N).Apply(func(S pandas.Series, N stat.DType) stat.DType { - r := S.EWM(pandas.EW{Alpha: float64(M) / float64(N), Adjust: false}).Mean().Values().([]float64) + x := S.Rolling(N).Apply(func(S stat.Series, N stat.DType) stat.DType { + r := S.EWM(stat.EW{Alpha: float64(M) / float64(N), Adjust: false}).Mean().Values().([]float64) if len(r) == 0 { return stat.DTypeNaN } @@ -88,10 +87,10 @@ func SMA_v3(S pandas.Series, N any, M int) any { } // SMA_v1 最原始的python写法 -func SMA_v1(S pandas.Series, N int, M int) any { +func SMA_v1(S stat.Series, N int, M int) any { if M == 0 { M = 1 } - x := S.EWM(pandas.EW{Alpha: float64(M) / float64(N), Adjust: false}).Mean().Values() + x := S.EWM(stat.EW{Alpha: float64(M) / float64(N), Adjust: false}).Mean().Values() return x } diff --git a/formula/sma_test.go b/formula/sma_test.go index 0ddc743340d2d57619456c1133ab3dd50e356894..8c38b069873ad8d398b1f2853ef8c9dfab3c274d 100644 --- a/formula/sma_test.go +++ b/formula/sma_test.go @@ -3,6 +3,7 @@ package formula import ( "fmt" "gitee.com/quant1x/pandas" + "gitee.com/quant1x/pandas/stat" "github.com/viterin/vek/vek32" "testing" ) @@ -20,10 +21,10 @@ func TestSMA(t *testing.T) { CLOSE := df.Col("close") cs := CLOSE.Values().([]float32) REF10 := REF(CLOSE, 10).([]float32) - v1 := vek32.Div(cs, REF10) + d1 := vek32.Div(cs, REF10) //as := stat.Repeat[float32](1, CLOSE.Len()) //bs := stat.Repeat[float32](0, CLOSE.Len()) - df01 := pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "x", v1) + df01 := pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "x", d1) x := make([]float32, CLOSE.Len()) df01.Apply(func(idx int, v any) { f := v.(float32) @@ -34,11 +35,11 @@ func TestSMA(t *testing.T) { x[idx] = t }) //x := stat.Where(v2, as, bs) - n := BARSLAST(pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "", x)) + n := BARSLAST(pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "", x)) fmt.Println(n[len(n)-10:]) //r1 := SMA(CLOSE, pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "", n), 1) r1 := SMA(CLOSE, 7, 1) - s2 := pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "sma", r1) + s2 := pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "sma", r1) df2 := pandas.NewDataFrame(s2) fmt.Println(df2) } diff --git a/formula/sqrt.go b/formula/sqrt.go index 58945556802768830d1da2d9da3ae8b0355ce474..8caccbde2a94b920fbc8579cdb67868639288207 100644 --- a/formula/sqrt.go +++ b/formula/sqrt.go @@ -1,12 +1,11 @@ package formula import ( - "gitee.com/quant1x/pandas" "gitee.com/quant1x/pandas/stat" ) // SQRT 求S的平方根 -func SQRT(S pandas.Series) []stat.DType { +func SQRT(S stat.Series) []stat.DType { fs := S.DTypes() return stat.Sqrt(fs) } diff --git a/formula/sqrt_test.go b/formula/sqrt_test.go index da950fe99b29d0ba45fd574b4232633e55b64c16..fd9db2b0dab07b7380cfd18ced2bb02e1d4434e1 100644 --- a/formula/sqrt_test.go +++ b/formula/sqrt_test.go @@ -3,14 +3,15 @@ package formula import ( "fmt" "gitee.com/quant1x/pandas" + "gitee.com/quant1x/pandas/stat" "testing" ) func TestSQRT(t *testing.T) { f1 := []float32{1.1, 2.2, 1.3, 1.4} f2 := []float64{70, 80, 75, 83, 86} - s1 := pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "s1", f1) - s2 := pandas.NewSeries(pandas.SERIES_TYPE_FLOAT64, "s2", f2) + s1 := pandas.NewSeries(stat.SERIES_TYPE_FLOAT32, "s1", f1) + s2 := pandas.NewSeries(stat.SERIES_TYPE_FLOAT64, "s2", f2) fmt.Println(SQRT(s1)) fmt.Println(SQRT(s2)) } diff --git a/formula/std.go b/formula/std.go index 4612c708c417a8c2d981a03bb3884db083298564..795836edb88c7bdd97f1f0599f69cd697000042e 100644 --- a/formula/std.go +++ b/formula/std.go @@ -1,8 +1,8 @@ package formula -import "gitee.com/quant1x/pandas" +import "gitee.com/quant1x/pandas/stat" // STD 序列的N日标准差 -func STD(S pandas.Series, N any) pandas.Series { +func STD(S stat.Series, N any) stat.Series { return S.Rolling(N).Std() } diff --git a/formula/sum.go b/formula/sum.go index 932d30758823f808cc6ccc26f47be2d0cbcc9e7d..3716b0f4528490ddcf02969057fcfef1844d3194 100644 --- a/formula/sum.go +++ b/formula/sum.go @@ -1,13 +1,11 @@ package formula -import ( - "gitee.com/quant1x/pandas" -) +import "gitee.com/quant1x/pandas/stat" // SUM 求累和 // 如果N=0, 则从第一个有效值累加到当前 // 下一步再统一返回值 -func SUM(S pandas.Series, N any) any { +func SUM(S stat.Series, N any) any { //var X []float32 //switch v := N.(type) { //case int: diff --git a/formula/wma.go b/formula/wma.go index 0bb2f7f0db17e258d278a3240e571246c514d9e5..ca838669f2db170313326112b54ce3a8da046b06 100644 --- a/formula/wma.go +++ b/formula/wma.go @@ -1,24 +1,23 @@ package formula import ( - "gitee.com/quant1x/pandas" "gitee.com/quant1x/pandas/exception" "gitee.com/quant1x/pandas/stat" ) // WMA 通达信S序列的N日加权移动平均 Yn = (1*X1+2*X2+3*X3+...+n*Xn)/(1+2+3+...+Xn) -func WMA(S pandas.Series, N any) any { +func WMA(S stat.Series, N any) any { var X []stat.DType switch v := N.(type) { case int: X = stat.Repeat[stat.DType](stat.DType(v), S.Len()) - case pandas.Series: + case stat.Series: vs := v.DTypes() X = stat.Align(vs, stat.DTypeNaN, S.Len()) default: panic(exception.New(1, "error window")) } - return S.Rolling(X).Apply(func(S pandas.Series, N stat.DType) stat.DType { + return S.Rolling(X).Apply(func(S stat.Series, N stat.DType) stat.DType { if S.Len() == 0 { return stat.DTypeNaN } diff --git a/generic.go b/generic.go index 90c5605ff757ea92e3e140fa0f1ca6132c00b9d9..9e63db5a0d3f6c2180e1167bbb9bd40db4459db7 100644 --- a/generic.go +++ b/generic.go @@ -12,12 +12,11 @@ type NDFrame struct { lock sync.RWMutex // 读写锁 formatter stat.StringFormatter // 字符串格式化工具 name string // 帧名称 - type_ Type // values元素类型 + type_ stat.Type // values元素类型 copy_ bool // 是否副本 nilCount int // nil和nan的元素有多少, 这种统计在bool和int64类型中不会大于0, 只对float64及string有效 rows int // 行数 values any // 只能是一个一维slice, 在所有的运算中, values强制转换成float64切片 - } //""" @@ -35,7 +34,7 @@ func NewNDFrame[E stat.GenericType](name string, rows ...E) *NDFrame { frame := NDFrame{ formatter: stat.DefaultFormatter, name: name, - type_: SERIES_TYPE_INVAILD, + type_: stat.SERIES_TYPE_INVAILD, nilCount: 0, rows: 0, values: []E{}, @@ -54,20 +53,20 @@ func NewNDFrame[E stat.GenericType](name string, rows ...E) *NDFrame { // 赋值 func assign[T stat.GenericType](frame *NDFrame, idx, size int, v T) { // 检测类型 - if frame.type_ == SERIES_TYPE_INVAILD { + if frame.type_ == stat.SERIES_TYPE_INVAILD { _type, _ := detectTypes(v) - if _type != SERIES_TYPE_INVAILD { + if _type != stat.SERIES_TYPE_INVAILD { frame.type_ = _type } } _vv := reflect.ValueOf(v) _vi := _vv.Interface() // float和string类型有可能是NaN, 对nil和NaN进行计数 - if frame.Type() == SERIES_TYPE_FLOAT32 && stat.Float32IsNaN(_vi.(float32)) { + if frame.Type() == stat.SERIES_TYPE_FLOAT32 && stat.Float32IsNaN(_vi.(float32)) { frame.nilCount++ - } else if frame.Type() == SERIES_TYPE_FLOAT64 && stat.Float64IsNaN(_vi.(float64)) { + } else if frame.Type() == stat.SERIES_TYPE_FLOAT64 && stat.Float64IsNaN(_vi.(float64)) { frame.nilCount++ - } else if frame.Type() == SERIES_TYPE_STRING && stat.StringIsNaN(_vi.(string)) { + } else if frame.Type() == stat.SERIES_TYPE_STRING && stat.StringIsNaN(_vi.(string)) { frame.nilCount++ // 以下修正string的NaN值, 统一为"NaN" //_rv := reflect.ValueOf(StringNaN) @@ -120,7 +119,7 @@ func (self *NDFrame) Rename(n string) { self.name = n } -func (self *NDFrame) Type() Type { +func (self *NDFrame) Type() stat.Type { return self.type_ } @@ -146,7 +145,7 @@ func (self *NDFrame) NaN() any { } } -func (self *NDFrame) Float() []float32 { +func (self *NDFrame) Floats() []float32 { return stat.SliceToFloat32(self.values) } @@ -156,14 +155,17 @@ func (self *NDFrame) DTypes() []stat.DType { } // AsInt 强制转换成整型 -func (self *NDFrame) AsInt() []stat.Int { +func (self *NDFrame) Ints() []stat.Int { values := self.DTypes() fs := stat.Fill[stat.DType](values, stat.DType(0)) ns := stat.DType2Int(fs) return ns } -func (self *NDFrame) Empty() Series { +func (self *NDFrame) Empty(t ...stat.Type) stat.Series { + if len(t) > 0 { + self.type_ = t[0] + } var frame NDFrame if self.type_ == stat.SERIES_TYPE_STRING { frame = NDFrame{ @@ -224,7 +226,7 @@ func (self *NDFrame) Records() []string { return ret } -func (self *NDFrame) Repeat(x any, repeats int) Series { +func (self *NDFrame) Repeat(x any, repeats int) stat.Series { switch values := self.values.(type) { case []bool: _ = values @@ -245,9 +247,9 @@ func (self *NDFrame) Repeat(x any, repeats int) Series { } } -func (self *NDFrame) Shift(periods int) Series { - var d Series - d = clone(self).(Series) +func (self *NDFrame) Shift(periods int) stat.Series { + var d stat.Series + d = stat.Clone(self).(stat.Series) //return Shift[float64](&d, periods, func() float64 { // return Nil2Float64 //}) @@ -267,18 +269,18 @@ func (self *NDFrame) Shift(periods int) Series { }) case []float32: return Shift[float32](&d, periods, func() float32 { - return Nil2Float32 + return stat.Nil2Float32 }) default: //case []float64: return Shift[float64](&d, periods, func() float64 { - return Nil2Float64 + return stat.Nil2Float64 }) } } func (self *NDFrame) Mean() stat.DType { if self.Len() < 1 { - return NaN() + return stat.NaN() } fs := make([]stat.DType, 0) self.Apply(func(idx int, v any) { @@ -291,7 +293,7 @@ func (self *NDFrame) Mean() stat.DType { func (self *NDFrame) StdDev() stat.DType { if self.Len() < 1 { - return NaN() + return stat.NaN() } values := make([]stat.DType, self.Len()) self.Apply(func(idx int, v any) { @@ -303,7 +305,7 @@ func (self *NDFrame) StdDev() stat.DType { func (self *NDFrame) Std() stat.DType { if self.Len() < 1 { - return NaN() + return stat.NaN() } values := make([]stat.DType, self.Len()) self.Apply(func(idx int, v any) { @@ -313,7 +315,7 @@ func (self *NDFrame) Std() stat.DType { return stdDev } -func (self *NDFrame) FillNa(v any, inplace bool) Series { +func (self *NDFrame) FillNa(v any, inplace bool) stat.Series { values := self.Values() switch rows := values.(type) { case []string: diff --git a/generic_append.go b/generic_append.go index 5bdc9bb1434ae4ee6d69e7798177aaf81d078f30..159370803d4850661ed768cc6a585f49a1bd7e89 100644 --- a/generic_append.go +++ b/generic_append.go @@ -7,16 +7,16 @@ import ( // 插入一条记录 func (self *NDFrame) insert(idx, size int, v any) { - if self.type_ == SERIES_TYPE_BOOL { + if self.type_ == stat.SERIES_TYPE_BOOL { val := stat.AnyToBool(v) assign[bool](self, idx, size, val) - } else if self.type_ == SERIES_TYPE_INT64 { + } else if self.type_ == stat.SERIES_TYPE_INT64 { val := stat.AnyToInt64(v) assign[int64](self, idx, size, val) - } else if self.type_ == SERIES_TYPE_FLOAT32 { + } else if self.type_ == stat.SERIES_TYPE_FLOAT32 { val := stat.AnyToFloat32(v) assign[float32](self, idx, size, val) - } else if self.type_ == SERIES_TYPE_FLOAT64 { + } else if self.type_ == stat.SERIES_TYPE_FLOAT64 { val := stat.AnyToFloat64(v) assign[float64](self, idx, size, val) } else { @@ -26,7 +26,7 @@ func (self *NDFrame) insert(idx, size int, v any) { } // Append 批量增加记录 -func (self *NDFrame) Append(values ...any) { +func (self *NDFrame) Append(values ...any) stat.Series { size := 0 for idx, v := range values { switch val := v.(type) { @@ -51,4 +51,5 @@ func (self *NDFrame) Append(values ...any) { } } } + return self } diff --git a/generic_diff.go b/generic_diff.go index f9ba946f0c774605ef9daa2da1a345dc57069562..df916d756ab5d41ef57eeda73c13678786cc2173 100644 --- a/generic_diff.go +++ b/generic_diff.go @@ -9,28 +9,28 @@ import ( // First discrete difference of element. // Calculates the difference of a {klass} element compared with another // element in the {klass} (default is element in previous row). -func (self *NDFrame) Diff(param any) (s Series) { - if !(self.type_ == SERIES_TYPE_INT64 || self.type_ == SERIES_TYPE_FLOAT32 || self.type_ == SERIES_TYPE_FLOAT64) { - return NewSeries(SERIES_TYPE_INVAILD, "", "") +func (self *NDFrame) Diff(param any) (s stat.Series) { + if !(self.type_ == stat.SERIES_TYPE_INT64 || self.type_ == stat.SERIES_TYPE_FLOAT32 || self.type_ == stat.SERIES_TYPE_FLOAT64) { + return NewSeries(stat.SERIES_TYPE_INVAILD, "", "") } var N []stat.DType switch v := param.(type) { case int: N = stat.Repeat[stat.DType](stat.DType(v), self.Len()) - case Series: + case stat.Series: vs := v.DTypes() N = stat.Align(vs, stat.DTypeNaN, self.Len()) default: //periods = 1 N = stat.Repeat[stat.DType](stat.DType(1), self.Len()) } - r := RollingAndExpandingMixin{ - window: N, - series: self, + r := stat.RollingAndExpandingMixin{ + Window: N, + Series: self, } var d []stat.DType var front = stat.DTypeNaN - for _, block := range r.getBlocks() { + for _, block := range r.GetBlocks() { vs := reflect.ValueOf(block.Values()) vl := vs.Len() if vl == 0 { @@ -50,6 +50,6 @@ func (self *NDFrame) Diff(param any) (s Series) { d = append(d, diff) front = cf } - s = NewSeries(SERIES_TYPE_DTYPE, r.series.Name(), d) + s = NewSeries(stat.SERIES_TYPE_DTYPE, r.Series.Name(), d) return } diff --git a/generic_diff_test.go b/generic_diff_test.go index 5c4827bdc044e6e4d8ef16677800e49a4c9e2c73..3d15e385056f4fc1aed8740b8e1f5e7b4e6fa570 100644 --- a/generic_diff_test.go +++ b/generic_diff_test.go @@ -2,6 +2,7 @@ package pandas import ( "fmt" + "gitee.com/quant1x/pandas/stat" "testing" ) @@ -16,8 +17,8 @@ func TestNDFrame_Diff(t *testing.T) { r1 := df.Col("x").Diff(N).Values() fmt.Println("序列化结果:", r1) fmt.Println("------------------------------------------------------------") - d2 := []float64{1, 2, 3, 4, 3, 3, 2, 1, Nil2Float64, Nil2Float64, Nil2Float64, Nil2Float64} - s2 := NewSeries(SERIES_TYPE_FLOAT64, "x", d2) + d2 := []float64{1, 2, 3, 4, 3, 3, 2, 1, stat.Nil2Float64, stat.Nil2Float64, stat.Nil2Float64, stat.Nil2Float64} + s2 := NewSeries(stat.SERIES_TYPE_FLOAT64, "x", d2) fmt.Printf("序列化参数: %+v\n", s2.Values()) r2 := df.Col("x").Diff(s2).Values() fmt.Println("序列化结果:", r2) diff --git a/generic_ewm.go b/generic_ewm.go index 64b3d9dfd4c7253d49ac6ceae9ca676257279e48..2c5c7b3357dfcb2eaf642eadddac15c0d1e80c83 100644 --- a/generic_ewm.go +++ b/generic_ewm.go @@ -2,179 +2,35 @@ package pandas import ( "gitee.com/quant1x/pandas/stat" - "math" ) -type AlphaType int - -// https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.ewm.html -const ( - // AlphaAlpha Specify smoothing factor α directly, 0<α≤1. - AlphaAlpha AlphaType = iota - // AlphaCom Specify decay in terms of center of mass, α=1/(1+com), for com ≥ 0. - AlphaCom - // AlphaSpan Specify decay in terms of span, α=2/(span+1), for span ≥ 1. - AlphaSpan - // AlphaHalfLife Specify decay in terms of half-life, α=1−exp(−ln(2)/halflife), for halflife > 0. - AlphaHalfLife -) - -// EW (Factor) 指数加权(EW)计算Alpha 结构属性非0即为有效启动同名算法 -type EW struct { - Com stat.DType // 根据质心指定衰减 - Span stat.DType // 根据跨度指定衰减 - HalfLife stat.DType // 根据半衰期指定衰减 - Alpha stat.DType // 直接指定的平滑因子α - Adjust bool // 除以期初的衰减调整系数以核算 相对权重的不平衡(将 EWMA 视为移动平均线) - IgnoreNA bool // 计算权重时忽略缺失值 - Callback func(idx int) stat.DType -} - -// ExponentialMovingWindow 加权移动窗口 -type ExponentialMovingWindow struct { - data Series // 序列 - atype AlphaType // 计算方式: com/span/halflefe/alpha - param stat.DType // 参数类型为浮点 - adjust bool // 默认为真, 是否调整, 默认真时, 计算序列的EW移动平均线, 为假时, 计算指数加权递归 - ignoreNA bool // 默认为假, 计算权重时是否忽略缺失值NaN - minPeriods int // 默认为0, 窗口中具有值所需的最小观测值数,否则结果为NaN - axis int // {0,1}, 默认为0, 0跨行计算, 1跨列计算 - cb func(idx int) stat.DType -} - // EWM provides exponential weighted calculations. -func (s *NDFrame) EWM(alpha EW) ExponentialMovingWindow { - atype := AlphaAlpha +func (s *NDFrame) EWM(alpha stat.EW) stat.ExponentialMovingWindow { + atype := stat.AlphaAlpha param := 0.00 adjust := alpha.Adjust ignoreNA := alpha.IgnoreNA if alpha.Com != 0 { - atype = AlphaCom + atype = stat.AlphaCom param = alpha.Com } else if alpha.Span != 0 { - atype = AlphaSpan + atype = stat.AlphaSpan param = alpha.Span } else if alpha.HalfLife != 0 { - atype = AlphaHalfLife + atype = stat.AlphaHalfLife param = alpha.HalfLife } else { - atype = AlphaAlpha + atype = stat.AlphaAlpha param = alpha.Alpha } - dest := NewSeries(SERIES_TYPE_FLOAT64, s.name, s.Values()) - return ExponentialMovingWindow{ - data: dest, - atype: atype, - param: param, - adjust: adjust, - ignoreNA: ignoreNA, - cb: alpha.Callback, - } -} - -func (w ExponentialMovingWindow) Mean() Series { - var alpha stat.DType - - switch w.atype { - case AlphaAlpha: - if w.param <= 0 { - panic("alpha param must be > 0") - } - alpha = w.param - - case AlphaCom: - if w.param <= 0 { - panic("com param must be >= 0") - } - alpha = 1 / (1 + w.param) - - case AlphaSpan: - if w.param < 1 { - panic("span param must be >= 1") - } - alpha = 2 / (w.param + 1) - - case AlphaHalfLife: - if w.param <= 0 { - panic("halflife param must be > 0") - } - alpha = 1 - math.Exp(-math.Ln2/w.param) - } - - return w.applyMean(w.data, alpha) -} - -func (w ExponentialMovingWindow) applyMean(data Series, alpha stat.DType) Series { - if w.adjust { - w.adjustedMean(data, alpha, w.ignoreNA) - } else { - w.notadjustedMean(data, alpha, w.ignoreNA) - } - return data -} - -func (w ExponentialMovingWindow) adjustedMean(data Series, alpha stat.DType, ignoreNA bool) { - var ( - values = data.Values().([]stat.DType) - weight stat.DType = 1 - last = values[0] - ) - - alpha = 1 - alpha - for t := 1; t < len(values); t++ { - - w := alpha*weight + 1 - x := values[t] - if stat.DTypeIsNaN(x) { - if ignoreNA { - weight = w - } - values[t] = last - continue - } - - last = last + (x-last)/(w) - weight = w - values[t] = last - } -} - -func (w ExponentialMovingWindow) notadjustedMean(data Series, alpha stat.DType, ignoreNA bool) { - hasCallback := false - if stat.DTypeIsNaN(alpha) { - hasCallback = true - alpha = w.cb(0) - } - var ( - count int - values = data.Values().([]stat.DType) - //values = data.DTypes() // Dtypes有复制功能 - beta = 1 - alpha - last = values[0] - ) - if stat.Float64IsNaN(last) { - last = 0 - values[0] = last - } - for t := 1; t < len(values); t++ { - x := values[t] - - if stat.DTypeIsNaN(x) { - values[t] = last - continue - } - if hasCallback { - alpha = w.cb(t) - beta = 1 - alpha - } - // yt = (1−α)*y(t−1) + α*x(t) - last = (beta * last) + (alpha * x) - if stat.DTypeIsNaN(last) { - last = values[t-1] - } - values[t] = last - - count++ + dest := NewSeries(stat.SERIES_TYPE_FLOAT64, s.name, s.Values()) + return stat.ExponentialMovingWindow{ + Data: dest, + AType: atype, + Param: param, + Adjust: adjust, + IgnoreNA: ignoreNA, + Cb: alpha.Callback, } } diff --git a/generic_max.go b/generic_max.go index 2659f9e5dbf5622bd3490c66e0dd19d7f8337870..5f825b8de07fe20c71ec9eb9aa2d81d1741b9a7c 100644 --- a/generic_max.go +++ b/generic_max.go @@ -70,11 +70,11 @@ func (self *NDFrame) Max() any { _ = idx } if hasNan { - return Nil2Float32 + return stat.Nil2Float32 } else if i > 0 { return max } - return Nil2Float32 + return stat.Nil2Float32 //case []float32: // if self.Len() == 0 { // return Nil2Float32 @@ -96,11 +96,11 @@ func (self *NDFrame) Max() any { _ = idx } if hasNaN { - return Nil2Float64 + return stat.Nil2Float64 } else if i > 0 { return max } - return Nil2Float64 + return stat.Nil2Float64 //case []float64: // if self.Len() == 0 { // return Nil2Float64 diff --git a/generic_range.go b/generic_range.go index 118b97068df7860f25825790f16be18ce6566def..7a469bbe02c1d9dda7cbf2057bb966422e133c5b 100644 --- a/generic_range.go +++ b/generic_range.go @@ -7,12 +7,12 @@ import ( ) // Copy 复制一个副本 -func (self *NDFrame) Copy() Series { +func (self *NDFrame) Copy() stat.Series { vlen := self.Len() return self.Subset(0, vlen, true) } -func (self *NDFrame) Subset(start, end int, opt ...any) Series { +func (self *NDFrame) Subset(start, end int, opt ...any) stat.Series { // 默认不copy var __optCopy bool = false if len(opt) > 0 { @@ -46,10 +46,10 @@ func (self *NDFrame) Subset(start, end int, opt ...any) Series { default: // 其它类型忽略 } - return self.Empty() + return self.Empty(0) } -func (self *NDFrame) oldSubset(start, end int, opt ...any) Series { +func (self *NDFrame) oldSubset(start, end int, opt ...any) stat.Series { // 默认不copy var __optCopy bool = false if len(opt) > 0 { @@ -110,13 +110,13 @@ func (self *NDFrame) oldSubset(start, end int, opt ...any) Series { rows: rows, values: vs, } - var s Series + var s stat.Series s = &frame return s } // Select 选取一段记录 -func (self *NDFrame) Select(r stat.ScopeLimit) Series { +func (self *NDFrame) Select(r stat.ScopeLimit) stat.Series { start, end, err := r.Limits(self.Len()) if err != nil { return nil diff --git a/generic_ref.go b/generic_ref.go index b631f2df411e782c56cab433560a4685550b7d6f..f4e49f9027791a8c715d95fa4abaf8bcdf3cbaf8 100644 --- a/generic_ref.go +++ b/generic_ref.go @@ -5,23 +5,23 @@ import ( "gitee.com/quant1x/pandas/stat" ) -func (self *NDFrame) Ref(param any) (s Series) { +func (self *NDFrame) Ref(param any) (s stat.Series) { var N []float32 switch v := param.(type) { case int: N = stat.Repeat[float32](float32(v), self.Len()) case []float32: - N = stat.Align(v, Nil2Float32, self.Len()) - case Series: + N = stat.Align(v, stat.Nil2Float32, self.Len()) + case stat.Series: vs := v.Values() N = stat.SliceToFloat32(vs) - N = stat.Align(N, Nil2Float32, self.Len()) + N = stat.Align(N, stat.Nil2Float32, self.Len()) default: panic(exception.New(1, "error window")) } - var d Series - d = clone(self).(Series) + var d stat.Series + d = stat.Clone(self).(stat.Series) //return Shift[float64](&d, periods, func() float64 { // return Nil2Float64 //}) @@ -41,11 +41,11 @@ func (self *NDFrame) Ref(param any) (s Series) { }) case []float32: return Shift2[float32](&d, N, func() float32 { - return Nil2Float32 + return stat.Nil2Float32 }) default: //case []float64: return Shift2[float64](&d, N, func() float64 { - return Nil2Float64 + return stat.Nil2Float64 }) } diff --git a/generic_rolling.go b/generic_rolling.go index 3442ba32cd55be303773abe66f5f0a5950c2a468..8c8854d3563c7c8b963e6e334b384eb5024e76cd 100644 --- a/generic_rolling.go +++ b/generic_rolling.go @@ -5,73 +5,23 @@ import ( "gitee.com/quant1x/pandas/stat" ) -// RollingAndExpandingMixin 滚动和扩展静态横切 -type RollingAndExpandingMixin struct { - window []stat.DType - series Series -} - // Rolling RollingAndExpandingMixin -func (self *NDFrame) Rolling(param any) RollingAndExpandingMixin { +func (self *NDFrame) Rolling(param any) stat.RollingAndExpandingMixin { var N []stat.DType switch v := param.(type) { case int: N = stat.Repeat[stat.DType](stat.DType(v), self.Len()) case []stat.DType: N = stat.Align(v, stat.DTypeNaN, self.Len()) - case Series: + case stat.Series: vs := v.DTypes() N = stat.Align(vs, stat.DTypeNaN, self.Len()) default: panic(exception.New(1, "error window")) } - w := RollingAndExpandingMixin{ - window: N, - series: self, + w := stat.RollingAndExpandingMixin{ + Window: N, + Series: self, } return w } - -func (r RollingAndExpandingMixin) getBlocks() (blocks []Series) { - for i := 0; i < r.series.Len(); i++ { - N := r.window[i] - if stat.DTypeIsNaN(N) || int(N) > i+1 { - blocks = append(blocks, r.series.Empty()) - continue - } - window := int(N) - start := i + 1 - window - end := i + 1 - blocks = append(blocks, r.series.Subset(start, end, false)) - } - - return -} - -func (r RollingAndExpandingMixin) Apply_v1(f func(S Series, N stat.DType) stat.DType) (s Series) { - s = r.series.Empty() - for i, block := range r.getBlocks() { - if block.Len() == 0 { - s.Append(stat.DTypeNaN) - continue - } - v := f(block, r.window[i]) - s.Append(v) - } - return -} - -// Apply 接受一个回调 -func (r RollingAndExpandingMixin) Apply(f func(S Series, N stat.DType) stat.DType) (s Series) { - values := make([]stat.DType, r.series.Len()) - for i, block := range r.getBlocks() { - if block.Len() == 0 { - values[i] = stat.DTypeNaN - continue - } - v := f(block, r.window[i]) - values[i] = v - } - s = NewSeries(SERIES_TYPE_DTYPE, r.series.Name(), values) - return -} diff --git a/generic_shift.go b/generic_shift.go index 5ed85b3e58c3e83698e03c9be3f7c17cbbc2bc97..f71109e7015f82cd6639c0d49dce79edf5710660 100644 --- a/generic_shift.go +++ b/generic_shift.go @@ -6,9 +6,9 @@ import ( ) // Shift series切片, 使用可选的时间频率按所需的周期数移动索引 -func Shift[T stat.GenericType](s *Series, periods int, cbNan func() T) Series { - var d Series - d = clone(*s).(Series) +func Shift[T stat.GenericType](s *stat.Series, periods int, cbNan func() T) stat.Series { + var d stat.Series + d = stat.Clone(*s).(stat.Series) if periods == 0 { return d } @@ -43,9 +43,9 @@ func Shift[T stat.GenericType](s *Series, periods int, cbNan func() T) Series { } // Shift2 series切片, 使用可选的时间频率按所需的周期数移动索引 -func Shift2[T stat.GenericType](s *Series, N []float32, cbNan func() T) Series { - var d Series - d = clone(*s).(Series) +func Shift2[T stat.GenericType](s *stat.Series, N []float32, cbNan func() T) stat.Series { + var d stat.Series + d = stat.Clone(*s).(stat.Series) if len(N) == 0 { return d } diff --git a/generic_sort.go b/generic_sort.go index a7d625ab2b7e268a67118e757d949e97e94429e3..7f67d83e3c3deaf10d55b9e83cb6b376e4b2ed70 100644 --- a/generic_sort.go +++ b/generic_sort.go @@ -1,5 +1,7 @@ package pandas +import "gitee.com/quant1x/pandas/stat" + // Len 获得行数, 实现sort.Interface接口的获取元素数量方法 func (self *NDFrame) Len() int { return self.rows @@ -7,7 +9,7 @@ func (self *NDFrame) Len() int { // Less 实现sort.Interface接口的比较元素方法 func (self *NDFrame) Less(i, j int) bool { - if self.type_ == SERIES_TYPE_BOOL { + if self.type_ == stat.SERIES_TYPE_BOOL { values := self.Values().([]bool) var ( a = int(0) @@ -20,16 +22,16 @@ func (self *NDFrame) Less(i, j int) bool { b = 1 } return a < b - } else if self.type_ == SERIES_TYPE_INT64 { + } else if self.type_ == stat.SERIES_TYPE_INT64 { values := self.Values().([]int64) return values[i] < values[j] - } else if self.type_ == SERIES_TYPE_FLOAT32 { + } else if self.type_ == stat.SERIES_TYPE_FLOAT32 { values := self.Values().([]float32) return values[i] < values[j] - } else if self.type_ == SERIES_TYPE_FLOAT64 { + } else if self.type_ == stat.SERIES_TYPE_FLOAT64 { values := self.Values().([]float64) return values[i] < values[j] - } else if self.type_ == SERIES_TYPE_STRING { + } else if self.type_ == stat.SERIES_TYPE_STRING { values := self.Values().([]string) return values[i] < values[j] } else { @@ -43,19 +45,19 @@ func (self *NDFrame) Less(i, j int) bool { // Swap 实现sort.Interface接口的交换元素方法 func (self *NDFrame) Swap(i, j int) { - if self.type_ == SERIES_TYPE_BOOL { + if self.type_ == stat.SERIES_TYPE_BOOL { values := self.Values().([]bool) values[i], values[j] = values[j], values[i] - } else if self.type_ == SERIES_TYPE_INT64 { + } else if self.type_ == stat.SERIES_TYPE_INT64 { values := self.Values().([]int64) values[i], values[j] = values[j], values[i] - } else if self.type_ == SERIES_TYPE_FLOAT32 { + } else if self.type_ == stat.SERIES_TYPE_FLOAT32 { values := self.Values().([]float32) values[i], values[j] = values[j], values[i] - } else if self.type_ == SERIES_TYPE_FLOAT64 { + } else if self.type_ == stat.SERIES_TYPE_FLOAT64 { values := self.Values().([]float64) values[i], values[j] = values[j], values[i] - } else if self.type_ == SERIES_TYPE_STRING { + } else if self.type_ == stat.SERIES_TYPE_STRING { values := self.Values().([]string) values[i], values[j] = values[j], values[i] } else { diff --git a/generic_test.go b/generic_test.go index a916f1996b85b861bbe1355c3b8b8fe5802adb3f..a58049659c814699fd298f4851062a62a60ff4eb 100644 --- a/generic_test.go +++ b/generic_test.go @@ -8,22 +8,22 @@ import ( func TestSeriesFrame(t *testing.T) { data := []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - s1 := NewSeries(SERIES_TYPE_FLOAT64, "x", data) + s1 := NewSeries(stat.SERIES_TYPE_FLOAT64, "x", data) fmt.Printf("%+v\n", s1) var d1 any d1 = data - s2 := NewSeries(SERIES_TYPE_FLOAT64, "x", d1) + s2 := NewSeries(stat.SERIES_TYPE_FLOAT64, "x", d1) fmt.Printf("%+v\n", s2) - var s3 Series + var s3 stat.Series // s3 = NewSeriesBool("x", data) - s3 = NewSeries(SERIES_TYPE_BOOL, "x", data) + s3 = NewSeries(stat.SERIES_TYPE_BOOL, "x", data) fmt.Printf("%+v\n", s3.Values()) - var s4 Series + var s4 stat.Series ts4 := GenericSeries[float64]("x", data...) - ts4 = NewSeries(SERIES_TYPE_FLOAT64, "x", data) + ts4 = NewSeries(stat.SERIES_TYPE_FLOAT64, "x", data) s4 = ts4 fmt.Printf("%+v\n", s4.Values()) } @@ -43,24 +43,14 @@ func TestNDFrameNew(t *testing.T) { nd11 := nd1.Subset(1, 2, true) fmt.Println(nd11.Records()) fmt.Println(nd1.Max()) - fmt.Println(nd1.RollingV1(5).Max()) - fmt.Println(nd1.RollingV1(5).Min()) - - nd12 := nd1.RollingV1(5).Mean() - d12 := nd12.Values() - fmt.Println(d12) nd13 := nd1.Shift(3) fmt.Println(nd13.Values()) - nd14 := nd1.RollingV1(5).StdDev() - fmt.Println(nd14.Values()) // string d2 := []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "nan", "12"} nd2 := NewNDFrame[string]("x", d2...) fmt.Println(nd2) - nd21 := nd2.RollingV1(5).Max() - fmt.Println(nd21) nd2.FillNa(0, true) fmt.Println(nd2) fmt.Println(nd2.Records()) @@ -79,8 +69,8 @@ func TestRolling2(t *testing.T) { r1 := df.Col("x").Rolling(5).Mean().Values() fmt.Println("序列化结果:", r1) fmt.Println("------------------------------------------------------------") - d2 := []float64{1, 2, 3, 4, 3, 3, 2, 1, Nil2Float64, Nil2Float64, Nil2Float64, Nil2Float64} - s2 := NewSeries(SERIES_TYPE_FLOAT64, "x", d2) + d2 := []float64{1, 2, 3, 4, 3, 3, 2, 1, stat.Nil2Float64, stat.Nil2Float64, stat.Nil2Float64, stat.Nil2Float64} + s2 := NewSeries(stat.SERIES_TYPE_FLOAT64, "x", d2) fmt.Printf("序列化参数: %+v\n", s2.Values()) r2 := df.Col("x").Rolling(s2).Mean().Values() fmt.Println("序列化结果:", r2) diff --git a/num/Adder.go b/num/Adder.go deleted file mode 100644 index 9f2414702342cdc07b94cc245d0399a2725a960e..0000000000000000000000000000000000000000 --- a/num/Adder.go +++ /dev/null @@ -1,152 +0,0 @@ -package num - -import ( - "fmt" - "reflect" -) - -type _int struct { - v interface{} - t reflect.Type -} - -func (i _int) Value() interface{} { - return i.v -} - -func (i *_int) SetZero() { - switch i.t.Kind() { - case reflect.Int: - i.v = 0 - case reflect.Uint8: - i.v = uint8(0) - case reflect.Uint16: - i.v = uint16(0) - case reflect.Uint32: - i.v = uint32(0) - case reflect.Uint64: - i.v = uint64(0) - case reflect.Int8: - i.v = int8(0) - case reflect.Int16: - i.v = int16(0) - case reflect.Int32: - i.v = int32(0) - case reflect.Int64: - i.v = int64(0) - } -} - -type _float struct { - v interface{} - t reflect.Type -} - -func (f _float) Value() interface{} { - return f.v -} - -func (f *_float) SetZero() { - switch f.t.Kind() { - case reflect.Float32: - f.v = float32(0) - case reflect.Float64: - f.v = float64(0) - } -} - -func Adder(t reflect.Type) Add { - - support := []reflect.Kind{ - reflect.Int, - reflect.Uint8, - reflect.Uint16, - reflect.Uint32, - reflect.Uint64, - reflect.Int8, - reflect.Int16, - reflect.Uint32, - reflect.Int64, - reflect.Float32, - reflect.Float64, - } - vk := t.Kind() - if contain(support, vk) { - var add Add - if vk == reflect.Float64 || vk == reflect.Float32 { - add = &_float{t: t} - } else { - add = &_int{t: t} - } - add.SetZero() - return add - } - panic("not support type " + t.String()) -} - -type Add interface { - Add(value reflect.Value) - Value() interface{} - SetZero() -} - -func (i *_int) Add(value reflect.Value) { - if value.Kind() != i.t.Kind() { - panic(fmt.Sprintf("not support %s add %s", i.t.String(), value.Type().String())) - } - switch value.Type().Kind() { - case reflect.Int: - sum := i.v.(int) - sum += value.Interface().(int) - i.v = sum - case reflect.Uint8: - sum := i.v.(uint8) - sum += value.Interface().(uint8) - i.v = sum - case reflect.Uint16: - sum := i.v.(uint16) - sum += value.Interface().(uint16) - i.v = sum - case reflect.Uint32: - sum := i.v.(uint32) - sum += value.Interface().(uint32) - i.v = sum - case reflect.Uint64: - sum := i.v.(uint64) - sum += value.Interface().(uint64) - i.v = sum - case reflect.Int8: - sum := i.v.(int8) - sum += value.Interface().(int8) - i.v = sum - case reflect.Int16: - sum := i.v.(int16) - sum += value.Interface().(int16) - i.v = sum - case reflect.Int32: - sum := i.v.(int32) - sum += value.Interface().(int32) - i.v = sum - case reflect.Int64: - sum := i.v.(int64) - sum += value.Interface().(int64) - i.v = sum - } -} - -func (f *_float) Add(value reflect.Value) { - if value.Kind() != f.t.Kind() { - panic(fmt.Sprintf("not support %s add %s", f.t.String(), value.Type().String())) - } - - switch value.Type().Kind() { - case reflect.Float32: - sum := f.v.(float32) - sum += value.Interface().(float32) - f.v = sum - case reflect.Float64: - sum := f.v.(float64) - sum += value.Interface().(float64) - f.v = sum - } -} diff --git a/num/README.md b/num/README.md deleted file mode 100644 index 1f38be099d488d5e7063791e15787b73af64b5af..0000000000000000000000000000000000000000 --- a/num/README.md +++ /dev/null @@ -1,505 +0,0 @@ -# lambda - -## overview - -`lambda` is a lambda expression for go,lets you extract elements from array by lambda expression - -## Installation - -go get github.com/favar/lambda - -## Getting Started - -#### use LambdaArray returns Array interface - -```go -sa := []int{1,2,3,4,5,6,7,8,9} -arr := LambdaArray(sa) // return Array -``` - - - -#### interface Array - -```go -type Array interface { - IsSlice() bool - Join(options JoinOptions) string - Filter(express interface{}) Array - Sort(express interface{}) Array - SortMT(express interface{}) Array - Map(express interface{}) Array - Append(elements ...interface{}) Array - Max(express interface{}) interface{} - Min(express interface{}) interface{} - Any(express interface{}) bool - All(express interface{}) bool - Count(express interface{}) int - First(express interface{}) (interface{}, error) - Last(express interface{}) (interface{}, error) - index(i int) (interface{}, error) - Take(skip, count int) Array - Sum(express interface{}) interface{} - Average(express interface{}) float64 - Contains(express interface{}) bool - Pointer() interface{} -} -``` - -## Usage - -***define test struct*** - -```go -type user struct { - name string - age int -} -``` - - - -#### Join - -array join into string - -```go -type JoinOptions struct { - Symbol string // split string,default `,` - express interface{} // express match func(ele TElement) string -} -Join(options JoinOptions) string -``` - -``` -arr := []int{1,2,3,4,5} -str1 := LambdaArray(arr).Join(JoinOptions{ - express: func(e int) string { return strconv.Itoa(e) }, -}) -fmt.Println(str1) // 1,2,3,4,5 default `,` - -str2 := LambdaArray(arr).Join(JoinOptions{ - express: func(e int) string { return strconv.Itoa(e) }, - Symbol: "|", -}) -fmt.Println(str2) // 1|2|3|4|5 - - -``` - - - -#### Filter - -array filter - -```go -Filter(express interface{}) Array // express match func(ele TElement) bool -``` - -```go -arr := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} -larr := LambdaArray(arr) -ret1 := larr.Filter(func(ele int) bool { return ele > 5 }).Pointer().([]int) -fmt.Println(ret1) // [6 7 8 9 10] - -ret2 := larr.Filter(func(ele int) bool { return ele%2 == 0 }).Pointer().([]int) -fmt.Println(ret2) // [2 4 6 8 10] - -ret3 := LambdaArray(users).Filter(func(u user) bool { return u.age < 30 }).Pointer().([]user) -fmt.Println(ret3) // [{Abraham 20} {Edith 25} {Anthony 26}] -``` - - - -#### Sort - -quick sort - -```go -Sort(express interface{}) Array // express match func(e1, e2 TElement) bool -``` - -```go -arr := []int{1, 3, 8, 6, 12, 5, 9} -// order by asc -ret1 := LambdaArray(arr).Sort(func(a, b int) bool { return a < b }).Pointer().([]int) -// order by desc -ret2 := LambdaArray(arr).Sort(func(a, b int) bool { return a > b }).Pointer().([]int) - -fmt.Println(ret1) // [1 3 5 6 8 9 12] -fmt.Println(ret2) // [12 9 8 6 5 3 1] - -users := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, -} -// order by user.age asc -ret3 := LambdaArray(users).Sort(func(a, b user) bool { return a.age < b.age }).Pointer().([]user) -fmt.Println(ret3) // [{Abraham 20} {Edith 25} {Anthony 26} {Abel 33} {Charles 40}] -``` - - - -#### SortMT - -sort by quick multithreading - -usage like Sort - -#### Map - -.map to new array - -```go -Map(express interface{}) Array // express match func(ele TElement) TOut -``` - -```go -arr := LambdaArray([]int{1, 2, 3, 4, 5}) -users := arr.Map(func(i int) user { - return user{name: "un:" + strconv.Itoa(i), age: i} -}).Pointer().([]user) -fmt.Println(users) // [{un:1 1} {un:2 2} {un:3 3} {un:4 4} {un:5 5}] - -``` - - - -#### Append - -.append element - -```go -Append(elements ...interface{}) Array // each of elements type must be TElmenent -``` - -```go -arr := LambdaArray([]int{1, 2, 3}) -arr.Append(4) -fmt.Println(arr.Pointer().([]int)) // [1 2 3 4] -arr.Append(5, 6) -fmt.Println(arr.Pointer().([]int)) // [1 2 3 4 5 6] -``` - -#### Max - -.maximum element of array - -```go -Max(express interface{}) interface{} -``` - -```go -users := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, -} -eldest := LambdaArray(users).Max(func(u user) int { return u.age }).(user) -fmt.Println(eldest.name + " is the eldest") // Charles is the eldest - -want := []int{1, 5, 6, 3, 8, 9, 3, 12, 56, 186, 4, 9, 14} -var iArr = LambdaArray(want) -ret := iArr.Max(nil).(int) -fmt.Println(ret) // 186 -``` - - - -#### Min - -.minimum element of array - -```go -Min(express interface{}) interface{} -``` - -```go -users := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, -} -eldest := LambdaArray(users).Min(func(u user) int { return u.age }).(user) -fmt.Println(eldest.name + " is the eldest") // Abraham is the Charles - -want := []int{1, 5, 6, 3, 8, 9, 3, 12, 56, 186, 4, 9, 14} -var iArr = LambdaArray(want) -ret := iArr.Min(nil).(int) -fmt.Println(ret) // 1 -``` - - - -#### Any - -.Determines whether the Array contains any elements - -```go -Any(express interface{}) bool -``` - -```go -us := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, -} -ret1 := LambdaArray(us).Any(func(u user) bool { return u.age > 30 }) -fmt.Println(ret1) // true -ret2 := LambdaArray(us).Any(func(u user) bool { return u.age < 0 }) -fmt.Println(ret2) // false -``` - -#### All - -Determines whether the condition is satisfied for all elements in the Array - -```go -All(express interface{}) bool -``` - -```go -us := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, -} -ret1 := LambdaArray(us).All(func(u user) bool { return u.age > 30 }) -fmt.Println(ret1) // false -ret2 := LambdaArray(us).All(func(u user) bool { return u.age > 10 }) -fmt.Println(ret2) // true -``` - -#### Count - -Returns a number indicating how many elements in the specified Array satisfy the condition - -```go -Count(express interface{}) int -``` - -```go -us := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, -} -ret1 := LambdaArray(us).Count(func(u user) bool { return u.age > 30 }) -fmt.Println(ret1) // 2 -ret2 := LambdaArray(us).Count(func(u user) bool { return u.age > 20 }) -fmt.Println(ret2) // 4 -``` - - - -#### First - -Returns the first element of an Array that satisfies the condition - -```go -First(express interface{}) (interface{}, error) -``` - -```go -us := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, -} -arr := LambdaArray(us) -if u, err := arr.First(func(u user) bool { return u.name == "Charles" }); err == nil { - fmt.Println(u, " found") -} else { - fmt.Println("not found") -} -// {Charles 40} found -if u, err := arr.First(func(u user) bool { return u.name == "jack" }); err == nil { - fmt.Println(u, " found") -} else { - fmt.Println("not found") -} -// not found - -``` - -#### Last - -Returns the last element of an Array that satisfies the condition - -```go -Last(express interface{}) (interface{}, error) -``` - -```go -us := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, -} -arr := LambdaArray(us) -if u, err := arr.Last(func(u user) bool { return u.name == "Anthony" }); err == nil { - fmt.Println(u, " found") -} else { - fmt.Println("not found") -} -// {Anthony 26} found -if u, err := arr.Last(func(u user) bool { return u.age > 35 }); err == nil { - fmt.Println(u, " found") -} else { - fmt.Println("not found") -} -// {Charles 40} found -``` - - - -#### Index - -Returns the zero based index of the first occurrence in an Array - -```go -Index(i int) (interface{}, error) -``` - -```go -if element, err := LambdaArray([]int{1, 2, 3, 4, 5}).Index(3); err == nil { - fmt.Println(element) -} else { - fmt.Println(err) -} -// 4 -if element, err := LambdaArray([]int{1, 2, 3, 4, 5}).Index(10); err == nil { - fmt.Println(element) -} else { - fmt.Println(err) -} -// 10 out of range -``` - - - -#### Take - -take `count` elements start by `skip` - -```go -Take(skip, count int) Array -``` - -```go -ret1 := LambdaArray([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}).Take(4, 10).Pointer().([]int) -fmt.Println(ret1) // [5 6 7 8 9 10] -ret2 := LambdaArray([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}).Take(10, 10).Pointer().([]int) -fmt.Println(ret2) // [] -``` - - - -#### Sum - -sum of the values returned by the expression - -```go -Sum(express interface{}) interface{} -``` - -```go -us := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, -} -arr := LambdaArray(us) -fmt.Println("total user age is ", arr.Sum(func(u user) int { return u.age })) -// total user age is 144 -``` - - - -#### Average - -average of the values returned by the expression - -```go -Average(express interface{}) float64 -``` - -```go -us := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, -} -arr := LambdaArray(us) -fmt.Println("all user average age is", arr.Average(func(u user) int { return u.age })) -// all user average age is 28.8 -``` - - - -#### Contains - -Determines whether the array contains the specified element - -```go -Contains(express interface{}) bool -``` - -```go -us := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, -} -arr2 := LambdaArray(us) -fmt.Println(arr2.Contains(func(u user) bool { return u.age > 25 })) //true - -fmt.Println(LambdaArray([]int{1, 2, 3, 4, 5, 6, 7, 8, 9}).Contains(9)) // true -fmt.Println(LambdaArray([]int{1, 2, 3, 4, 5, 6, 7, 8, 9}).Contains(0)) // false -``` - -#### Pointer - -array or slice pointer - -```go -Pointer() interface{} -``` - - - -## Tutorial - -Usage - -## Questions - -Please let me know if you have any questions. - - - diff --git a/num/array.go b/num/array.go deleted file mode 100644 index 063ab1554e9b4d56dab146b91f4fe4d4627068d2..0000000000000000000000000000000000000000 --- a/num/array.go +++ /dev/null @@ -1,678 +0,0 @@ -package num - -import ( - "errors" - "fmt" - "reflect" - "strings" -) - -// LambdaArray make Array from source(TIn[] type) -// source support array or slice type -func LambdaArray(source interface{}) Array { - t := reflect.TypeOf(source) - arr := _array{source, t, t.Elem(), reflect.ValueOf(source)} - if !arr.IsSlice() { - err := fmt.Errorf("source type is %s, not array ", arr.arrayType.Kind()) - panic(err) - } - return &arr -} - -type Array interface { - - // IsSlice return true when obj is slice - IsSlice() bool - - // Join array join into string - // eg - // JoinOptions.express func(u user) string {return u.name } - // JoinOptions.Symbol default `,` - Join(options JoinOptions) string - - // Filter array filter - // eg: arr.Filter(func(ele int) bool{ return ele>10}) - Filter(express interface{}) Array - - // Sort sort by quick - // eg - Sort(express interface{}) Array - - // SortMT sort by quick multithreading - SortMT(express interface{}) Array - - // Map map to new array - // express func(el T) T{ return T } - Map(express interface{}) Array - - // Append append element - Append(elements ...interface{}) Array - - // Max maximum of array - // express eg: express func(ele TIn) TOut{ return TOut },TOut must be number Type or Compare - Max(express interface{}) interface{} - - // Min minimum of array - // express eg: express func(ele TIn) TOut{ return TOut },TOut must be number Type or Compare - Min(express interface{}) interface{} - - // Any Determines whether the Array contains any elements - Any(express interface{}) bool - - // All Determines whether the condition is satisfied for all elements in the Array - All(express interface{}) bool - - // Count Returns a number indicating how many elements in the specified Array satisfy the condition - Count(express interface{}) int - - // First Returns the first element of an Array that satisfies the condition - First(express interface{}) (interface{}, error) - - // Last Returns the last element of an Array that satisfies the condition - Last(express interface{}) (interface{}, error) - - // Index Returns the zero based index of the first occurrence in an Array - Index(i int) (interface{}, error) - - // Take skip and Returns the elements - Take(skip, count int) Array - - // Sum sum of the values returned by the expression - Sum(express interface{}) interface{} - - // Average average of the values returned by the expression - Average(express interface{}) float64 - - // Contains Determines whether the array contains the specified element - // number type use default comparator - // other type can implements Compare - Contains(express interface{}) bool - - // Pointer array or slice pointer - // Array.Pointer().([]T or [n]T) - Pointer() interface{} -} - -func innerLambdaArray(value reflect.Value) Array { - t := value.Type() - arr := _array{value, t, t.Elem(), value} - return &arr -} - -type _array struct { - - // source array - source interface{} - // array type - arrayType reflect.Type - // element type - elementType reflect.Type - // value type - value reflect.Value -} - -func (p *_array) Contains(express interface{}) bool { - sz := p.Len() - if express == nil { - panic("express is null") - } - if t := reflect.TypeOf(express); t.Kind() == reflect.Func { - expType := reflect.TypeOf(express) - checkExpress(expType, []reflect.Type{p.elementType}, []reflect.Type{reflect.TypeOf(true)}) - fn := reflect.ValueOf(express) - for i := 0; i < sz; i++ { - ret := fn.Call([]reflect.Value{p.value.Index(i)}) - if ret[0].Interface().(bool) { - return true - } - } - } else if tor, err := BasicComparator(express); err == nil { - for i := 0; i < sz; i++ { - if tor.CompareTo(p.value.Index(i).Interface()) == 0 { - return true - } - } - } else if eq, ok := express.(Equal); ok { - for i := 0; i < sz; i++ { - if eq.Equals(p.value.Index(i).Interface()) { - return true - } - } - } else { - panic("unknown type " + t.String()) - } - return false -} - -func (p *_array) Average(express interface{}) float64 { - length := p.Len() - if length == 0 { - return float64(0) - } - sum := p.Sum(express) - - switch sum.(type) { - case int: - return float64(sum.(int)) / float64(length) - case uint8: - return float64(sum.(uint8)) / float64(length) - case uint16: - return float64(sum.(uint16)) / float64(length) - case uint32: - return float64(sum.(uint32)) / float64(length) - case uint64: - return float64(sum.(uint64)) / float64(length) - case int8: - return float64(sum.(int8)) / float64(length) - case int16: - return float64(sum.(int16)) / float64(length) - case int32: - return float64(sum.(int32)) / float64(length) - case int64: - return float64(sum.(int64)) / float64(length) - case float32: - return float64(sum.(float32)) / float64(length) - case float64: - return sum.(float64) / float64(length) - default: - panic("unknown type " + reflect.TypeOf(sum).String()) - } -} - -func (p *_array) Append(elements ...interface{}) Array { - ret := LambdaArray(elements).Map(func(ele interface{}) reflect.Value { - if t := reflect.TypeOf(ele); t.Kind() != p.elementType.Kind() { - panic(fmt.Sprintf("element type[%s] is not %s.", t.String(), p.elementType.String())) - } - return reflect.ValueOf(ele) - }).Pointer().([]reflect.Value) - p.value = reflect.Append(p.value, ret...) - return p -} - -func (p *_array) Any(express interface{}) bool { - if express == nil { - return p.Len() > 0 - } - checkExpress( - reflect.TypeOf(express), - []reflect.Type{p.elementType}, - []reflect.Type{reflect.TypeOf(true)}) - - length := p.Len() - fn := reflect.ValueOf(express) - for i := 0; i < length; i++ { - if fn.Call([]reflect.Value{p.value.Index(i)})[0].Interface().(bool) { - return true - } - } - return false -} - -func (p *_array) All(express interface{}) bool { - if express == nil { - return p.Len() > 0 - } - checkExpress( - reflect.TypeOf(express), - []reflect.Type{p.elementType}, - []reflect.Type{reflect.TypeOf(true)}) - length := p.Len() - fn := reflect.ValueOf(express) - for i := 0; i < length; i++ { - if !fn.Call([]reflect.Value{p.value.Index(i)})[0].Interface().(bool) { - return false - } - } - if p.Len() > 0 { - return true - } else { - return false - } -} - -func (p *_array) Count(express interface{}) int { - if express == nil { - return p.Len() - } - checkExpress( - reflect.TypeOf(express), - []reflect.Type{p.elementType}, - []reflect.Type{reflect.TypeOf(true)}) - fn := reflect.ValueOf(express) - count := 0 - p.EachV(func(v reflect.Value, _ int) { - if fn.Call([]reflect.Value{v})[0].Interface().(bool) { - count++ - } - }) - return count -} - -func (p *_array) Find(express interface{}, start, step int) (interface{}, error) { - length := p.Len() - if length == 0 { - return nil, errors.New("empty array") - } - - if express == nil { - return p.value.Index(0).Interface(), nil - } - checkExpress( - reflect.TypeOf(express), - []reflect.Type{p.elementType}, - []reflect.Type{reflect.TypeOf(true)}) - fn := reflect.ValueOf(express) - for i := start; i < length && i >= 0; i += step { - if ele := p.value.Index(i); fn.Call([]reflect.Value{ele})[0].Interface().(bool) { - return ele.Interface(), nil - } - } - return nil, errors.New("not found") -} - -func (p *_array) First(express interface{}) (interface{}, error) { - return p.Find(express, 0, 1) -} - -func (p *_array) Last(express interface{}) (interface{}, error) { - return p.Find(express, p.Len()-1, -1) -} - -func (p *_array) Index(i int) (interface{}, error) { - if i < p.Len() { - return p.value.Index(i), nil - } - return nil, errors.New(fmt.Sprintf("%d out of range", i)) -} - -func (p *_array) Take(skip, count int) Array { - length := p.value.Len() - - ret := reflect.MakeSlice(p.arrayType, 0, 0) - for i := skip; i < length; i++ { - if count > 0 { - ret = reflect.Append(ret, p.value.Index(i)) - count-- - } - if count == 0 { - break - } - } - return innerLambdaArray(ret) -} - -func (p *_array) Sum(express interface{}) interface{} { - - var add Add - if express == nil { - add = Adder(p.elementType) - } else { - checkExpress( - reflect.TypeOf(express), - []reflect.Type{p.elementType}, - nil) - add = Adder(reflect.TypeOf(express).Out(0)) - } - - length := p.Len() - if length == 0 { - return add.Value() - } - - fn := reflect.ValueOf(express) - - fv := func(i int) reflect.Value { - v := p.value.Index(i) - if express == nil { - return v - } - return fn.Call([]reflect.Value{v})[0] - } - - for i := 0; i < length; i++ { - add.Add(fv(i)) - } - return add.Value() -} - -func (p *_array) Pointer() interface{} { - return p.value.Interface() -} - -func (p *_array) IsSlice() bool { - return p.arrayType.Kind() == reflect.Slice || p.arrayType.Kind() == reflect.Array -} - -func (p *_array) Len() int { - return p.value.Len() -} - -// check the function express -// exp the express function type -// in express function parameter types -// out express function return types -func checkExpress(exp reflect.Type, in []reflect.Type, out []reflect.Type) { - if exp.Kind() != reflect.Func { - panic("express is not a func express") - } - // check in - numIn := exp.NumIn() - lenIn := len(in) - if numIn != lenIn { - panic(fmt.Errorf("lambda express parameter count must be %d", lenIn)) - } - for i := 0; i < lenIn; i++ { - if in[i].Kind() != exp.In(i).Kind() { - panic(fmt.Errorf("lambda express the %d'th parameter Type must be %s,not %s,func=%s", - i, exp.In(i).String(), in[i].String(), exp.String())) - } - } - if out == nil { - return - } - // check output - numOut := exp.NumOut() - lenOut := len(out) - if numOut != lenOut { - panic(fmt.Errorf("lambda express return Types count must be %d", lenOut)) - } - for i := 0; i < lenOut; i++ { - if out[i].Kind() != exp.Out(i).Kind() { - panic(fmt.Errorf("lambda express the %d'th return Type must be %s", i, exp.Out(i).String())) - } - } -} - -// check the function express -func checkExpressRARTO(express interface{}, in []reflect.Type) reflect.Type { - t := reflect.TypeOf(express) - if t.NumOut() == 0 { - panic("lambda express must has only one return-value.") - } - ot := t.Out(0) - checkExpress(t, in, []reflect.Type{ot}) - return ot -} - -func (p *_array) Map(express interface{}) Array { - in := []reflect.Type{p.elementType} - ot := checkExpressRARTO(express, in) - - var result reflect.Value - length := p.Len() - // slice or array - isSlice := p.arrayType.Kind() == reflect.Slice - var element reflect.Value - if isSlice { - result = reflect.MakeSlice(reflect.SliceOf(ot), p.Len(), p.Len()) - element = result - } else { - result = reflect.New(reflect.ArrayOf(length, ot)) - element = result.Elem() - } - - funcValue := reflect.ValueOf(express) - params := []reflect.Value{reflect.ValueOf(0)} - for i := 0; i < length; i++ { - params[0] = p.value.Index(i) - trans := funcValue.Call(params) - v := element.Index(i) - v.Set(trans[0]) - } - - return innerLambdaArray(result) -} - -type JoinOptions struct { - Symbol string - express interface{} -} - -func (p *_array) Join(option JoinOptions) string { - if option.express != nil { - return p.Map(option.express).Join(JoinOptions{Symbol: option.Symbol}) - } - if p.elementType.Kind() != reflect.String { - panic("the array is not string array") - } - if option.Symbol == "" { - option.Symbol = "," - } - length := p.Len() - var build strings.Builder - for i := 0; i < length; i++ { - s := p.value.Index(i).Interface().(string) - build.WriteString(s) - if i < length-1 { - build.WriteString(option.Symbol) - } - } - return build.String() -} - -func (p *_array) Filter(express interface{}) Array { - in := []reflect.Type{p.elementType} - ft := reflect.TypeOf(express) - ot := reflect.TypeOf(true) - checkExpress(ft, in, []reflect.Type{ot}) - - ret := reflect.MakeSlice(reflect.SliceOf(p.elementType), 0, 0) - funcValue := reflect.ValueOf(express) - params := []reflect.Value{reflect.ValueOf(0)} - length := p.Len() - for i := 0; i < length; i++ { - params[0] = p.value.Index(i) - trans := funcValue.Call(params) - if trans[0].Interface().(bool) { - ret = reflect.Append(ret, params[0]) - } - } - return innerLambdaArray(ret) -} - -func (p *_array) SortByBubble(express interface{}) Array { - in := []reflect.Type{p.elementType, p.elementType} - ft := reflect.TypeOf(express) - ot := reflect.TypeOf(true) - checkExpress(ft, in, []reflect.Type{ot}) - - length := p.Len() - v := reflect.ValueOf(0) - funcValue := reflect.ValueOf(express) - params := []reflect.Value{v, v} - for i := 0; i < length-1; i++ { - for j := 0; j < length-i-1; j++ { - params[0] = p.value.Index(j) - params[1] = p.value.Index(j + 1) - trans := funcValue.Call(params) - if !trans[0].Interface().(bool) { - temp := params[0].Interface() - p.value.Index(j).Set(params[1]) - p.value.Index(j + 1).Set(reflect.ValueOf(temp)) - } - } - } - - return p -} - -func (p *_array) CopyValue() reflect.Value { - var arr reflect.Value - if p.IsSlice() { - arr = reflect.MakeSlice(reflect.SliceOf(p.elementType), p.Len(), p.Len()) - } else { - arr = reflect.New(reflect.ArrayOf(p.Len(), p.elementType)) - } - reflect.Copy(arr, p.value) - return arr -} - -func (p *_array) Sort(express interface{}) Array { - in := []reflect.Type{p.elementType, p.elementType} - ft := reflect.TypeOf(express) - ot := reflect.TypeOf(true) - checkExpress(ft, in, []reflect.Type{ot}) - - length := p.Len() - v := reflect.ValueOf(0) - funcValue := reflect.ValueOf(express) - params := []reflect.Value{v, v} - - p.value = p.CopyValue() - - compare := func(a reflect.Value, b int) bool { - params[0], params[1] = a, p.value.Index(b) - return funcValue.Call(params)[0].Interface().(bool) - } - - var inner func(int, int) - // quick sort - inner = func(l, r int) { - if l < r { - i, j, x := l, r, p.value.Index(l) - x = reflect.ValueOf(x.Interface()) - for i < j { - for i < j && compare(x, j) { - j-- - } - if i < j { - p.value.Index(i).Set(p.value.Index(j)) - i++ - } - for i < j && !compare(x, i) { - i++ - } - if i < j { - p.value.Index(j).Set(p.value.Index(i)) - j-- - } - } - p.value.Index(i).Set(x) - inner(l, i-1) - inner(i+1, r) - } - } - inner(0, length-1) - return p -} - -func (p *_array) SortMT(express interface{}) Array { - in := []reflect.Type{p.elementType, p.elementType} - ft := reflect.TypeOf(express) - ot := reflect.TypeOf(true) - checkExpress(ft, in, []reflect.Type{ot}) - - funcValue := reflect.ValueOf(express) - - compare := func(a, b reflect.Value) bool { - return funcValue.Call([]reflect.Value{a, b})[0].Interface().(bool) - } - var quick func(arr reflect.Value, ch chan reflect.Value) - quick = func(arr reflect.Value, ch chan reflect.Value) { - if arr.Len() == 1 { - ch <- arr.Index(0) - close(ch) - return - } - if arr.Len() == 0 { - close(ch) - return - } - - left := reflect.MakeSlice(reflect.SliceOf(p.elementType), 0, 0) - right := reflect.MakeSlice(reflect.SliceOf(p.elementType), 0, 0) - length := arr.Len() - x := arr.Index(0) - for i := 1; i < length; i++ { - curr := arr.Index(i) - if compare(x, curr) { - left = reflect.Append(left, curr) - } else { - right = reflect.Append(right, curr) - } - } - lch := make(chan reflect.Value, left.Len()) - rch := make(chan reflect.Value, right.Len()) - go quick(left, lch) - go quick(right, rch) - for v := range lch { - ch <- v - } - ch <- x - for v := range rch { - ch <- v - } - close(ch) - } - ch := make(chan reflect.Value) - go quick(p.value, ch) - values := reflect.MakeSlice(reflect.SliceOf(p.elementType), 0, 0) - for v := range ch { - values = reflect.Append(values, v) - } - return innerLambdaArray(values) -} - -func (p *_array) maxOrMin(express interface{}, isMax bool) interface{} { - if express != nil { - in := []reflect.Type{p.elementType} - ft := reflect.TypeOf(express) - ot := reflect.TypeOf(0) - checkExpress(ft, in, []reflect.Type{ot}) - } - var m reflect.Value - var mc interface{} - - funcValue := reflect.ValueOf(express) - f := func(value reflect.Value) interface{} { - if express == nil { - return value.Interface() - } - return funcValue.Call([]reflect.Value{value})[0].Interface() - } - - p.EachV(func(v reflect.Value, index int) { - vc := f(v) - if index == 0 { - m = v - mc = vc - } else { - tor, err := BasicComparator(vc) - if err != nil { - panic(err) - } - if isMax { - if tor.CompareTo(mc) > 0 { - m = v - mc = vc - } - } else { - if tor.CompareTo(mc) < 0 { - m = v - mc = vc - } - } - } - }) - return m.Interface() -} - -func (p *_array) Max(express interface{}) interface{} { - return p.maxOrMin(express, true) -} - -func (p *_array) Min(express interface{}) interface{} { - return p.maxOrMin(express, false) -} - -func (p *_array) EachV(fn func(v reflect.Value, i int)) { - if fn == nil { - return - } - length := p.Len() - - for i := 0; i < length; i++ { - fn(p.value.Index(i), i) - } -} diff --git a/num/array_test.go b/num/array_test.go deleted file mode 100644 index b161585ca2915eaf42698e2d0a96841511058c49..0000000000000000000000000000000000000000 --- a/num/array_test.go +++ /dev/null @@ -1,505 +0,0 @@ -package num - -import ( - "fmt" - "math/rand" - "strconv" - "testing" - "time" -) - -type user struct { - name string - age int -} - -type account struct { - name string - age int -} - -const count = 10000 - -func makeIntArray() []int { - want := make([]int, count) - for i := 0; i < count; i++ { - want[i] = i + 1 - } - return want -} - -func makeUserArray() []user { - want := make([]user, count) - for i := 0; i < count; i++ { - want[i] = user{"un:" + strconv.Itoa(i+1), i + 1} - } - return want -} - -func report(t *testing.T, start time.Time) { - end := time.Now() - ms := float32(end.Nanosecond()-start.Nanosecond()) / float32(1e6) - t.Log(fmt.Sprintf("run time %.2f ms", ms)) -} - -func isTrue(tv interface{}, v bool) { - t := tv.(*testing.T) - if !v { - t.Fail() - panic(v) - } -} - -func isFalse(tv interface{}, v bool) { - t := tv.(*testing.T) - if v { - t.Fail() - panic(v) - } -} - -func Test__array_Join(t *testing.T) { - defer report(t, time.Now()) - result := LambdaArray(makeIntArray()).Join(JoinOptions{ - express: func(e int) string { return strconv.Itoa(e) }, - }) - t.Log("string length", len(result)) - - arr := []int{1, 2, 3, 4, 5} - str1 := LambdaArray(arr).Join(JoinOptions{ - express: func(e int) string { return strconv.Itoa(e) }, - }) - fmt.Println(str1) - - str2 := LambdaArray(arr).Join(JoinOptions{ - express: func(e int) string { return strconv.Itoa(e) }, - Symbol: "|", - }) - fmt.Println(str2) -} - -func Test__array_Filter(t *testing.T) { - defer report(t, time.Now()) - want := makeIntArray() - ret := LambdaArray(want).Filter( - func(ele int) bool { return ele%3 == 0 }).Pointer().([]int) - isTrue(t, len(ret) == count/3) - - arr := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - larr := LambdaArray(arr) - ret1 := larr.Filter(func(ele int) bool { return ele > 5 }).Pointer().([]int) - fmt.Println(ret1) - - ret2 := larr.Filter(func(ele int) bool { return ele%2 == 0 }).Pointer().([]int) - fmt.Println(ret2) - - users := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, - } - ret3 := LambdaArray(users).Filter(func(u user) bool { return u.age < 30 }).Pointer().([]user) - fmt.Println(ret3) -} - -func Test__array_Sort_Quick(t *testing.T) { - defer report(t, time.Now()) - want := make([]int, count) - rand.Seed(time.Now().UnixNano()) - for i := 0; i < count; i++ { - want[i] = rand.Intn(count * 10) - } - t.Log(want[:10], "...", want[count-10:], " count=", len(want)) - ret := LambdaArray(want).Sort(func(e1, e2 int) bool { - return e1 > e2 - }).Pointer().([]int) - t.Log(ret[:10], "...", ret[count-10:], " count=", len(ret)) - - arr := []int{1, 3, 8, 6, 12, 5, 9} - // order by asc - ret1 := LambdaArray(arr).Sort(func(a, b int) bool { return a < b }).Pointer().([]int) - // order by desc - ret2 := LambdaArray(arr).Sort(func(a, b int) bool { return a > b }).Pointer().([]int) - - fmt.Println(ret1) - fmt.Println(ret2) - - users := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, - } - ret3 := LambdaArray(users).Sort(func(a, b user) bool { return a.age < b.age }).Pointer().([]user) - fmt.Println(ret3) -} - -func Test__array_Sort_QuickMT(t *testing.T) { - defer report(t, time.Now()) - want := make([]int, count) - rand.Seed(time.Now().UnixNano()) - for i := 0; i < count; i++ { - want[i] = rand.Intn(count * 10) - } - t.Log(want[:10], "...", want[count-10:], " count=", len(want)) - ret := LambdaArray(want).SortMT(func(e1, e2 int) bool { - return e1 > e2 - }).Pointer().([]int) - t.Log(ret[:10], "...", ret[count-10:], " count=", len(ret)) -} - -func Test__array_Map(t *testing.T) { - defer report(t, time.Now()) - - result := LambdaArray(makeIntArray()).Map(func(e int) int { - return e + 1 - }).Pointer().([]int) - - isTrue(t, len(result) == count) - - arr := LambdaArray([]int{1, 2, 3, 4, 5}) - users := arr.Map(func(i int) user { - return user{name: "un:" + strconv.Itoa(i), age: i} - }).Pointer().([]user) - fmt.Println(users) -} - -func Test__array_Append(t *testing.T) { - defer report(t, time.Now()) - want := LambdaArray(makeIntArray()) - want.Append(count + 1) - isTrue(t, count+1 == want.Count(nil)) - - arr := LambdaArray([]int{1, 2, 3}) - arr.Append(4) - fmt.Println(arr.Pointer().([]int)) - arr.Append(5, 6) - fmt.Println(arr.Pointer().([]int)) -} - -func (p account) CompareTo(a interface{}) int { - return p.age - a.(account).age -} - -func Test__array_Max(t *testing.T) { - defer report(t, time.Now()) - want := []int{1, 5, 6, 3, 8, 9, 3, 12, 56, 186, 4, 9, 14} - - var iArr = LambdaArray(want) - - ret := iArr.Max(nil).(int) - t.Log(ret) - ret = iArr.Max(func(ele int) int { return ele }).(int) - t.Log(ret) - - wantUsers := iArr.Map(func(ele int) account { - s := fmt.Sprintf("%d", ele) - return account{"zzz" + s, ele} - }) - - ret2 := wantUsers.Max(func(u account) int { return u.age }) - t.Log(ret2) - - ret3 := wantUsers.Max(nil) - t.Log(ret3) - - users := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, - } - eldest := LambdaArray(users).Max(func(u user) int { return u.age }).(user) - fmt.Println(eldest.name + " is the eldest") -} - -func Test__array_Sort_Min(t *testing.T) { - - defer report(t, time.Now()) - - want := []int{1, 5, 6, 3, 8, 9, 3, 12, 56, 186, 4, 9, 14} - - var iArr = LambdaArray(want) - - ret := iArr.Min(nil).(int) - t.Log(ret) - ret = iArr.Min(func(ele int) int { return ele }).(int) - t.Log(ret) - - wantUsers := iArr.Map(func(ele int) account { - s := fmt.Sprintf("%d", ele) - return account{"zzz" + s, ele} - }) - - ret2 := wantUsers.Min(func(u account) int { return u.age }) - t.Log(ret2) - - ret3 := wantUsers.Min(nil) - t.Log(ret3) - - users := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, - } - eldest := LambdaArray(users).Min(func(u user) int { return u.age }).(user) - fmt.Println(eldest.name + " is the Charles") -} - -func Test__array_Any(t *testing.T) { - defer report(t, time.Now()) - ints := LambdaArray(makeIntArray()) - users := LambdaArray(makeUserArray()) - ret := []bool{ - ints.Any(nil), - ints.Any(func(ele int) bool { return ele > 99999999 }), - users.Any(func(u user) bool { return u.name == "un:1997" }), - } - isTrue(t, ret[0]) - isFalse(t, ret[1]) - isTrue(t, ret[2]) - - us := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, - } - ret1 := LambdaArray(us).Any(func(u user) bool { return u.age > 30 }) - fmt.Println(ret1) - ret2 := LambdaArray(us).Any(func(u user) bool { return u.age < 0 }) - fmt.Println(ret2) -} - -func Test__array_All(t *testing.T) { - defer report(t, time.Now()) - ints := LambdaArray(makeIntArray()) - users := LambdaArray(makeUserArray()) - ret := []bool{ - ints.All(nil), - ints.All(func(ele int) bool { return ele > 0 }), - users.All(func(u user) bool { return u.name == "un:1997" }), - } - isTrue(t, ret[0]) - isTrue(t, ret[1]) - isFalse(t, ret[2]) - - us := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, - } - ret1 := LambdaArray(us).All(func(u user) bool { return u.age > 30 }) - fmt.Println(ret1) - ret2 := LambdaArray(us).All(func(u user) bool { return u.age > 10 }) - fmt.Println(ret2) -} - -func Test__array_Count(t *testing.T) { - defer report(t, time.Now()) - ints := LambdaArray(makeIntArray()) - ret := []int{ - ints.Count(nil), - ints.Count(func(ele int) bool { return ele%2 == 0 }), - } - isTrue(t, ret[0] == count) - isTrue(t, ret[1]*2 == count) - - us := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, - } - ret1 := LambdaArray(us).Count(func(u user) bool { return u.age > 30 }) - fmt.Println(ret1) - ret2 := LambdaArray(us).Count(func(u user) bool { return u.age > 20 }) - fmt.Println(ret2) -} - -func Test__array_First(t *testing.T) { - defer report(t, time.Now()) - want := []int{1, 5, 6, 3, 8, 9, 3, 12, 56, 186, 4, 9, 14} - if c, err := LambdaArray(want).First(func(e int) bool { return e > 30 }); err == nil { - t.Log(c) - isTrue(t, c == 56) - } else { - t.Fail() - } - - us := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, - } - arr := LambdaArray(us) - if u, err := arr.First(func(u user) bool { return u.name == "Charles" }); err == nil { - fmt.Println(u, " found") - } else { - fmt.Println("not found") - } - - if u, err := arr.First(func(u user) bool { return u.name == "jack" }); err == nil { - fmt.Println(u, " found") - } else { - fmt.Println("not found") - } -} - -func Test__array_Index(t *testing.T) { - defer report(t, time.Now()) - ints := LambdaArray(makeIntArray()) - ret := []int{ - ints.Count(nil), - ints.Count(func(ele int) bool { return ele%2 == 0 }), - } - isTrue(t, ret[0] == count) - isTrue(t, ret[1]*2 == count) - - if element, err := LambdaArray([]int{1, 2, 3, 4, 5}).Index(3); err == nil { - fmt.Println(element) - } else { - fmt.Println(err) - } - if element, err := LambdaArray([]int{1, 2, 3, 4, 5}).Index(10); err == nil { - fmt.Println(element) - } else { - fmt.Println(err) - } -} - -func Test__array_Last(t *testing.T) { - defer report(t, time.Now()) - want := []int{1, 5, 6, 3, 8, 9, 3, 12, 56, 186, 4, 9, 14} - if c, err := LambdaArray(want).Last(func(e int) bool { return e > 30 }); err == nil { - t.Log(c) - isTrue(t, c == 186) - } else { - t.Fail() - } - - us := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, - } - arr := LambdaArray(us) - if u, err := arr.Last(func(u user) bool { return u.name == "Anthony" }); err == nil { - fmt.Println(u, " found") - } else { - fmt.Println("not found") - } - - if u, err := arr.Last(func(u user) bool { return u.age > 35 }); err == nil { - fmt.Println(u, " found") - } else { - fmt.Println("not found") - } -} - -func Test__array_Take(t *testing.T) { - defer report(t, time.Now()) - ints := LambdaArray(makeIntArray()) - ret := ints.Take(200, 10).Pointer().([]int) - isTrue(t, ret[0] == 201) - isTrue(t, ret[9] == 210) - - ret1 := LambdaArray([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}).Take(4, 10).Pointer().([]int) - fmt.Println(ret1) - ret2 := LambdaArray([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}).Take(10, 10).Pointer().([]int) - fmt.Println(ret2) -} - -func Test__array_Sum(t *testing.T) { - defer report(t, time.Now()) - ret := LambdaArray(makeIntArray()).Sum(nil).(int) - ret2 := LambdaArray(makeUserArray()).Sum(func(u user) int { return u.age }) - t.Log(ret, ret2) - - us := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, - } - arr := LambdaArray(us) - fmt.Println("total user age is", arr.Sum(func(u user) int { return u.age })) -} - -func Test__array_Avg(t *testing.T) { - defer report(t, time.Now()) - ints := LambdaArray(makeIntArray()) - ret := ints.Average(nil) - t.Log(ret) - - us := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, - } - arr := LambdaArray(us) - fmt.Println("all user average age is", arr.Average(func(u user) int { return u.age })) -} - -func Test__array_Contain(t *testing.T) { - - defer report(t, time.Now()) - - want := makeIntArray() - arr := LambdaArray(want) - ret := []bool{arr.Contains(7777), arr.Contains(count + 1)} - isTrue(t, ret[0]) - isFalse(t, ret[1]) - - users := LambdaArray(makeUserArray()) - ret = []bool{ - users.Contains(user{"un:18", 18}), - users.Contains(user{"zzz", 18}), - } - isTrue(t, ret[0]) - isFalse(t, ret[1]) - - ret = []bool{ - users.Contains(func(u user) bool { return u.age > 5000 }), - users.Contains(func(u user) bool { return u.age > count+1 }), - } - isTrue(t, ret[0]) - isFalse(t, ret[1]) - - us := []user{ - {"Abraham", 20}, - {"Edith", 25}, - {"Charles", 40}, - {"Anthony", 26}, - {"Abel", 33}, - } - arr2 := LambdaArray(us) - fmt.Println(arr2.Contains(func(u user) bool { return u.age > 25 })) - - fmt.Println(LambdaArray([]int{1, 2, 3, 4, 5, 6, 7, 8, 9}).Contains(9)) - fmt.Println(LambdaArray([]int{1, 2, 3, 4, 5, 6, 7, 8, 9}).Contains(0)) -} - -func (u user) Equals(obj interface{}) bool { - if c, ok := obj.(user); ok { - return u.name == c.name && u.age == c.age - } - return false -} diff --git a/num/compare.go b/num/compare.go deleted file mode 100644 index 287b86374b2710a6fd6e6059e31d4247dbbc27d4..0000000000000000000000000000000000000000 --- a/num/compare.go +++ /dev/null @@ -1,106 +0,0 @@ -package num - -import ( - "errors" - "fmt" - "math" - "reflect" - "strings" -) - -func BasicComparator(ele interface{}) (Compare, error) { - if c, ok := ele.(Compare); ok { - return c, nil - } - - k := reflect.TypeOf(ele).Kind() - - support := []reflect.Kind{ - reflect.Int, - reflect.Uint8, - reflect.Uint16, - reflect.Uint32, - reflect.Uint64, - reflect.Int8, - reflect.Int16, - reflect.Uint32, - reflect.Int64, - reflect.Float32, - reflect.Float64, - reflect.String, - } - if contain(support, k) { - return &BasicCompare{ele}, nil - } - - return nil, errors.New("unknown type") -} - -func contain(kinds []reflect.Kind, target reflect.Kind) bool { - for _, k := range kinds { - if k == target { - return true - } - } - return false -} - -type Compare interface { - CompareTo(a interface{}) int -} - -type BasicCompare struct { - v interface{} -} - -func (p *BasicCompare) CompareTo(a interface{}) int { - - vt, at := reflect.TypeOf(p.v), reflect.TypeOf(a) - if vt.Kind() != at.Kind() { - panic(fmt.Sprintf("%s is not %s", vt.String(), at.String())) - } - - switch p.v.(type) { - case int: - return p.v.(int) - a.(int) - case uint8: - return int(p.v.(uint8) - a.(uint8)) - case uint16: - return int(p.v.(uint16) - a.(uint16)) - case uint32: - return int(p.v.(uint32) - a.(uint32)) - case uint64: - return int(p.v.(uint64) - a.(uint64)) - case int8: - return int(p.v.(int8) - a.(int8)) - case int16: - return int(p.v.(int16) - a.(int16)) - case int32: - return int(p.v.(int32) - a.(int32)) - case int64: - return int(p.v.(int64) - a.(int64)) - case float32: - v := p.v.(float32) - a.(float32) - if v > 0 { - return int(math.Ceil(float64(v))) - } else if v == 0 { - return 0 - } else { - return int(math.Floor(float64(v))) - } - - case float64: - v := p.v.(float64) - a.(float64) - if v > 0 { - return int(math.Ceil(v)) - } else if v == 0 { - return 0 - } else { - return int(math.Floor(v)) - } - case string: - return strings.Compare(p.v.(string), a.(string)) - default: - panic("unknown type " + at.String()) - } -} diff --git a/num/equal.go b/num/equal.go deleted file mode 100644 index 64f13139516ec746db5a39d878b3da07b7c02b8a..0000000000000000000000000000000000000000 --- a/num/equal.go +++ /dev/null @@ -1,5 +0,0 @@ -package num - -type Equal interface { - Equals(obj interface{}) bool -} diff --git a/series_generic.go b/series_generic.go index 2d10343f6aadffed8afe7811592494d587940ada..c7f63efde6fc9d66923b1e12b71c3b21b85adaba 100644 --- a/series_generic.go +++ b/series_generic.go @@ -1,28 +1,14 @@ package pandas import ( + "fmt" "gitee.com/quant1x/pandas/stat" "reflect" ) -// 初始化全局的私有变量 -var ( - rawBool bool = true - typeBool = reflect.TypeOf([]bool{}) - rawInt32 int32 = int32(0) - typeInt32 = reflect.TypeOf([]int32{}) - rawInt64 int64 = int64(0) - typeInt64 = reflect.TypeOf([]int64{}) - rawFloat32 float32 = float32(0) - typeFloat32 = reflect.TypeOf([]float32{}) - rawFloat64 float64 = float64(0) - typeFloat64 = reflect.TypeOf([]float64{}) - typeString = reflect.TypeOf([]string{}) -) - // NewSeriesWithoutType 不带类型创新一个新series -func NewSeriesWithoutType(name string, values ...interface{}) Series { - _type, err := detectTypeBySlice(values...) +func NewSeriesWithoutType(name string, values ...any) stat.Series { + _type, err := stat.DetectTypeBySlice(values...) if err != nil { return nil } @@ -30,64 +16,82 @@ func NewSeriesWithoutType(name string, values ...interface{}) Series { } // NewSeriesWithType 通过类型创新一个新series -func NewSeriesWithType(_type Type, name string, values ...interface{}) Series { +func NewSeriesWithType(_type stat.Type, name string, values ...any) stat.Series { frame := NDFrame{ formatter: stat.DefaultFormatter, name: name, - type_: SERIES_TYPE_INVAILD, + type_: stat.SERIES_TYPE_INVAILD, nilCount: 0, rows: 0, //values: []E{}, } - //_type, err := detectTypeBySlice(values) - //if err != nil { - // return nil - //} + frame.type_ = _type - if frame.type_ == SERIES_TYPE_BOOL { + if frame.type_ == stat.SERIES_TYPE_BOOL { // bool - frame.values = reflect.MakeSlice(typeBool, 0, 0).Interface() - } else if frame.type_ == SERIES_TYPE_INT64 { + frame.values = reflect.MakeSlice(stat.TypeBool, 0, 0).Interface() + } else if frame.type_ == stat.SERIES_TYPE_INT64 { // int64 - frame.values = reflect.MakeSlice(typeInt64, 0, 0).Interface() - } else if frame.type_ == SERIES_TYPE_FLOAT32 { + frame.values = reflect.MakeSlice(stat.TypeInt64, 0, 0).Interface() + } else if frame.type_ == stat.SERIES_TYPE_FLOAT32 { // float32 - frame.values = reflect.MakeSlice(typeFloat32, 0, 0).Interface() - } else if frame.type_ == SERIES_TYPE_FLOAT64 { + frame.values = reflect.MakeSlice(stat.TypeFloat32, 0, 0).Interface() + } else if frame.type_ == stat.SERIES_TYPE_FLOAT64 { // float64 - frame.values = reflect.MakeSlice(typeFloat64, 0, 0).Interface() + frame.values = reflect.MakeSlice(stat.TypeFloat64, 0, 0).Interface() } else { // string, 字符串最后容错使用 - frame.values = reflect.MakeSlice(typeString, 0, 0).Interface() + frame.values = reflect.MakeSlice(stat.TypeString, 0, 0).Interface() } - //series.Data = make([]float64, 0) // Warning: filled with 0.0 (not NaN) - //size := len(series.values) - //size := 0 - //for idx, v := range values { - // switch val := v.(type) { - // case nil, int8, uint8, int16, uint16, int32, uint32, int64, uint64, int, uint, float32, float64, bool, string: - // // 基础类型 - // series_append(&frame, idx, size, val) - // default: - // vv := reflect.ValueOf(val) - // vk := vv.Kind() - // switch vk { - // //case reflect.Invalid: // {interface} nil - // // series.assign(idx, size, Nil2Float64) - // case reflect.Slice, reflect.Array: // 切片或数组 - // for i := 0; i < vv.Len(); i++ { - // tv := vv.Index(i).Interface() - // //series.assign(idx, size, str) - // series_append(&frame, idx, size, tv) - // } - // case reflect.Struct: // 忽略结构体 - // continue - // default: - // series_append(&frame, idx, size, nil) - // } - // } - //} frame.Append(values...) return &frame } + +// NewSeries 指定类型创建序列 +func NewSeries(t stat.Type, name string, vals any) stat.Series { + var series stat.Series + if t == stat.SERIES_TYPE_BOOL { + series = NewSeriesWithType(stat.SERIES_TYPE_BOOL, name, vals) + } else if t == stat.SERIES_TYPE_INT64 { + series = NewSeriesWithType(stat.SERIES_TYPE_INT64, name, vals) + } else if t == stat.SERIES_TYPE_STRING { + series = NewSeriesWithType(stat.SERIES_TYPE_STRING, name, vals) + } else if t == stat.SERIES_TYPE_FLOAT64 { + series = NewSeriesWithType(stat.SERIES_TYPE_FLOAT64, name, vals) + } else { + // 默认全部强制转换成float32 + series = NewSeriesWithType(stat.SERIES_TYPE_FLOAT32, name, vals) + } + return series +} + +// GenericSeries 泛型方法, 构造序列, 比其它方式对类型的统一性要求更严格 +func GenericSeries[T stat.GenericType](name string, values ...T) stat.Series { + // 第一遍, 确定类型, 找到第一个非nil的值 + var _type stat.Type = stat.SERIES_TYPE_STRING + for _, v := range values { + // 泛型处理这里会出现一个错误, invalid operation: v == nil (mismatched types T and untyped nil) + //if v == nil { + // continue + //} + vv := reflect.ValueOf(v) + vk := vv.Kind() + switch vk { + case reflect.Bool: + _type = stat.SERIES_TYPE_BOOL + case reflect.Int64: + _type = stat.SERIES_TYPE_INT64 + case reflect.Float32: + _type = stat.SERIES_TYPE_FLOAT32 + case reflect.Float64: + _type = stat.SERIES_TYPE_FLOAT64 + case reflect.String: + _type = stat.SERIES_TYPE_STRING + default: + panic(fmt.Errorf("unknown type, %+v", v)) + } + break + } + return NewSeries(_type, name, values) +} diff --git a/series_string.go b/series_string.go new file mode 100644 index 0000000000000000000000000000000000000000..8a2cb2f25c3e5c43460b10ca8cdbc6de965fed8a --- /dev/null +++ b/series_string.go @@ -0,0 +1,167 @@ +package pandas + +import ( + "gitee.com/quant1x/pandas/stat" +) + +type SeriesString struct { + stat.NDArray[string] + name string +} + +func (self SeriesString) Name() string { + return self.name +} + +func (self SeriesString) Rename(name string) { + self.name = name +} + +func (self SeriesString) Type() stat.Type { + return self.NDArray.Type() +} + +func (self SeriesString) Values() any { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) NaN() any { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Floats() []float32 { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) DTypes() []stat.DType { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Ints() []stat.Int { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Len() int { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Less(i, j int) bool { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Swap(i, j int) { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Empty(t ...stat.Type) stat.Series { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Copy() stat.Series { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Records() []string { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Subset(start, end int, opt ...any) stat.Series { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Repeat(x any, repeats int) stat.Series { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Shift(periods int) stat.Series { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Rolling(param any) stat.RollingAndExpandingMixin { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Mean() stat.DType { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) StdDev() stat.DType { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) FillNa(v any, inplace bool) stat.Series { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Max() any { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Min() any { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Select(r stat.ScopeLimit) stat.Series { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Append(values ...any) stat.Series { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Apply(f func(idx int, v any)) { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Logic(f func(idx int, v any) bool) []bool { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Diff(param any) (s stat.Series) { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Ref(param any) (s stat.Series) { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Std() stat.DType { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) Sum() stat.DType { + //TODO implement me + panic("implement me") +} + +func (self SeriesString) EWM(alpha stat.EW) stat.ExponentialMovingWindow { + //TODO implement me + panic("implement me") +} diff --git a/series_string_test.go b/series_string_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c8705bd1111eee761a2b713bb57e58efe04b5414 --- /dev/null +++ b/series_string_test.go @@ -0,0 +1,13 @@ +package pandas + +import ( + "fmt" + "testing" +) + +func TestSeriesString_Type(t *testing.T) { + s := new(SeriesString) + fmt.Println(s.Type()) + s1 := s.NDArray + fmt.Println(s1) +} diff --git a/stat/and.go b/stat/and.go new file mode 100644 index 0000000000000000000000000000000000000000..79cf5554f0392d20cb3d5ef04333f9a88c71c9fe --- /dev/null +++ b/stat/and.go @@ -0,0 +1,50 @@ +package stat + +import "github.com/viterin/vek" + +// And 两者为真 +func And[T Number | ~bool](x, y []T) []bool { + switch vs := any(x).(type) { + case []bool: + return vek.And(vs, any(y).([]bool)) + case []int8: + return __and_go(vs, any(y).([]int8)) + case []uint8: + return __and_go(vs, any(y).([]uint8)) + case []int16: + return __and_go(vs, any(y).([]int16)) + case []uint16: + return __and_go(vs, any(y).([]uint16)) + case []int32: + return __and_go(vs, any(y).([]int32)) + case []uint32: + return __and_go(vs, any(y).([]uint32)) + case []int64: + return __and_go(vs, any(y).([]int64)) + case []uint64: + return __and_go(vs, any(y).([]uint64)) + case []int: + return __and_go(vs, any(y).([]int)) + case []uint: + return __and_go(vs, any(y).([]uint)) + case []uintptr: + return __and_go(vs, any(y).([]uintptr)) + case []float32: + return __and_go(vs, any(y).([]float32)) + case []float64: + return __and_go(vs, any(y).([]float64)) + } + panic(Throw(x)) +} + +func __and_go[T Number](x, y []T) []bool { + d := make([]bool, len(x)) + for i := 0; i < len(x); i++ { + if x[i] != 0 && y[i] != 0 { + d[i] = true + } else { + d[i] = false + } + } + return d +} diff --git a/stat/builtin.go b/stat/builtin.go index 532e72bc251e28f6de30154ca85df2ec9b9cd6b7..0b9cccb8dd9e3999faea2545cd21c33fcdc2bacc 100644 --- a/stat/builtin.go +++ b/stat/builtin.go @@ -1,6 +1,7 @@ package stat import ( + gc "github.com/huandu/go-clone" "github.com/viterin/vek" "math" "reflect" @@ -69,3 +70,8 @@ func IsEmpty(s string) bool { return false } } + +// Clone 克隆一个any +func Clone(v any) any { + return gc.Clone(v) +} diff --git a/stat/frame.go b/stat/frame.go deleted file mode 100644 index 77b8130aa4ceb21346987ea812f8fd5f53720bda..0000000000000000000000000000000000000000 --- a/stat/frame.go +++ /dev/null @@ -1,82 +0,0 @@ -package stat - -type Frame interface { - // Name 取得series名称 - Name() string - // Rename renames the series. - Rename(name string) - - // Type returns the type of data the series holds. - // 返回series的数据类型 - Type() Type - // Values 获得全部数据集 - Values() any - - // NaN 输出默认的NaN - NaN() any - // Floats 强制转成[]float32 - Floats() []float32 - // DTypes 强制转[]stat.DType - DTypes() []DType - // Ints 强制转换成整型 - Ints() []Int - - // sort.Interface - - // Len 获得行数, 实现sort.Interface接口的获取元素数量方法 - Len() int - // Less 实现sort.Interface接口的比较元素方法 - Less(i, j int) bool - // Swap 实现sort.Interface接口的交换元素方法 - Swap(i, j int) - - // Empty returns an empty Series of the same type - Empty() Frame - // Copy 复制 - Copy() Frame - // Records returns the elements of a Series as a []string - Records() []string - // Subset 获取子集 - Subset(start, end int, opt ...any) Frame - // Repeat elements of an array. - Repeat(x any, repeats int) Frame - // Shift index by desired number of periods with an optional time freq. - // 使用可选的时间频率按所需的周期数移动索引. - Shift(periods int) Frame - // Rolling 序列化版本 - //Rolling(param any) RollingAndExpandingMixin - - // Mean calculates the average value of a series - Mean() DType - // StdDev calculates the standard deviation of a series - StdDev() DType - // FillNa Fill NA/NaN values using the specified method. - FillNa(v any, inplace bool) Frame - // Max 找出最大值 - Max() any - // Min 找出最小值 - Min() any - // Select 选取一段记录 - Select(r ScopeLimit) Frame - // Append 增加一批记录 - Append(values ...any) Frame - // Apply 接受一个回调函数 - Apply(f func(idx int, v any)) - // Logic 逻辑处理 - Logic(f func(idx int, v any) bool) []bool - // Diff 元素的第一个离散差 - Diff(param any) Frame - // Ref 引用其它周期的数据 - Ref(param any) Frame - // Std 计算标准差 - Std() DType - // Sum 计算累和 - Sum() DType - // EWM Provide exponentially weighted (EW) calculations. - // - // Exactly one of ``com``, ``span``, ``halflife``, or ``alpha`` must be - // provided if ``times`` is not provided. If ``times`` is provided, - // ``halflife`` and one of ``com``, ``span`` or ``alpha`` may be provided. - //EWM(alpha EW) ExponentialMovingWindow - -} diff --git a/stat/ndarray.go b/stat/ndarray.go index ef0a6c3457aed48c7b1024f7bc8f5cc4356e74e7..096b68c12c6938bbf422781d75c1c35265aefa41 100644 --- a/stat/ndarray.go +++ b/stat/ndarray.go @@ -1,6 +1,7 @@ package stat import ( + "gitee.com/quant1x/pandas/exception" gc "github.com/huandu/go-clone" "reflect" ) @@ -8,13 +9,11 @@ import ( type NDArray[T BaseType] []T func (self NDArray[T]) Name() string { - //TODO implement me - panic("implement me") + return "x" } func (self NDArray[T]) Rename(name string) { - //TODO implement me - panic("implement me") + } func (self NDArray[T]) Type() Type { @@ -58,12 +57,12 @@ func (self NDArray[T]) Ints() []Int { return d } -func (self NDArray[T]) Empty() Frame { - var empty []T +func (self NDArray[T]) Empty(tv ...Type) Series { + empty := []T{} return NDArray[T](empty) } -func (self NDArray[T]) Copy() Frame { +func (self NDArray[T]) Copy() Series { vlen := self.Len() return self.Subset(0, vlen, true) } @@ -77,7 +76,7 @@ func (self NDArray[T]) Records() []string { } -func (self NDArray[T]) Subset(start, end int, opt ...any) Frame { +func (self NDArray[T]) Subset(start, end int, opt ...any) Series { // 默认不copy var __optCopy bool = false if len(opt) > 0 { @@ -97,9 +96,10 @@ func (self NDArray[T]) Subset(start, end int, opt ...any) Frame { rows = vv.Len() if __optCopy && rows > 0 { vs = gc.Clone(vs) + //vs = slices.Clone(vs) } rows = vvs.Len() - var d Frame + var d Series d = NDArray[T](vs.([]T)) return d default: @@ -108,7 +108,7 @@ func (self NDArray[T]) Subset(start, end int, opt ...any) Frame { return self.Empty() } -func (self NDArray[T]) Repeat(x any, repeats int) Frame { +func (self NDArray[T]) Repeat(x any, repeats int) Series { var d any switch values := self.Values().(type) { case []bool: @@ -126,22 +126,28 @@ func (self NDArray[T]) Repeat(x any, repeats int) Frame { return NDArray[T](d.([]T)) } -func (self NDArray[T]) Shift(periods int) Frame { +func (self NDArray[T]) Shift(periods int) Series { values := self.Values().([]T) d := Shift(values, periods) return NDArray[T](d) } func (self NDArray[T]) Mean() DType { + if self.Len() < 1 { + return NaN() + } d := Mean2(self) return Any2DType(d) } func (self NDArray[T]) StdDev() DType { + if self.Len() < 1 { + return NaN() + } return self.Std() } -func (self NDArray[T]) FillNa(v any, inplace bool) Frame { +func (self NDArray[T]) FillNa(v any, inplace bool) Series { d := FillNa(self, v, inplace) return NDArray[T](d) } @@ -156,7 +162,7 @@ func (self NDArray[T]) Min() any { return d } -func (self NDArray[T]) Select(r ScopeLimit) Frame { +func (self NDArray[T]) Select(r ScopeLimit) Series { start, end, err := r.Limits(self.Len()) if err != nil { return nil @@ -179,24 +185,81 @@ func (self NDArray[T]) Logic(f func(idx int, v any) bool) []bool { return d } -func (self NDArray[T]) Diff(param any) Frame { +func (self NDArray[T]) Diff(param any) Series { d := Diff2(self, param) return NDArray[T](d) } -func (self NDArray[T]) Ref(param any) Frame { +func (self NDArray[T]) Ref(param any) Series { values := self.Values().([]T) d := Shift3(values, param) return NDArray[T](d) } func (self NDArray[T]) Std() DType { + if self.Len() < 1 { + return NaN() + } d := Std(self) return Any2DType(d) } func (self NDArray[T]) Sum() DType { + if self.Len() < 1 { + return NaN() + } values := Slice2DType(self) d := Sum(values) return Any2DType(d) } + +func (self NDArray[T]) Rolling(param any) RollingAndExpandingMixin { + var N []DType + switch v := param.(type) { + case int: + N = Repeat[DType](DType(v), self.Len()) + case []DType: + N = Align(v, DTypeNaN, self.Len()) + case Series: + vs := v.DTypes() + N = Align(vs, DTypeNaN, self.Len()) + default: + panic(exception.New(1, "error window")) + } + w := RollingAndExpandingMixin{ + Window: N, + Series: self, + } + return w +} + +func (self NDArray[T]) EWM(alpha EW) ExponentialMovingWindow { + atype := AlphaAlpha + param := 0.00 + adjust := alpha.Adjust + ignoreNA := alpha.IgnoreNA + if alpha.Com != 0 { + atype = AlphaCom + param = alpha.Com + } else if alpha.Span != 0 { + atype = AlphaSpan + param = alpha.Span + } else if alpha.HalfLife != 0 { + atype = AlphaHalfLife + param = alpha.HalfLife + } else { + atype = AlphaAlpha + param = alpha.Alpha + } + + dest := NewSeries[DType]() + dest = dest.Append(self) + return ExponentialMovingWindow{ + Data: dest, + AType: atype, + Param: param, + Adjust: adjust, + IgnoreNA: ignoreNA, + Cb: alpha.Callback, + } +} diff --git a/stat/ndarray_append.go b/stat/ndarray_append.go index 8ae823aa413eea93c8fde8018cb6629a5ceba324..9f15818d2bb6d7301a85554141679c91eb410472 100644 --- a/stat/ndarray_append.go +++ b/stat/ndarray_append.go @@ -3,7 +3,7 @@ package stat import "reflect" // 赋值 -func assign[T BaseType](type_ Type, array Frame, idx, size int, v T) Frame { +func assign[T BaseType](type_ Type, array Series, idx, size int, v T) Series { _vv := reflect.ValueOf(v) _vi := _vv.Interface() // float和string类型有可能是NaN, 对nil和NaN进行计数 @@ -68,7 +68,7 @@ func (self NDArray[T]) insert(idx, size int, v any) NDArray[T] { return self } -func (self NDArray[T]) Append(values ...any) Frame { +func (self NDArray[T]) Append(values ...any) Series { size := 0 for idx, v := range values { switch val := v.(type) { diff --git a/stat/ndarray_convert.go b/stat/ndarray_convert.go index 8c120d744975140c17d406b99166f0d14bf760a4..a11827f81e015aca9223bdff9beae7317e449517 100644 --- a/stat/ndarray_convert.go +++ b/stat/ndarray_convert.go @@ -7,7 +7,7 @@ import ( ) // 这里做数组统一转换 -func convert[T GenericType](s Frame, v T) { +func convert[T GenericType](s Series, v T) { values := s.Values() rawType := checkoutRawType(values) values, ok := values.([]T) @@ -15,7 +15,7 @@ func convert[T GenericType](s Frame, v T) { _ = ok } -func ToFloat32(s Frame) []float32 { +func ToFloat32(s Series) []float32 { length := s.Len() defaultSlice := vek32.Repeat(Nil2Float32, length) values := s.Values() @@ -37,7 +37,7 @@ func ToFloat32(s Frame) []float32 { } } -func ToFloat64(s Frame) []float64 { +func ToFloat64(s Series) []float64 { length := s.Len() defaultSlice := vek.Repeat(Nil2Float64, length) values := s.Values() @@ -59,7 +59,7 @@ func ToFloat64(s Frame) []float64 { } } -func ToBool(s Frame) []bool { +func ToBool(s Series) []bool { length := s.Len() defaultSlice := make([]bool, length) values := s.Values() diff --git a/stat/ndarray_test.go b/stat/ndarray_test.go index 93096aeb2248cad1e1fcdf3e539877e69e22cbfb..b005ad9f1293c65bd7d8766a3ca0aa80f7f377cc 100644 --- a/stat/ndarray_test.go +++ b/stat/ndarray_test.go @@ -24,7 +24,7 @@ func TestNDArrayAll(t *testing.T) { d := []float32{1, 2, 3, 4, 5} sh1 := (*reflect.SliceHeader)(unsafe.Pointer(&d)) fmt.Printf("s : %#v\n", sh1) - var s Frame + var s Series s = NDArray[float32](d) //s3 := []float32(s) //fmt.Println(s3) @@ -54,3 +54,14 @@ func TestNDArrayAll(t *testing.T) { fmt.Println(s) _ = s4 } + +func TestNDArray_Rolling(t *testing.T) { + d1 := []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} + s := NewSeries(d1...) + r1 := s.Rolling(5).Mean() + fmt.Println(r1) + + d2 := []float64{1, 2, 3, 4, 3, 3, 2, 1, Nil2Float64, Nil2Float64, Nil2Float64, Nil2Float64} + r2 := s.Rolling(d2).Mean() + fmt.Println(r2) +} diff --git a/stat/ndarray_type.go b/stat/ndarray_type.go index e0af29018e3f07a8443a25da8bd90813eafd2c1f..bdf97db46748de70181d781abc60a35fe6e9ecd3 100644 --- a/stat/ndarray_type.go +++ b/stat/ndarray_type.go @@ -81,16 +81,16 @@ var ( // 初始化全局的私有变量 var ( rawBool bool = true - typeBool = reflect.TypeOf([]bool{}) + TypeBool = reflect.TypeOf([]bool{}) rawInt32 int32 = int32(0) typeInt32 = reflect.TypeOf([]int32{}) rawInt64 int64 = int64(0) - typeInt64 = reflect.TypeOf([]int64{}) + TypeInt64 = reflect.TypeOf([]int64{}) rawFloat32 float32 = float32(0) - typeFloat32 = reflect.TypeOf([]float32{}) + TypeFloat32 = reflect.TypeOf([]float32{}) rawFloat64 float64 = float64(0) - typeFloat64 = reflect.TypeOf([]float64{}) - typeString = reflect.TypeOf([]string{}) + TypeFloat64 = reflect.TypeOf([]float64{}) + TypeString = reflect.TypeOf([]string{}) ) // 从泛型检测出类型 diff --git a/stat/rolling.go b/stat/rolling.go index d83b3071fea4bb89746a05ba82c9668b60ae9c68..1ff78a50c75e151860fc3758a386619b32ff6cd7 100644 --- a/stat/rolling.go +++ b/stat/rolling.go @@ -6,7 +6,7 @@ import ( // Rolling returns an array with elements that roll beyond the last position are re-introduced at the first. // 滑动窗口, 数据不足是用空数组占位 -func Rolling[T Number | bool](S []T, N any) [][]T { +func Rolling[T BaseType](S []T, N any) [][]T { sLen := len(S) // 这样就具备了序列化滑动窗口的特性了 var window []DType diff --git a/stat/rolling_count.go b/stat/rolling_count.go new file mode 100644 index 0000000000000000000000000000000000000000..17f42e736621020602453d3986f42b5b58259ac4 --- /dev/null +++ b/stat/rolling_count.go @@ -0,0 +1,24 @@ +package stat + +import ( + "github.com/viterin/vek" +) + +func (r RollingAndExpandingMixin) Count() (s Series) { + if r.Series.Type() != SERIES_TYPE_BOOL { + panic("不支持非bool序列") + } + values := make([]DType, r.Series.Len()) + for i, block := range r.GetBlocks() { + if block.Len() == 0 { + values[i] = 0 + continue + } + bs := block.Values().([]bool) + values[i] = DType(vek.Count(bs)) + } + s = r.Series.Empty(SERIES_TYPE_DTYPE) + s.Rename(r.Series.Name()) + s = s.Append(values) + return +} diff --git a/stat/rolling_max.go b/stat/rolling_max.go new file mode 100644 index 0000000000000000000000000000000000000000..56360768e0d50ba7fb87b05d24d1dd9b25ceedd7 --- /dev/null +++ b/stat/rolling_max.go @@ -0,0 +1,19 @@ +package stat + +func (r RollingAndExpandingMixin) Max() (s Series) { + s = r.Series.Empty() + for _, block := range r.GetBlocks() { + //// 1. 排序处理方式 + //if block.Len() == 0 { + // s.Append(s.NaN()) + // continue + //} + //sort.Sort(block) + //r := RangeFinite(-1) + //_s := block.Select(r) + //s.Append(_s.Values()) + // 2. Series.Max方法 + s.Append(block.Max()) + } + return +} diff --git a/stat/rolling_mean.go b/stat/rolling_mean.go new file mode 100644 index 0000000000000000000000000000000000000000..8b8bd6dbe1cd9f325ca635ca1ca9673dccb8fb97 --- /dev/null +++ b/stat/rolling_mean.go @@ -0,0 +1,14 @@ +package stat + +// Mean returns the rolling mean. +func (r RollingAndExpandingMixin) Mean() (s Series) { + var d []DType + for _, block := range r.GetBlocks() { + d = append(d, block.Mean()) + } + //s = pandas.NewSeries(SERIES_TYPE_DTYPE, r.series.Name(), d) + s = r.Series.Empty(SERIES_TYPE_DTYPE) + s.Rename(r.Series.Name()) + s = s.Append(d) + return +} diff --git a/stat/rolling_min.go b/stat/rolling_min.go new file mode 100644 index 0000000000000000000000000000000000000000..1881d58f4b9f0d0d522843b907559bdd040f561d --- /dev/null +++ b/stat/rolling_min.go @@ -0,0 +1,19 @@ +package stat + +func (r RollingAndExpandingMixin) Min() (s Series) { + s = r.Series.Empty() + for _, block := range r.GetBlocks() { + //// 1. 排序处理方式 + //if block.Len() == 0 { + // s.Append(s.NaN()) + // continue + //} + //sort.Sort(block) + //r := RangeFinite(0, 1) + //_s := block.Select(r) + //s.Append(_s.Values()) + // 2. Series.Max方法 + s.Append(block.Min()) + } + return +} diff --git a/stat/rolling_std.go b/stat/rolling_std.go new file mode 100644 index 0000000000000000000000000000000000000000..7ef3761aae3d7bfbedb88570837c1de5c7f44769 --- /dev/null +++ b/stat/rolling_std.go @@ -0,0 +1,19 @@ +package stat + +func (r RollingAndExpandingMixin) Std() Series { + s := r.Series.Empty() + for _, block := range r.GetBlocks() { + //// 1. 排序处理方式 + //if block.Len() == 0 { + // s.Append(s.NaN()) + // continue + //} + //sort.Sort(block) + //r := RangeFinite(-1) + //_s := block.Select(r) + //s.Append(_s.Values()) + // 2. Series.Max方法 + s.Append(block.Std()) + } + return s +} diff --git a/stat/rolling_sum.go b/stat/rolling_sum.go new file mode 100644 index 0000000000000000000000000000000000000000..315eac2486b02f6ffdd5b31860f8a98f4b994a1b --- /dev/null +++ b/stat/rolling_sum.go @@ -0,0 +1,13 @@ +package stat + +func (r RollingAndExpandingMixin) Sum() Series { + var d []DType + for _, block := range r.GetBlocks() { + d = append(d, block.Sum()) + } + //s := pandas.NewSeries(SERIES_TYPE_DTYPE, r.series.Name(), d) + s := r.Series.Empty(SERIES_TYPE_DTYPE) + s.Rename(r.Series.Name()) + s = s.Append(d) + return s +} diff --git a/stat/rolling_test.go b/stat/rolling_test.go index 0d2eb319c0a67b33c9032c49f3ec6b151507f3a4..e2c0000f0acea92f65298e8e9c3b6137e3f1aa20 100644 --- a/stat/rolling_test.go +++ b/stat/rolling_test.go @@ -36,8 +36,4 @@ func TestRolling(t *testing.T) { t.Errorf("Got %v, want %v", output, expected) } - //output = Rolling(testSliceFloat, []int8{3, 3, 3, 3, 3, 3}) - //if reflect.DeepEqual(expected, output) != true { - // t.Errorf("Got %v, want %v", output, expected) - //} } diff --git a/stat/series.go b/stat/series.go index 71d8f894a68e689979efdbc984932530fa882874..b0447cb5a491a53f1b608805f98db63b14bc6f3f 100644 --- a/stat/series.go +++ b/stat/series.go @@ -1,13 +1,151 @@ package stat -type series interface { - // Len 长度 +import ( + "fmt" + "reflect" +) + +type Series interface { + // Name 取得series名称 + Name() string + // Rename renames the series. + Rename(name string) + // Type returns the type of Data the series holds. + // 返回series的数据类型 + Type() Type + // Values 获得全部数据集 + Values() any + + // NaN 输出默认的NaN + NaN() any + // Floats 强制转成[]float32 + Floats() []float32 + // DTypes 强制转[]stat.DType + DTypes() []DType + // Ints 强制转换成整型 + Ints() []Int + + // sort.Interface + + // Len 获得行数, 实现sort.Interface接口的获取元素数量方法 Len() int + // Less 实现sort.Interface接口的比较元素方法 + Less(i, j int) bool + // Swap 实现sort.Interface接口的交换元素方法 + Swap(i, j int) + + // Empty returns an empty Series of the same type + Empty(t ...Type) Series + // Copy 复制 + Copy() Series + // Records returns the elements of a Series as a []string + Records() []string + // Subset 获取子集 + Subset(start, end int, opt ...any) Series + // Repeat elements of an array. + Repeat(x any, repeats int) Series + // Shift index by desired number of periods with an optional time freq. + // 使用可选的时间频率按所需的周期数移动索引. + Shift(periods int) Series + // Rolling 序列化版本 + Rolling(param any) RollingAndExpandingMixin + // Mean calculates the average value of a series + Mean() DType + // StdDev calculates the standard deviation of a series + StdDev() DType + // FillNa Fill NA/NaN values using the specified method. + FillNa(v any, inplace bool) Series + // Max 找出最大值 + Max() any + // Min 找出最小值 + Min() any + // Select 选取一段记录 + Select(r ScopeLimit) Series + // Append 增加一批记录 + Append(values ...any) Series + // Apply 接受一个回调函数 + Apply(f func(idx int, v any)) + // Logic 逻辑处理 + Logic(f func(idx int, v any) bool) []bool + // Diff 元素的第一个离散差 + Diff(param any) (s Series) + // Ref 引用其它周期的数据 + Ref(param any) (s Series) + // Std 计算标准差 + Std() DType + // Sum 计算累和 + Sum() DType + // EWM Provide exponentially weighted (EW) calculations. + // + // Exactly one of ``com``, ``span``, ``halflife``, or ``alpha`` must be + // provided if ``times`` is not provided. If ``times`` is provided, + // ``halflife`` and one of ``com``, ``span`` or ``alpha`` may be provided. + EWM(alpha EW) ExponentialMovingWindow } -// Series 数据序列化 -// -// 第一个参数, data -func Series(data any, args any) { +// DetectTypeBySlice 检测类型 +func DetectTypeBySlice(arr ...any) (Type, error) { + var hasFloat32s, hasFloat64s, hasInts, hasBools, hasStrings bool + for _, v := range arr { + switch value := v.(type) { + case string: + hasStrings = true + continue + case float32: + hasFloat32s = true + continue + case float64: + hasFloat64s = true + continue + case int, int32, int64: + hasInts = true + continue + case bool: + hasBools = true + continue + default: + vv := reflect.ValueOf(v) + vk := vv.Kind() + switch vk { + case reflect.Slice, reflect.Array: // 切片或数组 + for i := 0; i < vv.Len(); i++ { + tv := vv.Index(i).Interface() + t_, err := DetectTypeBySlice(tv) + if err == nil { + return t_, nil + } + } + case reflect.Struct: // 忽略结构体 + continue + default: + } + _ = value + } + } + + switch { + case hasStrings: + return SERIES_TYPE_STRING, nil + case hasBools: + return SERIES_TYPE_BOOL, nil + case hasFloat32s: + return SERIES_TYPE_FLOAT32, nil + case hasFloat64s: + return SERIES_TYPE_FLOAT64, nil + case hasInts: + return SERIES_TYPE_INT64, nil + default: + return SERIES_TYPE_STRING, fmt.Errorf("couldn't detect type") + } +} +// NewSeries 构建一个新的Series +func NewSeries[T BaseType](data ...T) Series { + var S Series + values := []T{} + if len(data) > 0 { + values = append(values, data...) + } + S = NDArray[T](values) + return S } diff --git a/stat/series_test.go b/stat/series_test.go new file mode 100644 index 0000000000000000000000000000000000000000..56b1aa08aafe2eea54a896c678b8e0093f63319f --- /dev/null +++ b/stat/series_test.go @@ -0,0 +1,13 @@ +package stat + +import ( + "fmt" + "testing" +) + +func TestNewSeries(t *testing.T) { + d1 := []DType{} + fmt.Println(d1) + s1 := NewSeries[DType]() + fmt.Println(s1) +} diff --git a/stat/window_ewm.go b/stat/window_ewm.go new file mode 100644 index 0000000000000000000000000000000000000000..a06bdabfa677f9520362228895c862d32bf201a2 --- /dev/null +++ b/stat/window_ewm.go @@ -0,0 +1,145 @@ +package stat + +import "math" + +type AlphaType int + +// https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.ewm.html +const ( + // AlphaAlpha Specify smoothing factor α directly, 0<α≤1. + AlphaAlpha AlphaType = iota + // AlphaCom Specify decay in terms of center of mass, α=1/(1+com), for com ≥ 0. + AlphaCom + // AlphaSpan Specify decay in terms of span, α=2/(span+1), for span ≥ 1. + AlphaSpan + // AlphaHalfLife Specify decay in terms of half-life, α=1−exp(−ln(2)/halflife), for halflife > 0. + AlphaHalfLife +) + +// EW (Factor) 指数加权(EW)计算Alpha 结构属性非0即为有效启动同名算法 +type EW struct { + Com DType // 根据质心指定衰减 + Span DType // 根据跨度指定衰减 + HalfLife DType // 根据半衰期指定衰减 + Alpha DType // 直接指定的平滑因子α + Adjust bool // 除以期初的衰减调整系数以核算 相对权重的不平衡(将 EWMA 视为移动平均线) + IgnoreNA bool // 计算权重时忽略缺失值 + Callback func(idx int) DType +} + +// ExponentialMovingWindow 加权移动窗口 +type ExponentialMovingWindow struct { + Data Series // 序列 + AType AlphaType // 计算方式: com/span/halflefe/alpha + Param DType // 参数类型为浮点 + Adjust bool // 默认为真, 是否调整, 默认真时, 计算序列的EW移动平均线, 为假时, 计算指数加权递归 + IgnoreNA bool // 默认为假, 计算权重时是否忽略缺失值NaN + minPeriods int // 默认为0, 窗口中具有值所需的最小观测值数,否则结果为NaN + axis int // {0,1}, 默认为0, 0跨行计算, 1跨列计算 + Cb func(idx int) DType +} + +func (w ExponentialMovingWindow) Mean() Series { + var alpha DType + + switch w.AType { + case AlphaAlpha: + if w.Param <= 0 { + panic("alpha param must be > 0") + } + alpha = w.Param + + case AlphaCom: + if w.Param <= 0 { + panic("com param must be >= 0") + } + alpha = 1 / (1 + w.Param) + + case AlphaSpan: + if w.Param < 1 { + panic("span param must be >= 1") + } + alpha = 2 / (w.Param + 1) + + case AlphaHalfLife: + if w.Param <= 0 { + panic("halflife param must be > 0") + } + alpha = 1 - math.Exp(-math.Ln2/w.Param) + } + + return w.applyMean(w.Data, alpha) +} + +func (w ExponentialMovingWindow) applyMean(data Series, alpha DType) Series { + if w.Adjust { + w.adjustedMean(data, alpha, w.IgnoreNA) + } else { + w.notadjustedMean(data, alpha, w.IgnoreNA) + } + return data +} + +func (w ExponentialMovingWindow) adjustedMean(data Series, alpha DType, ignoreNA bool) { + var ( + values = data.Values().([]DType) + weight DType = 1 + last = values[0] + ) + + alpha = 1 - alpha + for t := 1; t < len(values); t++ { + + w := alpha*weight + 1 + x := values[t] + if DTypeIsNaN(x) { + if ignoreNA { + weight = w + } + values[t] = last + continue + } + + last = last + (x-last)/(w) + weight = w + values[t] = last + } +} + +func (w ExponentialMovingWindow) notadjustedMean(data Series, alpha DType, ignoreNA bool) { + hasCallback := false + if DTypeIsNaN(alpha) { + hasCallback = true + alpha = w.Cb(0) + } + var ( + count int + values = data.Values().([]DType) + beta = 1 - alpha + last = values[0] + ) + if DTypeIsNaN(last) { + last = 0 + values[0] = last + } + for t := 1; t < len(values); t++ { + x := values[t] + + if DTypeIsNaN(x) { + values[t] = last + continue + } + if hasCallback { + alpha = w.Cb(t) + beta = 1 - alpha + } + // yt = (1−α)*y(t−1) + α*x(t) + last = (beta * last) + (alpha * x) + if DTypeIsNaN(last) { + last = values[t-1] + } + values[t] = last + + count++ + } +} diff --git a/stat/window_rolling.go b/stat/window_rolling.go new file mode 100644 index 0000000000000000000000000000000000000000..d9d9f7f575c8c8accc582b8fdd69f3def7d78a9d --- /dev/null +++ b/stat/window_rolling.go @@ -0,0 +1,40 @@ +package stat + +// RollingAndExpandingMixin 滚动和扩展静态横切 +type RollingAndExpandingMixin struct { + Window []DType + Series Series +} + +func (r RollingAndExpandingMixin) GetBlocks() (blocks []Series) { + for i := 0; i < r.Series.Len(); i++ { + N := r.Window[i] + if DTypeIsNaN(N) || int(N) > i+1 { + blocks = append(blocks, r.Series.Empty()) + continue + } + window := int(N) + start := i + 1 - window + end := i + 1 + blocks = append(blocks, r.Series.Subset(start, end, false)) + } + + return +} + +// Apply 接受一个回调 +func (r RollingAndExpandingMixin) Apply(f func(S Series, N DType) DType) (s Series) { + values := make([]DType, r.Series.Len()) + for i, block := range r.GetBlocks() { + if block.Len() == 0 { + values[i] = DTypeNaN + continue + } + v := f(block, r.Window[i]) + values[i] = v + } + s = r.Series.Empty(SERIES_TYPE_DTYPE) + s.Rename(r.Series.Name()) + s = s.Append(values) + return +} diff --git a/strategy/no1.go b/strategy/no1.go index 87016629f91437959f0653c4458e8707a37683bf..8c2920af7c63fe127f507e22309adff89a2097e8 100644 --- a/strategy/no1.go +++ b/strategy/no1.go @@ -1,10 +1,11 @@ package main import ( - "gitee.com/quant1x/pandas" + pandas "gitee.com/quant1x/pandas" "gitee.com/quant1x/pandas/data/cache" "gitee.com/quant1x/pandas/data/security" . "gitee.com/quant1x/pandas/formula" + "github.com/mymmsc/gox/logger" "github.com/mymmsc/gox/util/treemap" ) @@ -35,24 +36,36 @@ func (this *FormulaNo1) Evaluate(fullCode string, info *security.StaticBasic, re df.SetNames("date", "open", "high", "low", "close", "volume") // 收盘价序列 CLOSE := df.Col("close") + days := CLOSE.Len() // 取5、10、20日均线 ma5 := MA(CLOSE, 5) ma10 := MA(CLOSE, 10) ma20 := MA(CLOSE, 20) + if len(ma5) != days || len(ma10) != days || len(ma20) != days { + logger.Errorf("均线, 数据没对齐") + } // 两个金叉 - c1 := CROSS2(ma5, ma10) - c2 := CROSS2(ma10, ma20) - + c1 := CROSS(ma5, ma10) + c2 := CROSS(ma10, ma20) + if len(c1) != days || len(c2) != days { + logger.Errorf("金叉, 数据没对齐") + } // 两个统计 - r1 := COUNT2(c1, N) - r2 := COUNT2(c2, N) - + r1 := COUNT(c1, N) + r2 := COUNT(c2, N) + if len(r1) != days || len(r2) != days { + logger.Errorf("统计, 数据没对齐") + } // 横向对比 + d := AND(r1, r2) + if len(d) != days { + logger.Errorf("横向对比, 数据没对齐") + } + //cc1 := CompareGte(r1, 1) - days := CLOSE.Len() - rLen := len(r1) - if rLen > 1 && r1[rLen-1] >= 1 && r2[rLen-1] >= 1 { + rLen := len(d) + if rLen > 1 && d[rLen-1] { buy := ma10[days-1] sell := buy * 1.05 date := df.Col("date").Values().([]string)[days-1] diff --git a/strategy/strategy.go b/strategy/strategy.go index 51fbc372a8c018d55b5e5a0c00d1f3a6046b07c9..5a5397ba056fa96aefa256a1128e7777824f5ab5 100644 --- a/strategy/strategy.go +++ b/strategy/strategy.go @@ -10,9 +10,8 @@ import ( "gitee.com/quant1x/pandas/stat" "github.com/mymmsc/gox/logger" "github.com/mymmsc/gox/util/treemap" - termTable "github.com/olekukonko/tablewriter" - - "github.com/qianlnk/pgbar" + tableView "github.com/olekukonko/tablewriter" + progressbar "github.com/qianlnk/pgbar" "os" "runtime" "sync" @@ -58,7 +57,7 @@ func main() { var wg = sync.WaitGroup{} fmt.Println("Quant1X 预警系统") fmt.Printf("CPU: %d, AVX2: %t\n", cpuNum, stat.GetAvx2Enabled()) - bar := pgbar.NewBar(0, "执行["+api.Name()+"]", count) + bar := progressbar.NewBar(0, "执行["+api.Name()+"]", count) var mapStock *treemap.Map mapStock = treemap.NewWithStringComparator() mainStart := time.Now() @@ -85,7 +84,7 @@ func main() { elapsedTime := time.Since(mainStart) / time.Millisecond fmt.Printf("CPU: %d, AVX2: %t, 总耗时: %.3fs, 总记录: %d, 平均: %.3f/s\n", cpuNum, stat.GetAvx2Enabled(), float64(elapsedTime)/1000, count, float64(count)/(float64(elapsedTime)/1000)) logger.Infof("CPU: %d, AVX2: %t, 总耗时: %.3fs, 总记录: %d, 平均: %.3f/s", cpuNum, stat.GetAvx2Enabled(), float64(elapsedTime)/1000, count, float64(count)/(float64(elapsedTime)/1000)) - table := termTable.NewWriter(os.Stdout) + table := tableView.NewWriter(os.Stdout) var row ResultInfo table.SetHeader(row.Headers()) diff --git a/builtin.go b/v1/builtin.go similarity index 99% rename from builtin.go rename to v1/builtin.go index 7af328d52c93458761460a92ab3935a3d1d46be3..9a58cef188a996656b77fcacdfaaec5efb961934 100644 --- a/builtin.go +++ b/v1/builtin.go @@ -1,4 +1,4 @@ -package pandas +package v1 import ( gc "github.com/huandu/go-clone" diff --git a/builtin_test.go b/v1/builtin_test.go similarity index 98% rename from builtin_test.go rename to v1/builtin_test.go index 6c208502ce29546c8164dea98b0b5e42f84ae0d2..dc85afc60f165c7b43a260d5f629941baf6dd6d8 100644 --- a/builtin_test.go +++ b/v1/builtin_test.go @@ -1,4 +1,4 @@ -package pandas +package v1 import ( "fmt" diff --git a/v1/dataframe.go b/v1/dataframe.go new file mode 100644 index 0000000000000000000000000000000000000000..bdfeb6a5c4d871e18682e7382dab834fe8d637c5 --- /dev/null +++ b/v1/dataframe.go @@ -0,0 +1,232 @@ +package v1 + +import ( + "fmt" + "sort" +) + +// DataFrame 以gota的DataFrame的方法为主, 兼顾新流程, 避免单元格元素结构化 +type DataFrame struct { + columns []Series + ncols int + nrows int + + // deprecated: Use Error() instead + Err error +} + +// NewDataFrame is the generic DataFrame constructor +func NewDataFrame(se ...Series) DataFrame { + if se == nil || len(se) == 0 { + return DataFrame{Err: fmt.Errorf("empty DataFrame")} + } + + columns := make([]Series, len(se)) + for i, s := range se { + var d Series + if s.Type() == SERIES_TYPE_INT64 { + d = NewSeries(SERIES_TYPE_INT64, s.Name(), s.Values()) + } else if s.Type() == SERIES_TYPE_BOOL { + d = NewSeries(SERIES_TYPE_BOOL, s.Name(), s.Values()) + } else if s.Type() == SERIES_TYPE_STRING { + d = NewSeries(SERIES_TYPE_STRING, s.Name(), s.Values()) + } else if s.Type() == SERIES_TYPE_FLOAT32 { + d = NewSeries(SERIES_TYPE_FLOAT32, s.Name(), s.Values()) + } else { + d = NewSeries(SERIES_TYPE_FLOAT64, s.Name(), s.Values()) + } + columns[i] = d + } + nrows, ncols, err := checkColumnsDimensions(columns...) + if err != nil { + return DataFrame{Err: err} + } + + // Fill DataFrame base structure + df := DataFrame{ + columns: columns, + ncols: ncols, + nrows: nrows, + } + colnames := df.Names() + fixColnames(colnames) + for i, colname := range colnames { + df.columns[i].Rename(colname) + } + return df +} + +// Dims retrieves the dimensions of a DataFrame. +func (self DataFrame) Dims() (int, int) { + return self.Nrow(), self.Ncol() +} + +// Nrow returns the number of rows on a DataFrame. +func (self DataFrame) Nrow() int { + return self.nrows +} + +// Ncol returns the number of columns on a DataFrame. +func (self DataFrame) Ncol() int { + return self.ncols +} + +// Returns error or nil if no error occured +func (self DataFrame) Error() error { + return self.Err +} + +// 检查列的尺寸 +func checkColumnsDimensions(se ...Series) (nrows, ncols int, err error) { + ncols = len(se) + nrows = -1 + if se == nil || ncols == 0 { + err = fmt.Errorf("no Series given") + return + } + for i, s := range se { + //if s.Err != nil { + // err = fmt.Errorf("error on series %d: %v", i, s.Err) + // return + //} + if nrows == -1 { + nrows = s.Len() + } + if nrows != s.Len() { + err = fmt.Errorf("arguments have different dimensions") + return + } + _ = i + } + return +} + +// Types returns the types of the columns on a DataFrame. +func (self DataFrame) Types() []string { + coltypes := make([]string, self.ncols) + for i, s := range self.columns { + coltypes[i] = s.Type().String() + } + return coltypes +} + +// Records return the string record representation of a DataFrame. +func (self DataFrame) Records() [][]string { + var records [][]string + records = append(records, self.Names()) + if self.ncols == 0 || self.nrows == 0 { + return records + } + var tRecords [][]string + for _, col := range self.columns { + tRecords = append(tRecords, col.Records()) + } + records = append(records, transposeRecords(tRecords)...) + return records +} + +// Getters/Setters for DataFrame fields +// ==================================== + +// Names returns the name of the columns on a DataFrame. +func (self DataFrame) Names() []string { + colnames := make([]string, self.ncols) + for i, s := range self.columns { + colnames[i] = s.Name() + } + return colnames +} + +func transposeRecords(x [][]string) [][]string { + n := len(x) + if n == 0 { + return x + } + m := len(x[0]) + y := make([][]string, m) + for i := 0; i < m; i++ { + z := make([]string, n) + for j := 0; j < n; j++ { + z[j] = x[j][i] + } + y[i] = z + } + return y +} + +// fixColnames assigns a name to the missing column names and makes it so that the +// column names are unique. +func fixColnames(colnames []string) { + // Find duplicated and missing colnames + dupnamesidx := make(map[string][]int) + var missingnames []int + for i := 0; i < len(colnames); i++ { + a := colnames[i] + if a == "" { + missingnames = append(missingnames, i) + continue + } + // for now, dupnamesidx contains the indices of *all* the columns + // the columns with unique locations will be removed after this loop + dupnamesidx[a] = append(dupnamesidx[a], i) + } + // NOTE: deleting a map key in a range is legal and correct in Go. + for k, places := range dupnamesidx { + if len(places) < 2 { + delete(dupnamesidx, k) + } + } + // Now: dupnameidx contains only keys that appeared more than once + + // Autofill missing column names + counter := 0 + for _, i := range missingnames { + proposedName := fmt.Sprintf("X%d", counter) + for findInStringSlice(proposedName, colnames) != -1 { + counter++ + proposedName = fmt.Sprintf("X%d", counter) + } + colnames[i] = proposedName + counter++ + } + + // Sort map keys to make sure it always follows the same order + var keys []string + for k := range dupnamesidx { + keys = append(keys, k) + } + sort.Strings(keys) + + // Add a suffix to the duplicated colnames + for _, name := range keys { + idx := dupnamesidx[name] + if name == "" { + name = "X" + } + counter := 0 + for _, i := range idx { + proposedName := fmt.Sprintf("%s_%d", name, counter) + for findInStringSlice(proposedName, colnames) != -1 { + counter++ + proposedName = fmt.Sprintf("%s_%d", name, counter) + } + colnames[i] = proposedName + counter++ + } + } +} + +func findInStringSlice(str string, s []string) int { + for i, e := range s { + if e == str { + return i + } + } + return -1 +} + +// Read/Write Methods +// ================= + +// LoadOption is the type used to configure the load of elements +type LoadOption func(*loadOptions) diff --git a/v1/dataframe_csv.go b/v1/dataframe_csv.go new file mode 100644 index 0000000000000000000000000000000000000000..fbc2be66baa7071ed6fc090ab70d091ad363768c --- /dev/null +++ b/v1/dataframe_csv.go @@ -0,0 +1,105 @@ +package v1 + +import ( + "encoding/csv" + "github.com/mymmsc/gox/api" + "github.com/mymmsc/gox/logger" + "github.com/mymmsc/gox/util/homedir" + "io" + "os" +) + +// ReadCSV reads a CSV file from a io.Reader and builds a DataFrame with the +// resulting records. +// 支持文件名和io两种方式读取数据 +func ReadCSV(in any, options ...LoadOption) DataFrame { + var ( + reader io.Reader + filename string + ) + switch param := in.(type) { + case io.Reader: + reader = param + case string: + filename = param + } + + if !IsEmpty(filename) { + filepath, err := homedir.Expand(filename) + if err != nil { + logger.Errorf("%s, error=%+v\n", filename, err) + return DataFrame{} + } + csvFile, err := os.Open(filepath) + if err != nil { + logger.Errorf("%s, error=%+v\n", filename, err) + return DataFrame{} + } + defer api.CloseQuietly(csvFile) + reader = csvFile + } + + csvReader := csv.NewReader(reader) + cfg := loadOptions{ + delimiter: ',', + lazyQuotes: false, + comment: 0, + } + for _, option := range options { + option(&cfg) + } + + csvReader.Comma = cfg.delimiter + csvReader.LazyQuotes = cfg.lazyQuotes + csvReader.Comment = cfg.comment + + records, err := csvReader.ReadAll() + if err != nil { + return DataFrame{Err: err} + } + return LoadRecords(records, options...) +} + +// WriteCSV writes the DataFrame to the given io.Writer as a CSV file. +// 支持文件名和io两种方式写入数据 +func (self DataFrame) WriteCSV(out any, options ...WriteOption) error { + var ( + writer io.Writer + filename string + ) + switch param := out.(type) { + case io.Writer: + writer = param + case string: + filename = param + } + + if !IsEmpty(filename) { + filepath, err := homedir.Expand(filename) + if err != nil { + return err + } + csvFile, err := os.Create(filepath) + if err != nil { + return err + } + defer api.CloseQuietly(csvFile) + writer = csvFile + } + // Set the default write options + cfg := writeOptions{ + writeHeader: true, + } + + // Set any custom write options + for _, option := range options { + option(&cfg) + } + + records := self.Records() + if !cfg.writeHeader { + records = records[1:] + } + + return csv.NewWriter(writer).WriteAll(records) +} diff --git a/v1/dataframe_csv_test.go b/v1/dataframe_csv_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c46aac5e2c3981b3039bc28dba6f35e010f092c4 --- /dev/null +++ b/v1/dataframe_csv_test.go @@ -0,0 +1,89 @@ +package v1 + +import ( + "bytes" + "encoding/json" + "fmt" + "strings" + "testing" +) + +func TestCsv(t *testing.T) { + csvStr := ` +Country,Date,Age,Amount,Id,close +"United States",2012-02-01,50,112.1,01234,1.23 +"United States",2012-02-01,32,321.31,54320,1.23 +"United Kingdom",2012-02-01,17,18.2,12345,1.23 +"United States",2012-02-01,32,321.31,54320,1.23 +"United Kingdom",2012-02-01,NA,18.2,12345,1.23 +"United States",2012-02-01,32,321.31,54320,1.23 +"United States",2012-02-01,32,321.31,54320,1.23 +Spain,2012-02-01,66,555.42,00241,1.23 +` + df := ReadCSV(strings.NewReader(csvStr)) + fmt.Println(df) + filename := "../testfiles/test-tutorials-w01.csv" + _ = df.WriteCSV(filename) + buf := new(bytes.Buffer) + _ = df.WriteCSV(buf) + df = ReadCSV(filename) + fmt.Println(df) + df.SetNames("a", "b", "c", "d", "e") + //s1 := df.Col("d") + //fmt.Println(s1) + // + //closes := df.Col("d") + //ma5 := closes.RollingV1(5).Mean() + //dframe.NewSeries(closes, dframe.Floats, "") + //fmt.Println(ma5) + d := df.Col("d") + fmt.Println(d) + _ = csvStr + +} + +type T1 struct { + X []int64 `json:"x"` +} + +func TestEwm(t *testing.T) { + //a := make(map[string][]int, 8) + t01 := map[string]int64{ + "x": 1, + } + fmt.Println(t01) + t02 := map[string][]int64{ + "x": {1, 2, 3, 4, 5, 6, 7, 8, 9}, + } + fmt.Println(t02) + text := `{"x":[1,2,3,4,5,6,7,8,9]}` + reader := strings.NewReader(text) + parser := json.NewDecoder(reader) + var t1 T1 + a1 := parser.Decode(&t1) + fmt.Println(a1, t1) + var t2 map[string][]int + a2 := parser.Decode(&t2) + fmt.Println(a2, t2) + //df := dframe.ReadJSON(reader) + //fmt.Println(df) + //values := []int64{1, 2, 3, 4, 5, 6, 7, 8, 9} + //s1 := dframe.NewSeries(values, dframe.Int, "x") + //df = dframe.NewFrame(s1) + //fmt.Println(df) + //xs := df.Col("x") + //r1 := xs.RollingV1(5).Mean() + //fmt.Println(r1) + // + //e1 := xs.EWM(dframe.Alpha{Span: 5, At: dframe.AlphaSpan}, false, false).Mean() + //fmt.Println(e1) + // + //df1 := dframe.NewFrame(e1) + //fmt.Println(df1) + // + //e2 := xs.EWM(dframe.Alpha{Span: 5, At: dframe.AlphaSpan}, true, false).Mean() + //fmt.Println(e2) + // + //df2 := dframe.NewFrame(e1, e2) + //fmt.Println(df2) +} diff --git a/v1/dataframe_excel.go b/v1/dataframe_excel.go new file mode 100644 index 0000000000000000000000000000000000000000..bd5319648289174de09713d41a601b62162b528c --- /dev/null +++ b/v1/dataframe_excel.go @@ -0,0 +1,86 @@ +package v1 + +import ( + "fmt" + "github.com/mymmsc/gox/logger" + "github.com/mymmsc/gox/util/homedir" + xlsv1 "github.com/tealeg/xlsx" + xlsv3 "github.com/tealeg/xlsx/v3" + "strings" +) + +// 读取excel文件 +func ReadExcel(filename string, options ...LoadOption) DataFrame { + if IsEmpty(filename) { + return DataFrame{Err: fmt.Errorf("filaname is empty")} + } + + filepath, err := homedir.Expand(filename) + if err != nil { + logger.Errorf("%s, error=%+v\n", filename, err) + return DataFrame{Err: err} + } + //filename := "test.xlsx" + xlFile, err := xlsv1.OpenFile(filepath) + if err != nil { + return DataFrame{Err: err} + } + colnums := make([][]string, 0) + for _, sheet := range xlFile.Sheets { + //fmt.Printf("Sheet Name: %s\n", sheet.Name) + for _, row := range sheet.Rows { + col := make([]string, 0) + for _, cell := range row.Cells { + //cell.SetStringFormula("%s") + if cell.IsTime() { + cell.SetFormat("yyyy-mm-dd") + } else if strings.HasPrefix(cell.Value, "0") { + cell.SetFormat("") + } + text := cell.String() + col = append(col, text) + } + colnums = append(colnums, col) + } + // 只展示第一个sheet + break + } + + return LoadRecords(colnums, options...) +} + +// WriteExcel 支持文件名和io两种方式写入数据 +func (self DataFrame) WriteExcel(filename string, options ...WriteOption) error { + filepath, err := homedir.Expand(filename) + if err != nil { + return err + } + xlFile := xlsv3.NewFile() + sheet, err := xlFile.AddSheet("Sheet(pandas)") + if err != nil { + return err + } + // Set the default write options + cfg := writeOptions{ + writeHeader: true, + } + + // Set any custom write options + for _, option := range options { + option(&cfg) + } + + records := self.Records() + if !cfg.writeHeader { + records = records[1:] + } + for _, cols := range records { + row := sheet.AddRow() + for _, col := range cols { + cell := row.AddCell() + cell.SetString(col) + } + } + + return xlFile.Save(filepath) +} diff --git a/v1/dataframe_excel_test.go b/v1/dataframe_excel_test.go new file mode 100644 index 0000000000000000000000000000000000000000..4986a2b9ab71cbc4aaa29b72df4529e65a559acd --- /dev/null +++ b/v1/dataframe_excel_test.go @@ -0,0 +1,17 @@ +package v1 + +import ( + "fmt" + "testing" +) + +func TestReadExcel(t *testing.T) { + filename := "../testfiles/test-excel-r01.xlsx" + df := ReadExcel(filename) + fmt.Println(df) + toFile := "../testfiles/test-excel-w01.xlsx" + err := df.WriteExcel(toFile) + if err != nil { + t.Errorf("write excel=%s, failed", toFile) + } +} diff --git a/v1/dataframe_fillna.go b/v1/dataframe_fillna.go new file mode 100644 index 0000000000000000000000000000000000000000..0303a54354f865cfe614ed066e87c40b8de14848 --- /dev/null +++ b/v1/dataframe_fillna.go @@ -0,0 +1,10 @@ +package v1 + +// FillNa dataframe实现FillNa +func (self DataFrame) FillNa(v any, inplace bool) { + for _, series := range self.columns { + if series.Len() > 0 { + series.FillNa(v, inplace) + } + } +} diff --git a/v1/dataframe_indexes.go b/v1/dataframe_indexes.go new file mode 100644 index 0000000000000000000000000000000000000000..e2c28594e269e4fc99dffdeee26e5d52bbd236e0 --- /dev/null +++ b/v1/dataframe_indexes.go @@ -0,0 +1,110 @@ +package v1 + +import "fmt" + +func parseSelectIndexes(l int, indexes SelectIndexes, colnames []string) ([]int, error) { + var idx []int + switch indexes.(type) { + case []int: + idx = indexes.([]int) + case int: + idx = []int{indexes.(int)} + case []bool: + bools := indexes.([]bool) + if len(bools) != l { + return nil, fmt.Errorf("indexing error: index dimensions mismatch") + } + for i, b := range bools { + if b { + idx = append(idx, i) + } + } + case string: + s := indexes.(string) + i := findInStringSlice(s, colnames) + if i < 0 { + return nil, fmt.Errorf("can't select columns: column name %q not found", s) + } + idx = append(idx, i) + case []string: + xs := indexes.([]string) + for _, s := range xs { + i := findInStringSlice(s, colnames) + if i < 0 { + return nil, fmt.Errorf("can't select columns: column name %q not found", s) + } + idx = append(idx, i) + } + //case Series: + // s := indexes.(Series) + // //if err := s.Err; err != nil { + // // return nil, fmt.Errorf("indexing error: new values has errors: %v", err) + // //} + // //if s.HasNaN() { + // // return nil, fmt.Errorf("indexing error: indexes contain NaN") + // //} + // switch s.Type() { + // case SERIES_TYPE_INT64: + // return s.Ints() + // case series.Bool: + // bools, err := s.Bool() + // if err != nil { + // return nil, fmt.Errorf("indexing error: %v", err) + // } + // return parseSelectIndexes(l, bools, colnames) + // case series.String: + // xs := indexes.(series.Series).Records() + // return parseSelectIndexes(l, xs, colnames) + // default: + // return nil, fmt.Errorf("indexing error: unknown indexing mode") + // } + default: + return nil, fmt.Errorf("indexing error: unknown indexing mode") + } + return idx, nil +} + +// SelectIndexes are the supported indexes used for the DataFrame.Select method. Currently supported are: +// +// int // Matches the given index number +// []int // Matches all given index numbers +// []bool // Matches all columns marked as true +// string // Matches the column with the matching column name +// []string // Matches all columns with the matching column names +// Series [Int] // Same as []int +// Series [Bool] // Same as []bool +// Series [String] // Same as []string +type SelectIndexes interface{} + +// Select the given DataFrame columns +func (df DataFrame) Select(indexes SelectIndexes) DataFrame { + if df.Err != nil { + return df + } + idx, err := parseSelectIndexes(df.ncols, indexes, df.Names()) + if err != nil { + return DataFrame{Err: fmt.Errorf("can't select columns: %v", err)} + } + columns := make([]Series, len(idx)) + for k, i := range idx { + if i < 0 || i >= df.ncols { + return DataFrame{Err: fmt.Errorf("can't select columns: index out of range")} + } + columns[k] = df.columns[i].Copy() + } + nrows, ncols, err := checkColumnsDimensions(columns...) + if err != nil { + return DataFrame{Err: err} + } + df = DataFrame{ + columns: columns, + ncols: ncols, + nrows: nrows, + } + colnames := df.Names() + fixColnames(colnames) + for i, colname := range colnames { + df.columns[i].Rename(colname) + } + return df +} diff --git a/v1/dataframe_join.go b/v1/dataframe_join.go new file mode 100644 index 0000000000000000000000000000000000000000..acc4a1d20a8ce09e13b94d04f5d2f4270e88b389 --- /dev/null +++ b/v1/dataframe_join.go @@ -0,0 +1,59 @@ +package v1 + +import "gitee.com/quant1x/pandas/stat" + +func (self DataFrame) align(ss ...Series) []Series { + defaultValue := []Series{} + sLen := len(ss) + if sLen == 0 { + return defaultValue + } + ls := make([]float32, sLen) + for i, v := range ss { + ls[i] = float32(v.Len()) + } + + maxLength := stat.Max(ls) + if maxLength <= 0 { + return defaultValue + } + cols := make([]Series, sLen) + for i, v := range ss { + vt := v.Type() + vn := v.Name() + vs := v.Values() + // 声明any的ns变量用于接收逻辑分支的输出 + // 切片数据不能直接对齐, 需要根据类型指定Nil和NaN默认值 + var ns any + if vt == SERIES_TYPE_BOOL { + ns = stat.Align(vs.([]bool), stat.Nil2Bool, int(maxLength)) + } else if vt == SERIES_TYPE_INT64 { + ns = stat.Align(vs.([]int64), stat.Nil2Int64, int(maxLength)) + } else if vt == SERIES_TYPE_STRING { + ns = stat.Align(vs.([]string), stat.Nil2String, int(maxLength)) + } else if vt == SERIES_TYPE_FLOAT32 { + ns = stat.Align(vs.([]float32), Nil2Float32, int(maxLength)) + } else if vt == SERIES_TYPE_FLOAT64 { + ns = stat.Align(vs.([]float64), Nil2Float64, int(maxLength)) + } + cols[i] = NewSeries(vt, vn, ns) + } + return cols +} + +// Join 默认右连接, 加入一个series +func (self DataFrame) Join(series Series) DataFrame { + if series.Len() < 0 { + return self + } + nCol := self.Ncol() + cols := make([]Series, nCol+1) + cols[len(cols)-1] = series + for i, s := range self.columns { + cols[i] = s + } + cols = self.align(cols...) + df := NewDataFrame(cols...) + self = df + return self +} diff --git a/v1/dataframe_join_test.go b/v1/dataframe_join_test.go new file mode 100644 index 0000000000000000000000000000000000000000..2b85185d31dbf57aebee769ee352d4359d5b6f9e --- /dev/null +++ b/v1/dataframe_join_test.go @@ -0,0 +1,26 @@ +package v1 + +import ( + "fmt" + "testing" +) + +func TestDataFrame_Join(t *testing.T) { + type testStruct struct { + A string + B int + C bool + D float32 + } + data := []testStruct{ + {"a", 1, true, 0.0}, + {"b", 2, false, 0.5}, + } + df1 := LoadStructs(data) + fmt.Println(df1) + + // 增加1列 + s_e := GenericSeries[string]("", "a0", "a1", "a2", "a3") + df2 := df1.Join(s_e) + fmt.Println(df2) +} diff --git a/v1/dataframe_map.go b/v1/dataframe_map.go new file mode 100644 index 0000000000000000000000000000000000000000..038ccc7019c089cac942d428b3fa71829c8d58c4 --- /dev/null +++ b/v1/dataframe_map.go @@ -0,0 +1,47 @@ +package v1 + +import ( + "fmt" + "sort" +) + +// LoadMaps creates a new DataFrame based on the given maps. This function assumes +// that every map on the array represents a row of observations. +func LoadMaps(maps []map[string]interface{}, options ...LoadOption) DataFrame { + if len(maps) == 0 { + return DataFrame{Err: fmt.Errorf("load maps: empty array")} + } + inStrSlice := func(i string, s []string) bool { + for _, v := range s { + if v == i { + return true + } + } + return false + } + // Detect all colnames + var colnames []string + for _, v := range maps { + for k := range v { + if exists := inStrSlice(k, colnames); !exists { + colnames = append(colnames, k) + } + } + } + sort.Strings(colnames) + records := make([][]string, len(maps)+1) + records[0] = colnames + for k, m := range maps { + row := make([]string, len(colnames)) + for i, colname := range colnames { + element := "" + val, ok := m[colname] + if ok { + element = fmt.Sprint(val) + } + row[i] = element + } + records[k+1] = row + } + return LoadRecords(records, options...) +} diff --git a/v1/dataframe_matrix.go b/v1/dataframe_matrix.go new file mode 100644 index 0000000000000000000000000000000000000000..e45b0be01e411ee683f4db357e5d2f285fb369f1 --- /dev/null +++ b/v1/dataframe_matrix.go @@ -0,0 +1,32 @@ +package v1 + +import "gonum.org/v1/gonum/mat" + +// LoadMatrix loads the given Matrix as a DataFrame +// TODO: Add Loadoptions +func LoadMatrix(mat mat.Matrix) DataFrame { + nrows, ncols := mat.Dims() + columns := make([]Series, ncols) + for i := 0; i < ncols; i++ { + floats := make([]float64, nrows) + for j := 0; j < nrows; j++ { + floats[j] = mat.At(j, i) + } + columns[i] = NewSeries(SERIES_TYPE_FLOAT64, "", floats) + } + nrows, ncols, err := checkColumnsDimensions(columns...) + if err != nil { + return DataFrame{Err: err} + } + df := DataFrame{ + columns: columns, + ncols: ncols, + nrows: nrows, + } + colnames := df.Names() + fixColnames(colnames) + for i, colname := range colnames { + df.columns[i].Rename(colname) + } + return df +} diff --git a/v1/dataframe_options.go b/v1/dataframe_options.go new file mode 100644 index 0000000000000000000000000000000000000000..a491d1ea55fdccfc4482e24e312e01da4e3933b3 --- /dev/null +++ b/v1/dataframe_options.go @@ -0,0 +1,110 @@ +package v1 + +type loadOptions struct { + // Specifies which is the default type in case detectTypes is disabled. + defaultType Type + + // If set, the type of each column will be automatically detected unless + // otherwise specified. + detectTypes bool + + // If set, the first row of the tabular structure will be used as column + // names. + hasHeader bool + + // The names to set as columns names. + names []string + + // Defines which values are going to be considered as NaN when parsing from string. + nanValues []string + + // Defines the csv delimiter + delimiter rune + + // EnablesLazyQuotes + lazyQuotes bool + + // Defines the comment delimiter + comment rune + + // The types of specific columns can be specified via column name. + types map[string]Type +} + +// DefaultType sets the defaultType option for loadOptions. +func DefaultType(t Type) LoadOption { + return func(c *loadOptions) { + c.defaultType = t + } +} + +// DetectTypes sets the detectTypes option for loadOptions. +func DetectTypes(b bool) LoadOption { + return func(c *loadOptions) { + c.detectTypes = b + } +} + +// HasHeader sets the hasHeader option for loadOptions. +func HasHeader(b bool) LoadOption { + return func(c *loadOptions) { + c.hasHeader = b + } +} + +// Names sets the names option for loadOptions. +func Names(names ...string) LoadOption { + return func(c *loadOptions) { + c.names = names + } +} + +// NaNValues sets the nanValues option for loadOptions. +func NaNValues(nanValues []string) LoadOption { + return func(c *loadOptions) { + c.nanValues = nanValues + } +} + +// WithTypes sets the types option for loadOptions. +func WithTypes(coltypes map[string]Type) LoadOption { + return func(c *loadOptions) { + c.types = coltypes + } +} + +// WithDelimiter sets the csv delimiter other than ',', for example '\t' +func WithDelimiter(b rune) LoadOption { + return func(c *loadOptions) { + c.delimiter = b + } +} + +// WithLazyQuotes sets csv parsing option to LazyQuotes +func WithLazyQuotes(b bool) LoadOption { + return func(c *loadOptions) { + c.lazyQuotes = b + } +} + +// WithComments sets the csv comment line detect to remove lines +func WithComments(b rune) LoadOption { + return func(c *loadOptions) { + c.comment = b + } +} + +// WriteOption is the type used to configure the writing of elements +type WriteOption func(*writeOptions) + +type writeOptions struct { + // Specifies whether the header is also written + writeHeader bool +} + +// WriteHeader sets the writeHeader option for writeOptions. +func WriteHeader(b bool) WriteOption { + return func(c *writeOptions) { + c.writeHeader = b + } +} diff --git a/v1/dataframe_records.go b/v1/dataframe_records.go new file mode 100644 index 0000000000000000000000000000000000000000..6c9a2e333aaf967b87ecefb17543b34ac0842f1f --- /dev/null +++ b/v1/dataframe_records.go @@ -0,0 +1,98 @@ +package v1 + +import ( + "fmt" + "gitee.com/quant1x/pandas/stat" +) + +// LoadRecords creates a new DataFrame based on the given records. +// 这个方法是从本地缓存文件读取数据的第二步, 数据从形式上只能是字符串 +func LoadRecords(records [][]string, options ...LoadOption) DataFrame { + // Set the default load options + cfg := loadOptions{ + defaultType: SERIES_TYPE_STRING, + detectTypes: true, + hasHeader: true, + nanValues: stat.PossibleNaOfString, + } + + // Set any custom load options + for _, option := range options { + option(&cfg) + } + + if len(records) == 0 { + return DataFrame{Err: fmt.Errorf("load records: empty DataFrame")} + } + if cfg.hasHeader && len(records) <= 1 { + return DataFrame{Err: fmt.Errorf("load records: empty DataFrame")} + } + if cfg.names != nil && len(cfg.names) != len(records[0]) { + if len(cfg.names) > len(records[0]) { + return DataFrame{Err: fmt.Errorf("load records: too many column names")} + } + return DataFrame{Err: fmt.Errorf("load records: not enough column names")} + } + + // Extract headers + headers := make([]string, len(records[0])) + if cfg.hasHeader { + headers = records[0] + records = records[1:] + } + if cfg.names != nil { + headers = cfg.names + } + + types := make([]Type, len(headers)) + rawcols := make([][]string, len(headers)) + for i, colname := range headers { + rawcol := make([]string, len(records)) + for j := 0; j < len(records); j++ { + rawcol[j] = records[j][i] + // 收敛NaN的情况, 统一替换为NaN + if findInStringSlice(rawcol[j], cfg.nanValues) != -1 { + rawcol[j] = "NaN" + } + } + rawcols[i] = rawcol + + t, ok := cfg.types[colname] + if !ok { + t = cfg.defaultType + if cfg.detectTypes { + if l, err := findTypeByString(rawcol); err == nil { + t = l + } + } + } + types[i] = t + } + + columns := make([]Series, len(headers)) + for i, colname := range headers { + cols := rawcols[i] + col := NewSeries(types[i], colname, cols) + //col := NewSeriesWithType(types[i], colname, cols) + //if col.Err != nil { + // return DataFrame{Err: col.Err} + //} + columns[i] = col + } + nrows, ncols, err := checkColumnsDimensions(columns...) + if err != nil { + return DataFrame{Err: err} + } + df := DataFrame{ + columns: columns, + ncols: ncols, + nrows: nrows, + } + + colnames := df.Names() + fixColnames(colnames) + for i, colname := range colnames { + df.columns[i].Rename(colname) + } + return df +} diff --git a/v1/dataframe_remove.go b/v1/dataframe_remove.go new file mode 100644 index 0000000000000000000000000000000000000000..46f456d9844d71b3ed83c03e602d98c73b657980 --- /dev/null +++ b/v1/dataframe_remove.go @@ -0,0 +1,28 @@ +package v1 + +import "gitee.com/quant1x/pandas/stat" + +// Remove 删除一段范围内的记录 +func (self DataFrame) Remove(p stat.ScopeLimit) DataFrame { + rowLen := self.Nrow() + start, end, err := p.Limits(rowLen) + if err != nil { + return self + } + columns := []Series{} + for i := range self.columns { + ht := self.columns[i].Subset(0, start, true) + tail := self.columns[i].Subset(end+1, rowLen).Values() + ht.Append(tail) + columns = append(columns, ht) + } + nrows, ncols, err := checkColumnsDimensions(columns...) + if err != nil { + return DataFrame{Err: err} + } + return DataFrame{ + columns: columns, + ncols: ncols, + nrows: nrows, + } +} diff --git a/v1/dataframe_remove_test.go b/v1/dataframe_remove_test.go new file mode 100644 index 0000000000000000000000000000000000000000..63258ffeedb16f6eb49dc92ab2a91fa347c184cc --- /dev/null +++ b/v1/dataframe_remove_test.go @@ -0,0 +1,31 @@ +package v1 + +import ( + "fmt" + "gitee.com/quant1x/pandas/stat" + "testing" +) + +func TestDataFrame_Remove(t *testing.T) { + type testStruct struct { + A string + B int + C bool + D float64 + } + data := []testStruct{ + {"a", 1, true, 0.0}, + {"b", 2, false, 0.5}, + } + df1 := LoadStructs(data) + fmt.Println(df1) + + // 增加1列 + s_e := GenericSeries[string]("x", "a0", "a1", "a2", "a3", "a4") + df2 := df1.Join(s_e) + fmt.Println(df2) + r := stat.RangeFinite(3, 3) + df3 := df2.Remove(r) + fmt.Println(df3) + +} diff --git a/v1/dataframe_select.go b/v1/dataframe_select.go new file mode 100644 index 0000000000000000000000000000000000000000..ffbdb772fb135e26a9ab71013f7f7f2e12517193 --- /dev/null +++ b/v1/dataframe_select.go @@ -0,0 +1,39 @@ +package v1 + +import "fmt" + +// Col returns a copy of the Series with the given column name contained in the DataFrame. +// 选取一列 +func (self DataFrame) Col(colname string) Series { + if self.Err != nil { + return NewSeriesWithType(SERIES_TYPE_INVAILD, "") + } + // Check that colname exist on dataframe + idx := findInStringSlice(colname, self.Names()) + if idx < 0 { + return NewSeriesWithType(SERIES_TYPE_INVAILD, "") + } + return self.columns[idx].Copy() +} + +// SetNames changes the column names of a DataFrame to the ones passed as an +// argument. +// 修改全部的列名 +func (self DataFrame) SetNames(colnames ...string) error { + if len(colnames) != self.ncols { + return fmt.Errorf("setting names: wrong dimensions") + } + for k, s := range colnames { + self.columns[k].Rename(s) + } + return nil +} + +// SetName 修改一个series的名称 +func (self DataFrame) SetName(from string, to string) { + for _, s := range self.columns { + if s.Name() == from { + s.Rename(to) + } + } +} diff --git a/v1/dataframe_struct.go b/v1/dataframe_struct.go new file mode 100644 index 0000000000000000000000000000000000000000..dc0dfad21ab8b294e298543428781598ce5964e7 --- /dev/null +++ b/v1/dataframe_struct.go @@ -0,0 +1,156 @@ +package v1 + +import ( + "fmt" + "gitee.com/quant1x/pandas/stat" + "reflect" + "strings" +) + +// LoadStructs creates a new DataFrame from arbitrary struct slices. +// +// LoadStructs will ignore unexported fields inside an struct. Note also that +// unless otherwise specified the column names will correspond with the name of +// the field. +// +// You can configure each field with the `dataframe:"name[,type]"` struct +// tag. If the name on the tag is the empty string `""` the field name will be +// used instead. If the name is `"-"` the field will be ignored. +// +// Examples: +// +// // field will be ignored +// field int +// +// // Field will be ignored +// Field int `dataframe:"-"` +// +// // Field will be parsed with column name Field and type int +// Field int +// +// // Field will be parsed with column name `field_column` and type int. +// Field int `dataframe:"field_column"` +// +// // Field will be parsed with column name `field` and type string. +// Field int `dataframe:"field,string"` +// +// // Field will be parsed with column name `Field` and type string. +// Field int `dataframe:",string"` +// +// If the struct tags and the given LoadOptions contradict each other, the later +// will have preference over the former. +func LoadStructs(i interface{}, options ...LoadOption) DataFrame { + if i == nil { + return DataFrame{Err: fmt.Errorf("load: can't create DataFrame from value")} + } + + // Set the default load options + cfg := loadOptions{ + defaultType: SERIES_TYPE_STRING, + detectTypes: true, + hasHeader: true, + nanValues: stat.PossibleNaOfString, + } + + // Set any custom load options + for _, option := range options { + option(&cfg) + } + + tpy, val := reflect.TypeOf(i), reflect.ValueOf(i) + switch tpy.Kind() { + case reflect.Slice: + if tpy.Elem().Kind() != reflect.Struct { + return DataFrame{Err: fmt.Errorf( + "load: type %s (%s %s) is not supported, must be []struct", tpy.Name(), tpy.Elem().Kind(), tpy.Kind())} + } + if val.Len() == 0 { + return DataFrame{Err: fmt.Errorf("load: can't create DataFrame from empty slice")} + } + + numFields := val.Index(0).Type().NumField() + var columns []Series + for j := 0; j < numFields; j++ { + // Extract field metadata + if !val.Index(0).Field(j).CanInterface() { + continue + } + field := val.Index(0).Type().Field(j) + fieldName := field.Name + fieldType := field.Type.String() + + // Process struct tags + fieldTags := field.Tag.Get("dataframe") + if fieldTags == "-" { + continue + } + tagOpts := strings.Split(fieldTags, ",") + if len(tagOpts) > 2 { + return DataFrame{Err: fmt.Errorf("malformed struct tag on field %s: %s", fieldName, fieldTags)} + } + if len(tagOpts) > 0 { + if name := strings.TrimSpace(tagOpts[0]); name != "" { + fieldName = name + } + if len(tagOpts) == 2 { + if tagType := strings.TrimSpace(tagOpts[1]); tagType != "" { + fieldType = tagType + } + } + } + + // Handle `types` option + var t Type + if cfgtype, ok := cfg.types[fieldName]; ok { + t = cfgtype + } else { + // Handle `detectTypes` option + if cfg.detectTypes { + // Parse field type + parsedType, err := parseType(fieldType) + if err != nil { + return DataFrame{Err: err} + } + t = parsedType + } else { + t = cfg.defaultType + } + } + + // Create Series for this field + elements := make([]interface{}, val.Len()) + for i := 0; i < val.Len(); i++ { + fieldValue := val.Index(i).Field(j) + elements[i] = fieldValue.Interface() + + // Handle `nanValues` option + if findInStringSlice(fmt.Sprint(elements[i]), cfg.nanValues) != -1 { + elements[i] = nil + } + } + + // Handle `hasHeader` option + if !cfg.hasHeader { + tmp := make([]interface{}, 1) + tmp[0] = fieldName + elements = append(tmp, elements...) + fieldName = "" + } + if t == SERIES_TYPE_STRING { + columns = append(columns, NewSeries(SERIES_TYPE_STRING, fieldName, elements)) + } else if t == SERIES_TYPE_BOOL { + columns = append(columns, NewSeries(SERIES_TYPE_BOOL, fieldName, elements)) + } else if t == SERIES_TYPE_INT64 { + columns = append(columns, NewSeries(SERIES_TYPE_INT64, fieldName, elements)) + } else if t == SERIES_TYPE_FLOAT32 { + columns = append(columns, NewSeries(SERIES_TYPE_FLOAT32, fieldName, elements)) + } else { + // 默认float + columns = append(columns, NewSeries(SERIES_TYPE_FLOAT64, fieldName, elements)) + } + } + return NewDataFrame(columns...) + } + return DataFrame{Err: fmt.Errorf( + "load: type %s (%s) is not supported, must be []struct", tpy.Name(), tpy.Kind())} +} diff --git a/v1/dataframe_subset.go b/v1/dataframe_subset.go new file mode 100644 index 0000000000000000000000000000000000000000..85f7bcde15359fece0df224c607c4f2ba08caf2d --- /dev/null +++ b/v1/dataframe_subset.go @@ -0,0 +1,43 @@ +package v1 + +import "gitee.com/quant1x/pandas/stat" + +// Subset returns a subset of the rows of the original DataFrame based on the +// Series subsetting indexes. +func (self DataFrame) Subset(start, end int) DataFrame { + if self.Err != nil { + return self + } + columns := make([]Series, self.ncols) + for i, column := range self.columns { + s := column.Subset(start, end) + columns[i] = s + } + nrows, ncols, err := checkColumnsDimensions(columns...) + if err != nil { + return DataFrame{Err: err} + } + return DataFrame{ + columns: columns, + ncols: ncols, + nrows: nrows, + } +} + +// Select 选择一段记录 +func (self DataFrame) SelectRows(p stat.ScopeLimit) DataFrame { + columns := []Series{} + for i := range self.columns { + columns = append(columns, self.columns[i].Select(p)) + } + nrows, ncols, err := checkColumnsDimensions(columns...) + if err != nil { + return DataFrame{Err: err} + } + newDF := DataFrame{ + columns: columns, + ncols: ncols, + nrows: nrows, + } + return newDF +} diff --git a/v1/dataframe_test.go b/v1/dataframe_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0d50140b31af2c7371bf0e5f51fc727bd7eea3c1 --- /dev/null +++ b/v1/dataframe_test.go @@ -0,0 +1,54 @@ +package v1 + +import ( + "fmt" + "testing" +) + +func TestDataFrameT0(t *testing.T) { + var s1 Series + s1 = NewSeriesFloat64("sales", nil, 50.3, 23.4, 56.2) + fmt.Println(s1) + expected := 4 + + if s1.Len() != expected { + t.Errorf("wrong val: expected: %v actual: %v", expected, s1.Len()) + } + s2 := s1.Shift(-2) + df := NewDataFrame(s1, s2) + fmt.Println(df) + df.FillNa(0.00, true) + fmt.Println(df) + + _ = s2 +} + +func TestLoadStructs(t *testing.T) { + type testStruct struct { + A string + B int + C bool + D float64 + } + type testStructTags struct { + A string `dataframe:"a,string"` + B int `dataframe:"b,string"` + C bool `dataframe:"c,string"` + D float64 `dataframe:"d,string"` + E int `dataframe:"-"` // ignored + f int // ignored + } + data := []testStruct{ + {"a", 1, true, 0.0}, + {"b", 2, false, 0.5}, + } + dataTags := []testStructTags{ + {"a", 1, true, 0.0, 0, 0}, + {"NA", 2, false, 0.5, 1, 3}, + {"NA", 3, false, 1.5, 2, 4}, + } + df1 := LoadStructs(data) + fmt.Println(df1) + df2 := LoadStructs(dataTags) + fmt.Println(df2) +} diff --git a/v1/dataframe_xstring.go b/v1/dataframe_xstring.go new file mode 100644 index 0000000000000000000000000000000000000000..58122e2ce9a2333f68d661595d24184858dfca7a --- /dev/null +++ b/v1/dataframe_xstring.go @@ -0,0 +1,177 @@ +package v1 + +import ( + "fmt" + "strconv" + "strings" + "unicode/utf8" +) + +// String implements the Stringer interface for DataFrame +func (self DataFrame) String() (str string) { + return self.print(true, false, true, true, 10, 70, "DataFrame") +} + +func (self DataFrame) print( + shortRows, shortCols, showDims, showTypes bool, + maxRows int, + maxCharsTotal int, + class string) (str string) { + + addRightPadding := func(s string, nchar int) string { + if utf8.RuneCountInString(s) < nchar { + return s + strings.Repeat(" ", nchar-utf8.RuneCountInString(s)) + } + return s + } + + addLeftPadding := func(s string, nchar int) string { + if utf8.RuneCountInString(s) < nchar { + return strings.Repeat(" ", nchar-utf8.RuneCountInString(s)) + s + } + return s + } + + if self.Err != nil { + str = fmt.Sprintf("%s error: %v", class, self.Err) + return + } + nMinRows := int(maxRows / 2) + nTotal := 0 + nrows, ncols := self.Dims() + if nrows == 0 || ncols == 0 { + str = fmt.Sprintf("Empty %s", class) + return + } + var records [][]string + shortening := false + if shortRows && nrows > maxRows { + shortening = true + dfHead := self.Subset(0, nMinRows) + records = dfHead.Records() + nTotal += dfHead.Nrow() + if shortening { + dots := make([]string, ncols) + for i := 0; i < ncols; i++ { + dots[i] = "..." + } + records = append(records, dots) + } + nTotal += 1 + dfTail := self.Subset(nrows-nMinRows, nrows) + tails := dfTail.Records() + nTotal += dfTail.Nrow() + records = append(records, tails[1:]...) + } else { + records = self.Records() + nTotal += self.Nrow() + } + + if showDims { + str += fmt.Sprintf("[%dx%d] %s\n\n", nrows, ncols, class) + } + + // Add the row numbers + for i := 0; i < nTotal+1; /*self.nrows+1*/ i++ { + add := "" + if i == 0 || (i == nMinRows+1 && shortening) { + // 跳过 + } else if i < nMinRows+1 { + add = strconv.Itoa(i-1) + ":" + } else { + add = strconv.Itoa(nrows-maxRows+i-1) + ":" + } + //fmt.Println(i) + records[i] = append([]string{add}, records[i]...) + } + //if shortening { + // dots := make([]string, ncols+1) + // for i := 1; i < ncols+1; i++ { + // dots[i] = "..." + // } + // records = append(records, dots) + //} + types := self.Types() + typesrow := make([]string, ncols) + for i := 0; i < ncols; i++ { + typesrow[i] = fmt.Sprintf("<%v>", types[i]) + } + typesrow = append([]string{""}, typesrow...) + + if showTypes { + records = append(records, typesrow) + } + + maxChars := make([]int, self.ncols+1) + for i := 0; i < len(records); i++ { + for j := 0; j < self.ncols+1; j++ { + // Escape special characters + records[i][j] = strconv.Quote(records[i][j]) + records[i][j] = records[i][j][1 : len(records[i][j])-1] + + // Detect maximum number of characters per column + if len(records[i][j]) > maxChars[j] { + maxChars[j] = utf8.RuneCountInString(records[i][j]) + } + } + } + maxCols := len(records[0]) + var notShowing []string + if shortCols { + maxCharsCum := 0 + for colnum, m := range maxChars { + maxCharsCum += m + if maxCharsCum > maxCharsTotal { + maxCols = colnum + break + } + } + notShowingNames := records[0][maxCols:] + notShowingTypes := typesrow[maxCols:] + notShowing = make([]string, len(notShowingNames)) + for i := 0; i < len(notShowingNames); i++ { + notShowing[i] = fmt.Sprintf("%s %s", notShowingNames[i], notShowingTypes[i]) + } + } + for i := 0; i < len(records); i++ { + // Add right padding to all elements + records[i][0] = addLeftPadding(records[i][0], maxChars[0]+1) + for j := 1; j < self.ncols; j++ { + records[i][j] = addRightPadding(records[i][j], maxChars[j]) + } + records[i] = records[i][0:maxCols] + if shortCols && len(notShowing) != 0 { + records[i] = append(records[i], "...") + } + // Create the final string + str += strings.Join(records[i], " ") + str += "\n" + } + // 没有显示字段处理逻辑 + if shortCols && len(notShowing) != 0 { + var notShown string + var notShownArr [][]string + cum := 0 + i := 0 + for n, ns := range notShowing { + cum += len(ns) + if cum > maxCharsTotal { + notShownArr = append(notShownArr, notShowing[i:n]) + cum = 0 + i = n + } + } + if i < len(notShowing) { + notShownArr = append(notShownArr, notShowing[i:]) + } + for k, ns := range notShownArr { + notShown += strings.Join(ns, ", ") + if k != len(notShownArr)-1 { + notShown += "," + } + notShown += "\n" + } + str += fmt.Sprintf("\nNot Showing: %s", notShown) + } + return str +} diff --git a/v1/generic.go b/v1/generic.go new file mode 100644 index 0000000000000000000000000000000000000000..be1e7605a2521e9138908a64838e097646b82388 --- /dev/null +++ b/v1/generic.go @@ -0,0 +1,345 @@ +package v1 + +import ( + "gitee.com/quant1x/pandas/stat" + gs "gonum.org/v1/gonum/stat" + "reflect" + "sync" +) + +// NDFrame 这里本意是想做一个父类, 实际的效果是一个抽象类 +type NDFrame struct { + lock sync.RWMutex // 读写锁 + formatter stat.StringFormatter // 字符串格式化工具 + name string // 帧名称 + type_ Type // values元素类型 + copy_ bool // 是否副本 + nilCount int // nil和nan的元素有多少, 这种统计在bool和int64类型中不会大于0, 只对float64及string有效 + rows int // 行数 + values any // 只能是一个一维slice, 在所有的运算中, values强制转换成float64切片 + +} + +//""" +//N-dimensional analogue of DataFrame. Store multi-dimensional in a +//size-mutable, labeled data structure +// +//Parameters +//---------- +//data : BlockManager +//axes : list +//copy : bool, default False +//""" + +func NewNDFrame[E stat.GenericType](name string, rows ...E) *NDFrame { + frame := NDFrame{ + formatter: stat.DefaultFormatter, + name: name, + type_: SERIES_TYPE_INVAILD, + nilCount: 0, + rows: 0, + values: []E{}, + } + // TODO: 不知道rows是否存在全部为空的情况, 只能先创建一个空的slice + frame.values = make([]E, 0) // Warning: filled with 0.0 (not NaN) + // 这个地方可以放心的强制转换, E已经做了类型约束 + size := len(frame.values.([]E)) + for idx, v := range rows { + assign(&frame, idx, size, v) + } + + return &frame +} + +// 赋值 +func assign[T stat.GenericType](frame *NDFrame, idx, size int, v T) { + // 检测类型 + if frame.type_ == SERIES_TYPE_INVAILD { + _type, _ := detectTypes(v) + if _type != SERIES_TYPE_INVAILD { + frame.type_ = _type + } + } + _vv := reflect.ValueOf(v) + _vi := _vv.Interface() + // float和string类型有可能是NaN, 对nil和NaN进行计数 + if frame.Type() == SERIES_TYPE_FLOAT32 && stat.Float32IsNaN(_vi.(float32)) { + frame.nilCount++ + } else if frame.Type() == SERIES_TYPE_FLOAT64 && stat.Float64IsNaN(_vi.(float64)) { + frame.nilCount++ + } else if frame.Type() == SERIES_TYPE_STRING && stat.StringIsNaN(_vi.(string)) { + frame.nilCount++ + // 以下修正string的NaN值, 统一为"NaN" + //_rv := reflect.ValueOf(StringNaN) + //_vv.Set(_rv) // 这样赋值会崩溃 + // TODO:值可修改条件之一: 可被寻址 + // 通过反射修改变量值的前提条件之一: 这个值必须可以被寻址, 简单地说就是这个变量必须能被修改. + // 第一步: 通过变量v反射(v的地址) + _vp := reflect.ValueOf(&v) + // 第二步: 取出v地址的元素(v的值) + _vv := _vp.Elem() + // 判断_vv是否能被修改 + if _vv.CanSet() { + // 修改v的值为新值 + _vv.SetString(stat.StringNaN) + // 执行之后, 通过debug可以看到assign入参的v已经变成了"NaN" + } + } + // 确保只添加了1个元素 + if idx < size { + frame.values.([]T)[idx] = v + } else { + frame.values = append(frame.values.([]T), v) + } + // 行数+1 + frame.rows += 1 +} + +// Repeat 重复生成a +func Repeat[T stat.GenericType](a T, n int) []T { + dst := make([]T, n) + for i := 0; i < n; i++ { + dst[i] = a + } + return dst +} + +// Repeat2 重复生成a +func Repeat2[T stat.GenericType](dst []T, a T, n int) []T { + for i := 0; i < n; i++ { + dst[i] = a + } + return dst +} + +func (self *NDFrame) Name() string { + return self.name +} + +func (self *NDFrame) Rename(n string) { + self.name = n +} + +func (self *NDFrame) Type() Type { + return self.type_ +} + +func (self *NDFrame) Values() any { + return self.values +} + +// NaN 输出默认的NaN +func (self *NDFrame) NaN() any { + switch self.values.(type) { + case []bool: + return stat.BoolNaN + case []string: + return stat.StringNaN + case []int64: + return stat.Nil2Int64 + case []float32: + return stat.Nil2Float32 + case []float64: + return stat.Nil2Float64 + default: + panic(ErrUnsupportedType) + } +} + +func (self *NDFrame) Float() []float32 { + return stat.SliceToFloat32(self.values) +} + +// DTypes 计算以这个函数为主 +func (self *NDFrame) DTypes() []stat.DType { + return stat.Slice2DType(self.Values()) +} + +// AsInt 强制转换成整型 +func (self *NDFrame) AsInt() []stat.Int { + values := self.DTypes() + fs := stat.Fill[stat.DType](values, stat.DType(0)) + ns := stat.DType2Int(fs) + return ns +} + +func (self *NDFrame) Empty() Series { + var frame NDFrame + if self.type_ == stat.SERIES_TYPE_STRING { + frame = NDFrame{ + formatter: self.formatter, + name: self.name, + type_: self.type_, + nilCount: 0, + rows: 0, + values: []string{}, + } + } else if self.type_ == stat.SERIES_TYPE_BOOL { + frame = NDFrame{ + formatter: self.formatter, + name: self.name, + type_: self.type_, + nilCount: 0, + rows: 0, + values: []bool{}, + } + } else if self.type_ == stat.SERIES_TYPE_INT64 { + frame = NDFrame{ + formatter: self.formatter, + name: self.name, + type_: self.type_, + nilCount: 0, + rows: 0, + values: []int64{}, + } + } else if self.type_ == stat.SERIES_TYPE_FLOAT32 { + frame = NDFrame{ + formatter: self.formatter, + name: self.name, + type_: self.type_, + nilCount: 0, + rows: 0, + values: []float32{}, + } + } else if self.type_ == stat.SERIES_TYPE_FLOAT64 { + frame = NDFrame{ + formatter: self.formatter, + name: self.name, + type_: self.type_, + nilCount: 0, + rows: 0, + values: []float64{}, + } + } else { + panic(ErrUnsupportedType) + } + return &frame +} + +func (self *NDFrame) Records() []string { + ret := make([]string, self.Len()) + self.Apply(func(idx int, v any) { + ret[idx] = stat.AnyToString(v) + }) + return ret +} + +func (self *NDFrame) Repeat(x any, repeats int) Series { + switch values := self.values.(type) { + case []bool: + _ = values + vs := Repeat(stat.AnyToBool(x), repeats) + return NewNDFrame(self.name, vs...) + case []string: + vs := Repeat(stat.AnyToString(x), repeats) + return NewNDFrame(self.name, vs...) + case []int64: + vs := Repeat(stat.AnyToInt64(x), repeats) + return NewNDFrame(self.name, vs...) + case []float32: + vs := Repeat(stat.AnyToFloat32(x), repeats) + return NewNDFrame(self.name, vs...) + default: //case []float64: + vs := Repeat(stat.AnyToFloat64(x), repeats) + return NewNDFrame(self.name, vs...) + } +} + +func (self *NDFrame) Shift(periods int) Series { + var d Series + d = clone(self).(Series) + //return Shift[float64](&d, periods, func() float64 { + // return Nil2Float64 + //}) + switch values := self.values.(type) { + case []bool: + _ = values + return Shift[bool](&d, periods, func() bool { + return stat.BoolNaN + }) + case []string: + return Shift[string](&d, periods, func() string { + return stat.StringNaN + }) + case []int64: + return Shift[int64](&d, periods, func() int64 { + return stat.Nil2Int64 + }) + case []float32: + return Shift[float32](&d, periods, func() float32 { + return Nil2Float32 + }) + default: //case []float64: + return Shift[float64](&d, periods, func() float64 { + return Nil2Float64 + }) + } +} + +func (self *NDFrame) Mean() stat.DType { + if self.Len() < 1 { + return NaN() + } + fs := make([]stat.DType, 0) + self.Apply(func(idx int, v any) { + f := stat.Any2DType(v) + fs = append(fs, f) + }) + stdDev := stat.Mean(fs) + return stdDev +} + +func (self *NDFrame) StdDev() stat.DType { + if self.Len() < 1 { + return NaN() + } + values := make([]stat.DType, self.Len()) + self.Apply(func(idx int, v any) { + values[idx] = stat.Any2DType(v) + }) + stdDev := gs.StdDev(values, nil) + return stdDev +} + +func (self *NDFrame) Std() stat.DType { + if self.Len() < 1 { + return NaN() + } + values := make([]stat.DType, self.Len()) + self.Apply(func(idx int, v any) { + values[idx] = stat.Any2DType(v) + }) + stdDev := stat.Std(values) + return stdDev +} + +func (self *NDFrame) FillNa(v any, inplace bool) Series { + values := self.Values() + switch rows := values.(type) { + case []string: + for idx, iv := range rows { + if stat.StringIsNaN(iv) && inplace { + rows[idx] = stat.AnyToString(v) + } + } + case []int64: + for idx, iv := range rows { + if stat.Float64IsNaN(float64(iv)) && inplace { + rows[idx] = stat.AnyToInt64(v) + } + } + case []float32: + for idx, iv := range rows { + if stat.Float32IsNaN(iv) && inplace { + rows[idx] = stat.AnyToFloat32(v) + } + } + case []float64: + for idx, iv := range rows { + if stat.Float64IsNaN(iv) && inplace { + rows[idx] = stat.AnyToFloat64(v) + } + } + } + return self +} diff --git a/v1/generic_append.go b/v1/generic_append.go new file mode 100644 index 0000000000000000000000000000000000000000..e753d1855334e05a47db09af8e5ab8c05a7dfb5c --- /dev/null +++ b/v1/generic_append.go @@ -0,0 +1,54 @@ +package v1 + +import ( + "gitee.com/quant1x/pandas/stat" + "reflect" +) + +// 插入一条记录 +func (self *NDFrame) insert(idx, size int, v any) { + if self.type_ == SERIES_TYPE_BOOL { + val := stat.AnyToBool(v) + assign[bool](self, idx, size, val) + } else if self.type_ == SERIES_TYPE_INT64 { + val := stat.AnyToInt64(v) + assign[int64](self, idx, size, val) + } else if self.type_ == SERIES_TYPE_FLOAT32 { + val := stat.AnyToFloat32(v) + assign[float32](self, idx, size, val) + } else if self.type_ == SERIES_TYPE_FLOAT64 { + val := stat.AnyToFloat64(v) + assign[float64](self, idx, size, val) + } else { + val := stat.AnyToString(v) + assign[string](self, idx, size, val) + } +} + +// Append 批量增加记录 +func (self *NDFrame) Append(values ...any) { + size := 0 + for idx, v := range values { + switch val := v.(type) { + case nil, int8, uint8, int16, uint16, int32, uint32, int64, uint64, int, uint, float32, float64, bool, string: + // 基础类型 + self.insert(idx, size, val) + default: + vv := reflect.ValueOf(val) + vk := vv.Kind() + switch vk { + //case reflect.Invalid: // {interface} nil + // series.assign(idx, size, Nil2Float64) + case reflect.Slice, reflect.Array: // 切片或数组 + for i := 0; i < vv.Len(); i++ { + tv := vv.Index(i).Interface() + self.insert(idx, size, tv) + } + case reflect.Struct: // 忽略结构体 + continue + default: + self.insert(idx, size, nil) + } + } + } +} diff --git a/v1/generic_apply.go b/v1/generic_apply.go new file mode 100644 index 0000000000000000000000000000000000000000..4a8596ab998bfa238164aada7894a46e41733195 --- /dev/null +++ b/v1/generic_apply.go @@ -0,0 +1,47 @@ +package v1 + +import "reflect" + +func (self *NDFrame) Apply(f func(idx int, v any)) { + vv := reflect.ValueOf(self.values) + vk := vv.Kind() + switch vk { + case reflect.Invalid: // {interface} nil + //series.assign(idx, size, Nil2Float64) + case reflect.Slice: // 切片, 不定长 + for i := 0; i < vv.Len(); i++ { + tv := vv.Index(i).Interface() + f(i, tv) + } + case reflect.Array: // 数组, 定长 + for i := 0; i < vv.Len(); i++ { + tv := vv.Index(i).Interface() + f(i, tv) + } + default: + // 其它类型忽略 + } +} + +func (self *NDFrame) Logic(f func(idx int, v any) bool) []bool { + x := make([]bool, self.Len()) + vv := reflect.ValueOf(self.values) + vk := vv.Kind() + switch vk { + case reflect.Invalid: // {interface} nil + //series.assign(idx, size, Nil2Float64) + case reflect.Slice: // 切片, 不定长 + for i := 0; i < vv.Len(); i++ { + tv := vv.Index(i).Interface() + x[i] = f(i, tv) + } + case reflect.Array: // 数组, 定长 + for i := 0; i < vv.Len(); i++ { + tv := vv.Index(i).Interface() + x[i] = f(i, tv) + } + default: + // 其它类型忽略 + } + return x +} diff --git a/v1/generic_diff.go b/v1/generic_diff.go new file mode 100644 index 0000000000000000000000000000000000000000..bd3014634d77a132d9d3ae136b977331a39aa290 --- /dev/null +++ b/v1/generic_diff.go @@ -0,0 +1,55 @@ +package v1 + +import ( + "gitee.com/quant1x/pandas/stat" + "reflect" +) + +// Diff 元素的第一个离散差 +// First discrete difference of element. +// Calculates the difference of a {klass} element compared with another +// element in the {klass} (default is element in previous row). +func (self *NDFrame) Diff(param any) (s Series) { + if !(self.type_ == SERIES_TYPE_INT64 || self.type_ == SERIES_TYPE_FLOAT32 || self.type_ == SERIES_TYPE_FLOAT64) { + return NewSeries(SERIES_TYPE_INVAILD, "", "") + } + var N []stat.DType + switch v := param.(type) { + case int: + N = stat.Repeat[stat.DType](stat.DType(v), self.Len()) + case Series: + vs := v.DTypes() + N = stat.Align(vs, stat.DTypeNaN, self.Len()) + default: + //periods = 1 + N = stat.Repeat[stat.DType](stat.DType(1), self.Len()) + } + r := RollingAndExpandingMixin{ + window: N, + series: self, + } + var d []stat.DType + var front = stat.DTypeNaN + for _, block := range r.getBlocks() { + vs := reflect.ValueOf(block.Values()) + vl := vs.Len() + if vl == 0 { + d = append(d, stat.DTypeNaN) + continue + } + vf := vs.Index(0).Interface() + vc := vs.Index(vl - 1).Interface() + cu := stat.Any2DType(vc) + cf := stat.Any2DType(vf) + if stat.DTypeIsNaN(cu) || stat.DTypeIsNaN(front) { + front = cf + d = append(d, stat.DTypeNaN) + continue + } + diff := cu - front + d = append(d, diff) + front = cf + } + s = NewSeries(SERIES_TYPE_DTYPE, r.series.Name(), d) + return +} diff --git a/v1/generic_diff_test.go b/v1/generic_diff_test.go new file mode 100644 index 0000000000000000000000000000000000000000..b940e0487a94f971896f2fa9d292b90be271942e --- /dev/null +++ b/v1/generic_diff_test.go @@ -0,0 +1,24 @@ +package v1 + +import ( + "fmt" + "testing" +) + +func TestNDFrame_Diff(t *testing.T) { + d1 := []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} + s1 := NewNDFrame[float64]("x", d1...) + df := NewDataFrame(s1) + fmt.Println(df) + fmt.Println("------------------------------------------------------------") + N := 1 + fmt.Println("固定的参数, N =", N) + r1 := df.Col("x").Diff(N).Values() + fmt.Println("序列化结果:", r1) + fmt.Println("------------------------------------------------------------") + d2 := []float64{1, 2, 3, 4, 3, 3, 2, 1, Nil2Float64, Nil2Float64, Nil2Float64, Nil2Float64} + s2 := NewSeries(SERIES_TYPE_FLOAT64, "x", d2) + fmt.Printf("序列化参数: %+v\n", s2.Values()) + r2 := df.Col("x").Diff(s2).Values() + fmt.Println("序列化结果:", r2) +} diff --git a/v1/generic_ewm.go b/v1/generic_ewm.go new file mode 100644 index 0000000000000000000000000000000000000000..b594bedba8a45e9c64ca5b0060b5f68dd90510fb --- /dev/null +++ b/v1/generic_ewm.go @@ -0,0 +1,180 @@ +package v1 + +import ( + "gitee.com/quant1x/pandas/stat" + "math" +) + +type AlphaType int + +// https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.ewm.html +const ( + // AlphaAlpha Specify smoothing factor α directly, 0<α≤1. + AlphaAlpha AlphaType = iota + // AlphaCom Specify decay in terms of center of mass, α=1/(1+com), for com ≥ 0. + AlphaCom + // AlphaSpan Specify decay in terms of span, α=2/(span+1), for span ≥ 1. + AlphaSpan + // AlphaHalfLife Specify decay in terms of half-life, α=1−exp(−ln(2)/halflife), for halflife > 0. + AlphaHalfLife +) + +// EW (Factor) 指数加权(EW)计算Alpha 结构属性非0即为有效启动同名算法 +type EW struct { + Com stat.DType // 根据质心指定衰减 + Span stat.DType // 根据跨度指定衰减 + HalfLife stat.DType // 根据半衰期指定衰减 + Alpha stat.DType // 直接指定的平滑因子α + Adjust bool // 除以期初的衰减调整系数以核算 相对权重的不平衡(将 EWMA 视为移动平均线) + IgnoreNA bool // 计算权重时忽略缺失值 + Callback func(idx int) stat.DType +} + +// ExponentialMovingWindow 加权移动窗口 +type ExponentialMovingWindow struct { + data Series // 序列 + atype AlphaType // 计算方式: com/span/halflefe/alpha + param stat.DType // 参数类型为浮点 + adjust bool // 默认为真, 是否调整, 默认真时, 计算序列的EW移动平均线, 为假时, 计算指数加权递归 + ignoreNA bool // 默认为假, 计算权重时是否忽略缺失值NaN + minPeriods int // 默认为0, 窗口中具有值所需的最小观测值数,否则结果为NaN + axis int // {0,1}, 默认为0, 0跨行计算, 1跨列计算 + cb func(idx int) stat.DType +} + +// EWM provides exponential weighted calculations. +func (s *NDFrame) EWM(alpha EW) ExponentialMovingWindow { + atype := AlphaAlpha + param := 0.00 + adjust := alpha.Adjust + ignoreNA := alpha.IgnoreNA + if alpha.Com != 0 { + atype = AlphaCom + param = alpha.Com + } else if alpha.Span != 0 { + atype = AlphaSpan + param = alpha.Span + } else if alpha.HalfLife != 0 { + atype = AlphaHalfLife + param = alpha.HalfLife + } else { + atype = AlphaAlpha + param = alpha.Alpha + } + + dest := NewSeries(SERIES_TYPE_FLOAT64, s.name, s.Values()) + return ExponentialMovingWindow{ + data: dest, + atype: atype, + param: param, + adjust: adjust, + ignoreNA: ignoreNA, + cb: alpha.Callback, + } +} + +func (w ExponentialMovingWindow) Mean() Series { + var alpha stat.DType + + switch w.atype { + case AlphaAlpha: + if w.param <= 0 { + panic("alpha param must be > 0") + } + alpha = w.param + + case AlphaCom: + if w.param <= 0 { + panic("com param must be >= 0") + } + alpha = 1 / (1 + w.param) + + case AlphaSpan: + if w.param < 1 { + panic("span param must be >= 1") + } + alpha = 2 / (w.param + 1) + + case AlphaHalfLife: + if w.param <= 0 { + panic("halflife param must be > 0") + } + alpha = 1 - math.Exp(-math.Ln2/w.param) + } + + return w.applyMean(w.data, alpha) +} + +func (w ExponentialMovingWindow) applyMean(data Series, alpha stat.DType) Series { + if w.adjust { + w.adjustedMean(data, alpha, w.ignoreNA) + } else { + w.notadjustedMean(data, alpha, w.ignoreNA) + } + return data +} + +func (w ExponentialMovingWindow) adjustedMean(data Series, alpha stat.DType, ignoreNA bool) { + var ( + values = data.Values().([]stat.DType) + weight stat.DType = 1 + last = values[0] + ) + + alpha = 1 - alpha + for t := 1; t < len(values); t++ { + + w := alpha*weight + 1 + x := values[t] + if stat.DTypeIsNaN(x) { + if ignoreNA { + weight = w + } + values[t] = last + continue + } + + last = last + (x-last)/(w) + weight = w + values[t] = last + } +} + +func (w ExponentialMovingWindow) notadjustedMean(data Series, alpha stat.DType, ignoreNA bool) { + hasCallback := false + if stat.DTypeIsNaN(alpha) { + hasCallback = true + alpha = w.cb(0) + } + var ( + count int + values = data.Values().([]stat.DType) + //values = data.DTypes() // Dtypes有复制功能 + beta = 1 - alpha + last = values[0] + ) + if stat.Float64IsNaN(last) { + last = 0 + values[0] = last + } + for t := 1; t < len(values); t++ { + x := values[t] + + if stat.DTypeIsNaN(x) { + values[t] = last + continue + } + if hasCallback { + alpha = w.cb(t) + beta = 1 - alpha + } + // yt = (1−α)*y(t−1) + α*x(t) + last = (beta * last) + (alpha * x) + if stat.DTypeIsNaN(last) { + last = values[t-1] + } + values[t] = last + + count++ + } +} diff --git a/v1/generic_fillna.go b/v1/generic_fillna.go new file mode 100644 index 0000000000000000000000000000000000000000..34b0d7b114a2559c06a31051d6a80ae5cb6c0d40 --- /dev/null +++ b/v1/generic_fillna.go @@ -0,0 +1,25 @@ +package v1 + +import "gitee.com/quant1x/pandas/stat" + +// FillNa 填充NaN的元素为v +// inplace为真是修改series元素的值 +// 如果v和Values()返回值的slice类型不一致就会panic +func FillNa[T stat.GenericType](s *NDFrame, v T, inplace bool) *NDFrame { + values := s.Values() + switch rows := values.(type) { + case []string: + for idx, iv := range rows { + if stat.StringIsNaN(iv) && inplace { + rows[idx] = stat.AnyToString(v) + } + } + case []float64: + for idx, iv := range rows { + if stat.Float64IsNaN(iv) && inplace { + rows[idx] = stat.AnyToFloat64(v) + } + } + } + return s +} diff --git a/v1/generic_max.go b/v1/generic_max.go new file mode 100644 index 0000000000000000000000000000000000000000..580777888705db9588f64336730e1bd6081a9fdf --- /dev/null +++ b/v1/generic_max.go @@ -0,0 +1,113 @@ +package v1 + +import "gitee.com/quant1x/pandas/stat" + +func (self *NDFrame) Max() any { + values := self.Values() + switch rows := values.(type) { + case []bool: + max := false + i := 0 + for idx, iv := range rows { + if iv && !max { + max = iv + i += 1 + } + _ = idx + } + if i > 0 { + return max + } + return false + case []string: + max := "" + hasNaN := false + i := 0 + for idx, iv := range rows { + if stat.StringIsNaN(iv) { + hasNaN = true + break + } + if iv > max { + max = iv + i += 1 + } + _ = idx + } + if hasNaN { + return stat.StringNaN + } else if i > 0 { + return max + } + return stat.StringNaN + case []int64: + max := stat.MinInt64 + //i := 0 + for idx, iv := range rows { + if stat.Float64IsNaN(float64(iv)) { + continue + } + if iv > max { + max = iv + //i = idx + } + _ = idx + } + return max + case []float32: + max := stat.MinFloat32 + hasNan := false + i := 0 + for idx, iv := range rows { + if stat.Float32IsNaN(iv) { + hasNan = true + break + } + if iv > max { + max = iv + i += 1 + } + _ = idx + } + if hasNan { + return Nil2Float32 + } else if i > 0 { + return max + } + return Nil2Float32 + //case []float32: + // if self.Len() == 0 { + // return Nil2Float32 + // } + // return stat.Max(rows) + case []float64: + max := stat.MinFloat64 + hasNaN := false + i := 0 + for idx, iv := range rows { + if stat.Float64IsNaN(iv) { + hasNaN = true + break + } + if iv > max { + max = iv + i += 1 + } + _ = idx + } + if hasNaN { + return Nil2Float64 + } else if i > 0 { + return max + } + return Nil2Float64 + //case []float64: + // if self.Len() == 0 { + // return Nil2Float64 + // } + // return stat.Max(rows) + default: + panic(ErrUnsupportedType) + } + //return Nil2Float64 +} diff --git a/v1/generic_min.go b/v1/generic_min.go new file mode 100644 index 0000000000000000000000000000000000000000..fb9ac0771451acb50982dbeb9dce978f21ced117 --- /dev/null +++ b/v1/generic_min.go @@ -0,0 +1,119 @@ +package v1 + +import "gitee.com/quant1x/pandas/stat" + +func (self *NDFrame) Min() any { + values := self.Values() + switch rows := values.(type) { + case []bool: + min := true + i := 0 + for idx, iv := range rows { + if !iv && min { + min = iv + i += 1 + } + _ = idx + } + if i > 0 { + return min + } + return false + case []string: + min := "" + hasNaN := false + i := 0 + for idx, iv := range rows { + if stat.StringIsNaN(iv) { + hasNaN = true + break + } else if i < 1 { + min = iv + i += 1 + } + if iv < min { + min = iv + i += 1 + } + _ = idx + } + if hasNaN { + return stat.StringNaN + } else if i > 0 { + return min + } + return stat.StringNaN + case []int64: + min := stat.MaxInt64 + i := 0 + for idx, iv := range rows { + if stat.Float64IsNaN(float64(iv)) { + continue + } else if i < 1 { + min = iv + i += 1 + } + if iv < min { + min = iv + i += 1 + } + _ = idx + } + return min + case []float32: + min := stat.MaxFloat32 + hasNan := false + i := 0 + for idx, iv := range rows { + if stat.Float32IsNaN(iv) { + hasNan = true + break + } + if iv < min { + min = iv + i += 1 + } + _ = idx + } + if hasNan { + return stat.Nil2Float32 + } else if i > 0 { + return min + } + return stat.Nil2Float32 + //case []float32: + // if self.Len() == 0 { + // return Nil2Float32 + // } + // return stat.Min(rows) + case []float64: + min := stat.MaxFloat64 + hasNaN := false + i := 0 + for idx, iv := range rows { + if stat.Float64IsNaN(iv) { + hasNaN = true + break + } + if iv < min { + min = iv + i += 1 + } + _ = idx + } + if hasNaN { + return stat.Nil2Float64 + } else if i > 0 { + return min + } + return stat.Nil2Float64 + //case []float64: + // if self.Len() == 0 { + // return Nil2Float64 + // } + // return stat.Min(rows) + default: + panic(ErrUnsupportedType) + } + return stat.Nil2Float64 +} diff --git a/v1/generic_range.go b/v1/generic_range.go new file mode 100644 index 0000000000000000000000000000000000000000..ba467eee40d22c2798d2bcd1e9432986ea508c3a --- /dev/null +++ b/v1/generic_range.go @@ -0,0 +1,126 @@ +package v1 + +import ( + "gitee.com/quant1x/pandas/stat" + gc "github.com/huandu/go-clone" + "reflect" +) + +// Copy 复制一个副本 +func (self *NDFrame) Copy() Series { + vlen := self.Len() + return self.Subset(0, vlen, true) +} + +func (self *NDFrame) Subset(start, end int, opt ...any) Series { + // 默认不copy + var __optCopy bool = false + if len(opt) > 0 { + // 第一个参数为是否copy + if _cp, ok := opt[0].(bool); ok { + __optCopy = _cp + } + } + var vs any + var rows int + vv := reflect.ValueOf(self.values) + vk := vv.Kind() + switch vk { + case reflect.Slice, reflect.Array: // 切片和数组同样的处理逻辑 + vvs := vv.Slice(start, end) + vs = vvs.Interface() + rows = vv.Len() + if __optCopy && rows > 0 { + vs = gc.Clone(vs) + } + rows = vvs.Len() + frame := NDFrame{ + formatter: self.formatter, + name: self.name, + type_: self.type_, + nilCount: 0, + rows: rows, + values: vs, + } + return &frame + default: + // 其它类型忽略 + } + return self.Empty() +} + +func (self *NDFrame) oldSubset(start, end int, opt ...any) Series { + // 默认不copy + var __optCopy bool = false + if len(opt) > 0 { + // 第一个参数为是否copy + if _cp, ok := opt[0].(bool); ok { + __optCopy = _cp + } + } + var vs any + var rows int + switch values := self.values.(type) { + case []bool: + subset := values[start:end] + rows = len(subset) + if !__optCopy { + vs = subset + } else { + _vs := make([]bool, 0) + _vs = append(_vs, subset...) + vs = _vs + } + case []string: + subset := values[start:end] + rows = len(subset) + if !__optCopy { + vs = subset + } else { + _vs := make([]string, 0) + _vs = append(_vs, subset...) + vs = _vs + } + case []int64: + subset := values[start:end] + rows = len(subset) + if !__optCopy { + vs = subset + } else { + _vs := make([]int64, 0) + _vs = append(_vs, subset...) + vs = _vs + } + case []float64: + subset := values[start:end] + rows = len(subset) + if !__optCopy { + vs = subset + } else { + _vs := make([]float64, 0) + _vs = append(_vs, subset...) + vs = _vs + } + } + frame := NDFrame{ + formatter: self.formatter, + name: self.name, + type_: self.type_, + nilCount: 0, + rows: rows, + values: vs, + } + var s Series + s = &frame + return s +} + +// Select 选取一段记录 +func (self *NDFrame) Select(r stat.ScopeLimit) Series { + start, end, err := r.Limits(self.Len()) + if err != nil { + return nil + } + series := self.Subset(start, end+1) + return series +} diff --git a/v1/generic_ref.go b/v1/generic_ref.go new file mode 100644 index 0000000000000000000000000000000000000000..83b89f5a192dd988872bdcf7002668bcd983d508 --- /dev/null +++ b/v1/generic_ref.go @@ -0,0 +1,53 @@ +package v1 + +import ( + "gitee.com/quant1x/pandas/exception" + "gitee.com/quant1x/pandas/stat" +) + +func (self *NDFrame) Ref(param any) (s Series) { + var N []float32 + switch v := param.(type) { + case int: + N = stat.Repeat[float32](float32(v), self.Len()) + case []float32: + N = stat.Align(v, Nil2Float32, self.Len()) + case Series: + vs := v.Values() + N = stat.SliceToFloat32(vs) + N = stat.Align(N, Nil2Float32, self.Len()) + default: + panic(exception.New(1, "error window")) + } + + var d Series + d = clone(self).(Series) + //return Shift[float64](&d, periods, func() float64 { + // return Nil2Float64 + //}) + switch values := self.values.(type) { + case []bool: + _ = values + return Shift2[bool](&d, N, func() bool { + return stat.BoolNaN + }) + case []string: + return Shift2[string](&d, N, func() string { + return stat.StringNaN + }) + case []int64: + return Shift2[int64](&d, N, func() int64 { + return stat.Nil2Int64 + }) + case []float32: + return Shift2[float32](&d, N, func() float32 { + return Nil2Float32 + }) + default: //case []float64: + return Shift2[float64](&d, N, func() float64 { + return Nil2Float64 + }) + } + + return d +} diff --git a/v1/generic_rolling.go b/v1/generic_rolling.go new file mode 100644 index 0000000000000000000000000000000000000000..eb267f8ea4d06de9a4a4fe51a1ba468fa0b8a1bf --- /dev/null +++ b/v1/generic_rolling.go @@ -0,0 +1,77 @@ +package v1 + +import ( + "gitee.com/quant1x/pandas/exception" + "gitee.com/quant1x/pandas/stat" +) + +// RollingAndExpandingMixin 滚动和扩展静态横切 +type RollingAndExpandingMixin struct { + window []stat.DType + series Series +} + +// Rolling RollingAndExpandingMixin +func (self *NDFrame) Rolling(param any) RollingAndExpandingMixin { + var N []stat.DType + switch v := param.(type) { + case int: + N = stat.Repeat[stat.DType](stat.DType(v), self.Len()) + case []stat.DType: + N = stat.Align(v, stat.DTypeNaN, self.Len()) + case Series: + vs := v.DTypes() + N = stat.Align(vs, stat.DTypeNaN, self.Len()) + default: + panic(exception.New(1, "error window")) + } + w := RollingAndExpandingMixin{ + window: N, + series: self, + } + return w +} + +func (r RollingAndExpandingMixin) getBlocks() (blocks []Series) { + for i := 0; i < r.series.Len(); i++ { + N := r.window[i] + if stat.DTypeIsNaN(N) || int(N) > i+1 { + blocks = append(blocks, r.series.Empty()) + continue + } + window := int(N) + start := i + 1 - window + end := i + 1 + blocks = append(blocks, r.series.Subset(start, end, false)) + } + + return +} + +func (r RollingAndExpandingMixin) Apply_v1(f func(S Series, N stat.DType) stat.DType) (s Series) { + s = r.series.Empty() + for i, block := range r.getBlocks() { + if block.Len() == 0 { + s.Append(stat.DTypeNaN) + continue + } + v := f(block, r.window[i]) + s.Append(v) + } + return +} + +// Apply 接受一个回调 +func (r RollingAndExpandingMixin) Apply(f func(S Series, N stat.DType) stat.DType) (s Series) { + values := make([]stat.DType, r.series.Len()) + for i, block := range r.getBlocks() { + if block.Len() == 0 { + values[i] = stat.DTypeNaN + continue + } + v := f(block, r.window[i]) + values[i] = v + } + s = NewSeries(SERIES_TYPE_DTYPE, r.series.Name(), values) + return +} diff --git a/v1/generic_shift.go b/v1/generic_shift.go new file mode 100644 index 0000000000000000000000000000000000000000..c93210e27dfd08541e20ce642e583149bb3d1599 --- /dev/null +++ b/v1/generic_shift.go @@ -0,0 +1,64 @@ +package v1 + +import ( + "gitee.com/quant1x/pandas/stat" + "math" +) + +// Shift series切片, 使用可选的时间频率按所需的周期数移动索引 +func Shift[T stat.GenericType](s *Series, periods int, cbNan func() T) Series { + var d Series + d = clone(*s).(Series) + if periods == 0 { + return d + } + + values := d.Values().([]T) + + var ( + naVals []T + dst []T + src []T + ) + + if shlen := int(math.Abs(float64(periods))); shlen < len(values) { + if periods > 0 { + naVals = values[:shlen] + dst = values[shlen:] + src = values + } else { + naVals = values[len(values)-shlen:] + dst = values[:len(values)-shlen] + src = values[shlen:] + } + copy(dst, src) + } else { + naVals = values + } + for i := range naVals { + naVals[i] = cbNan() + } + _ = naVals + return d +} + +// Shift2 series切片, 使用可选的时间频率按所需的周期数移动索引 +func Shift2[T stat.GenericType](s *Series, N []float32, cbNan func() T) Series { + var d Series + d = clone(*s).(Series) + if len(N) == 0 { + return d + } + S := (*s).Values().([]T) + values := d.Values().([]T) + for i, _ := range S { + x := N[i] + if stat.Float32IsNaN(x) || int(x) > i { + values[i] = cbNan() + continue + } + values[i] = S[i-int(x)] + } + + return d +} diff --git a/v1/generic_sort.go b/v1/generic_sort.go new file mode 100644 index 0000000000000000000000000000000000000000..ab375905ae8cca84b38e9928c52aef6c3a9a6d0c --- /dev/null +++ b/v1/generic_sort.go @@ -0,0 +1,66 @@ +package v1 + +// Len 获得行数, 实现sort.Interface接口的获取元素数量方法 +func (self *NDFrame) Len() int { + return self.rows +} + +// Less 实现sort.Interface接口的比较元素方法 +func (self *NDFrame) Less(i, j int) bool { + if self.type_ == SERIES_TYPE_BOOL { + values := self.Values().([]bool) + var ( + a = int(0) + b = int(0) + ) + if values[i] { + a = 1 + } + if values[j] { + b = 1 + } + return a < b + } else if self.type_ == SERIES_TYPE_INT64 { + values := self.Values().([]int64) + return values[i] < values[j] + } else if self.type_ == SERIES_TYPE_FLOAT32 { + values := self.Values().([]float32) + return values[i] < values[j] + } else if self.type_ == SERIES_TYPE_FLOAT64 { + values := self.Values().([]float64) + return values[i] < values[j] + } else if self.type_ == SERIES_TYPE_STRING { + values := self.Values().([]string) + return values[i] < values[j] + } else { + // SERIES_TYPE_INVAILD + // 应该到不了这里, Len()会返回0 + panic(ErrUnsupportedType) + } + return false + +} + +// Swap 实现sort.Interface接口的交换元素方法 +func (self *NDFrame) Swap(i, j int) { + if self.type_ == SERIES_TYPE_BOOL { + values := self.Values().([]bool) + values[i], values[j] = values[j], values[i] + } else if self.type_ == SERIES_TYPE_INT64 { + values := self.Values().([]int64) + values[i], values[j] = values[j], values[i] + } else if self.type_ == SERIES_TYPE_FLOAT32 { + values := self.Values().([]float32) + values[i], values[j] = values[j], values[i] + } else if self.type_ == SERIES_TYPE_FLOAT64 { + values := self.Values().([]float64) + values[i], values[j] = values[j], values[i] + } else if self.type_ == SERIES_TYPE_STRING { + values := self.Values().([]string) + values[i], values[j] = values[j], values[i] + } else { + // SERIES_TYPE_INVAILD + // 应该到不了这里, Len()会返回0 + panic(ErrUnsupportedType) + } +} diff --git a/v1/generic_sum.go b/v1/generic_sum.go new file mode 100644 index 0000000000000000000000000000000000000000..1e0151699f086141acc4a76268c8c071213d468e --- /dev/null +++ b/v1/generic_sum.go @@ -0,0 +1,8 @@ +package v1 + +import "gitee.com/quant1x/pandas/stat" + +func (self *NDFrame) Sum() stat.DType { + fs := self.DTypes() + return stat.Sum(fs) +} diff --git a/v1/generic_test.go b/v1/generic_test.go new file mode 100644 index 0000000000000000000000000000000000000000..78ee20be7b2ede18212e04e07aca1284a988de03 --- /dev/null +++ b/v1/generic_test.go @@ -0,0 +1,87 @@ +package v1 + +import ( + "fmt" + "gitee.com/quant1x/pandas/stat" + "testing" +) + +func TestSeriesFrame(t *testing.T) { + data := []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + s1 := NewSeries(SERIES_TYPE_FLOAT64, "x", data) + fmt.Printf("%+v\n", s1) + + var d1 any + d1 = data + s2 := NewSeries(SERIES_TYPE_FLOAT64, "x", d1) + fmt.Printf("%+v\n", s2) + + var s3 Series + // s3 = NewSeriesBool("x", data) + s3 = NewSeries(SERIES_TYPE_BOOL, "x", data) + fmt.Printf("%+v\n", s3.Values()) + + var s4 Series + ts4 := GenericSeries[float64]("x", data...) + ts4 = NewSeries(SERIES_TYPE_FLOAT64, "x", data) + s4 = ts4 + fmt.Printf("%+v\n", s4.Values()) +} + +func TestNDFrameNew(t *testing.T) { + // float64 + //d1 := []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, NaN(), 12} + d1 := []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} + nd1 := NewNDFrame[float64]("x", d1...) + fmt.Println(nd1) + + r := stat.RangeFinite(-1) + ndr1 := nd1.Select(r) + fmt.Println(ndr1.Values()) + + fmt.Println(nd1.Records()) + nd11 := nd1.Subset(1, 2, true) + fmt.Println(nd11.Records()) + fmt.Println(nd1.Max()) + fmt.Println(nd1.RollingV1(5).Max()) + fmt.Println(nd1.RollingV1(5).Min()) + + nd12 := nd1.RollingV1(5).Mean() + d12 := nd12.Values() + fmt.Println(d12) + + nd13 := nd1.Shift(3) + fmt.Println(nd13.Values()) + nd14 := nd1.RollingV1(5).StdDev() + fmt.Println(nd14.Values()) + + // string + d2 := []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "nan", "12"} + nd2 := NewNDFrame[string]("x", d2...) + fmt.Println(nd2) + nd21 := nd2.RollingV1(5).Max() + fmt.Println(nd21) + nd2.FillNa(0, true) + fmt.Println(nd2) + fmt.Println(nd2.Records()) + fmt.Println(nd2.Empty()) +} + +func TestRolling2(t *testing.T) { + d1 := []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} + s1 := NewNDFrame[float64]("x", d1...) + df := NewDataFrame(s1) + fmt.Println(df) + fmt.Println("------------------------------------------------------------") + + N := 5 + fmt.Println("固定的参数, N =", N) + r1 := df.Col("x").Rolling(5).Mean().Values() + fmt.Println("序列化结果:", r1) + fmt.Println("------------------------------------------------------------") + d2 := []float64{1, 2, 3, 4, 3, 3, 2, 1, Nil2Float64, Nil2Float64, Nil2Float64, Nil2Float64} + s2 := NewSeries(SERIES_TYPE_FLOAT64, "x", d2) + fmt.Printf("序列化参数: %+v\n", s2.Values()) + r2 := df.Col("x").Rolling(s2).Mean().Values() + fmt.Println("序列化结果:", r2) +} diff --git a/generic_type.go b/v1/generic_type.go similarity index 99% rename from generic_type.go rename to v1/generic_type.go index aa5128ae5aad9155ea4539cfe27eb28a82262080..be3b581ac32496f75cb7813e9b1ffa87c61fdb8e 100644 --- a/generic_type.go +++ b/v1/generic_type.go @@ -1,4 +1,4 @@ -package pandas +package v1 import ( "fmt" diff --git a/rolling_count.go b/v1/rolling_count.go similarity index 97% rename from rolling_count.go rename to v1/rolling_count.go index 668438ad61ab8e9dbc2fbd0a01b591b06241a95f..39eaa40aab57f4a0d081298bd4462860af152379 100644 --- a/rolling_count.go +++ b/v1/rolling_count.go @@ -1,4 +1,4 @@ -package pandas +package v1 import ( "gitee.com/quant1x/pandas/stat" diff --git a/rolling_max.go b/v1/rolling_max.go similarity index 96% rename from rolling_max.go rename to v1/rolling_max.go index 5a13c854ba89f6aa4180914fa42e42cef96bdcc9..db7bea06db9071138a6a0a519cd68209ff1b01ec 100644 --- a/rolling_max.go +++ b/v1/rolling_max.go @@ -1,4 +1,4 @@ -package pandas +package v1 func (r RollingAndExpandingMixin) Max() (s Series) { s = r.series.Empty() diff --git a/rolling_mean.go b/v1/rolling_mean.go similarity index 95% rename from rolling_mean.go rename to v1/rolling_mean.go index 52a4ed57a17bc8f9d7691cdad6e0c352a0041163..3525ee4c445ff6b9848fc3b02f5c0815e89281cc 100644 --- a/rolling_mean.go +++ b/v1/rolling_mean.go @@ -1,4 +1,4 @@ -package pandas +package v1 import "gitee.com/quant1x/pandas/stat" diff --git a/rolling_min.go b/v1/rolling_min.go similarity index 96% rename from rolling_min.go rename to v1/rolling_min.go index 5bc79d62cef634db9bc337f6234a01f01614513b..7c901fd7b6d6b9e0dd7a3fdd3e37ccb3cb7b2dcc 100644 --- a/rolling_min.go +++ b/v1/rolling_min.go @@ -1,4 +1,4 @@ -package pandas +package v1 func (r RollingAndExpandingMixin) Min() (s Series) { s = r.series.Empty() diff --git a/rolling_std.go b/v1/rolling_std.go similarity index 96% rename from rolling_std.go rename to v1/rolling_std.go index 4733b98cefad09a7efd7a27fedc3809efea34044..1e0beb0a75df72651ff9d78e1adfafaa7a8b04c7 100644 --- a/rolling_std.go +++ b/v1/rolling_std.go @@ -1,4 +1,4 @@ -package pandas +package v1 func (r RollingAndExpandingMixin) Std() Series { s := r.series.Empty() diff --git a/rolling_sum.go b/v1/rolling_sum.go similarity index 94% rename from rolling_sum.go rename to v1/rolling_sum.go index fc3a289bf36cca72f9b24b72cfdee94ddae8547f..151dfafc2e97e608b3921260753e0c915bb520ff 100644 --- a/rolling_sum.go +++ b/v1/rolling_sum.go @@ -1,4 +1,4 @@ -package pandas +package v1 import "gitee.com/quant1x/pandas/stat" diff --git a/rolling_v1.go b/v1/rolling_v1.go similarity index 99% rename from rolling_v1.go rename to v1/rolling_v1.go index 32f14f5b2e2607f21c6388f377bb678e19591a45..d503f134887770236018ad1e3b2207af7cf131ae 100644 --- a/rolling_v1.go +++ b/v1/rolling_v1.go @@ -1,4 +1,4 @@ -package pandas +package v1 // RollingWindowV1 is used for rolling window calculations. // Deprecated: 使用RollingAndExpandingMixin diff --git a/series.go b/v1/series.go similarity index 99% rename from series.go rename to v1/series.go index a5b30abdf00cff082a1d0519a3f6a4a8fcf71d61..0aa307313f0041a9d662d3f4d2bc902548f6d22a 100644 --- a/series.go +++ b/v1/series.go @@ -1,4 +1,4 @@ -package pandas +package v1 import ( "fmt" diff --git a/series_bool.go b/v1/series_bool.go similarity index 99% rename from series_bool.go rename to v1/series_bool.go index ec349bf4c2aa35a94593affc26678fd430b238eb..de7cd2cddffef1a2164d6ad938c3c60f00296e96 100644 --- a/series_bool.go +++ b/v1/series_bool.go @@ -1,4 +1,4 @@ -package pandas +package v1 import ( "gitee.com/quant1x/pandas/stat" diff --git a/series_bool_test.go b/v1/series_bool_test.go similarity index 97% rename from series_bool_test.go rename to v1/series_bool_test.go index f9899f14a52f4d6d925b56b5af9bf22d2ee3b31d..4a7944bdfcb9507f4fca5d483002ee54460cb2d7 100644 --- a/series_bool_test.go +++ b/v1/series_bool_test.go @@ -1,4 +1,4 @@ -package pandas +package v1 import ( "fmt" diff --git a/series_float32.go b/v1/series_float32.go similarity index 99% rename from series_float32.go rename to v1/series_float32.go index 07809d7fed60786e758de95d7f523339eca86b39..4c8a17fbafefa187114111229f54e57323bd45c1 100644 --- a/series_float32.go +++ b/v1/series_float32.go @@ -1,4 +1,4 @@ -package pandas +package v1 // TODO:留给于总的作业 type SeriesFloat32 struct { diff --git a/series_float64.go b/v1/series_float64.go similarity index 99% rename from series_float64.go rename to v1/series_float64.go index c49159932465ea6ed3238a080caf260fdf5fb712..d576ebbc953e33b061253dbf0e56f452c3422dfa 100644 --- a/series_float64.go +++ b/v1/series_float64.go @@ -1,4 +1,4 @@ -package pandas +package v1 import ( "gitee.com/quant1x/pandas/stat" diff --git a/series_float64_test.go b/v1/series_float64_test.go similarity index 98% rename from series_float64_test.go rename to v1/series_float64_test.go index 131b0c8a85761b59e4cde3e42f750f527925d6d8..c89110d7a027ecbf039135503e7aa30914eb9bfa 100644 --- a/series_float64_test.go +++ b/v1/series_float64_test.go @@ -1,4 +1,4 @@ -package pandas +package v1 import ( "fmt" diff --git a/v1/series_generic.go b/v1/series_generic.go new file mode 100644 index 0000000000000000000000000000000000000000..dda8b1fff5aea4b0384bc64877477457343385a0 --- /dev/null +++ b/v1/series_generic.go @@ -0,0 +1,93 @@ +package v1 + +import ( + "gitee.com/quant1x/pandas/stat" + "reflect" +) + +// 初始化全局的私有变量 +var ( + rawBool bool = true + typeBool = reflect.TypeOf([]bool{}) + rawInt32 int32 = int32(0) + typeInt32 = reflect.TypeOf([]int32{}) + rawInt64 int64 = int64(0) + typeInt64 = reflect.TypeOf([]int64{}) + rawFloat32 float32 = float32(0) + typeFloat32 = reflect.TypeOf([]float32{}) + rawFloat64 float64 = float64(0) + typeFloat64 = reflect.TypeOf([]float64{}) + typeString = reflect.TypeOf([]string{}) +) + +// NewSeriesWithoutType 不带类型创新一个新series +func NewSeriesWithoutType(name string, values ...interface{}) Series { + _type, err := detectTypeBySlice(values...) + if err != nil { + return nil + } + return NewSeriesWithType(_type, name, values...) +} + +// NewSeriesWithType 通过类型创新一个新series +func NewSeriesWithType(_type Type, name string, values ...interface{}) Series { + frame := NDFrame{ + formatter: stat.DefaultFormatter, + name: name, + type_: SERIES_TYPE_INVAILD, + nilCount: 0, + rows: 0, + //values: []E{}, + } + //_type, err := detectTypeBySlice(values) + //if err != nil { + // return nil + //} + frame.type_ = _type + if frame.type_ == SERIES_TYPE_BOOL { + // bool + frame.values = reflect.MakeSlice(typeBool, 0, 0).Interface() + } else if frame.type_ == SERIES_TYPE_INT64 { + // int64 + frame.values = reflect.MakeSlice(typeInt64, 0, 0).Interface() + } else if frame.type_ == SERIES_TYPE_FLOAT32 { + // float32 + frame.values = reflect.MakeSlice(typeFloat32, 0, 0).Interface() + } else if frame.type_ == SERIES_TYPE_FLOAT64 { + // float64 + frame.values = reflect.MakeSlice(typeFloat64, 0, 0).Interface() + } else { + // string, 字符串最后容错使用 + frame.values = reflect.MakeSlice(typeString, 0, 0).Interface() + } + //series.Data = make([]float64, 0) // Warning: filled with 0.0 (not NaN) + //size := len(series.values) + //size := 0 + //for idx, v := range values { + // switch val := v.(type) { + // case nil, int8, uint8, int16, uint16, int32, uint32, int64, uint64, int, uint, float32, float64, bool, string: + // // 基础类型 + // series_append(&frame, idx, size, val) + // default: + // vv := reflect.ValueOf(val) + // vk := vv.Kind() + // switch vk { + // //case reflect.Invalid: // {interface} nil + // // series.assign(idx, size, Nil2Float64) + // case reflect.Slice, reflect.Array: // 切片或数组 + // for i := 0; i < vv.Len(); i++ { + // tv := vv.Index(i).Interface() + // //series.assign(idx, size, str) + // series_append(&frame, idx, size, tv) + // } + // case reflect.Struct: // 忽略结构体 + // continue + // default: + // series_append(&frame, idx, size, nil) + // } + // } + //} + frame.Append(values...) + + return &frame +} diff --git a/series_generic_test.go b/v1/series_generic_test.go similarity index 97% rename from series_generic_test.go rename to v1/series_generic_test.go index c034e6e891f950e2b1202224670f48cd7578a62c..75ac3f9ab2af3581f5a2ec6c792685730a5475ce 100644 --- a/series_generic_test.go +++ b/v1/series_generic_test.go @@ -1,4 +1,4 @@ -package pandas +package v1 import ( "fmt" diff --git a/series_int64.go b/v1/series_int64.go similarity index 99% rename from series_int64.go rename to v1/series_int64.go index 2c2f2e3242089701f07b91e9f8d877c01eac517e..b2eeb3f4eea008a45555b079c903f72665d2cd0c 100644 --- a/series_int64.go +++ b/v1/series_int64.go @@ -1,4 +1,4 @@ -package pandas +package v1 import ( "gitee.com/quant1x/pandas/stat" diff --git a/series_int64_test.go b/v1/series_int64_test.go similarity index 97% rename from series_int64_test.go rename to v1/series_int64_test.go index 9cf91bb039860f29a00de69aea8ecdcb107128ad..38de53cf816b6a5b226f1ed7891235373ceec883 100644 --- a/series_int64_test.go +++ b/v1/series_int64_test.go @@ -1,4 +1,4 @@ -package pandas +package v1 import ( "fmt" diff --git a/series_type.go b/v1/series_type.go similarity index 97% rename from series_type.go rename to v1/series_type.go index a5c85291d1d16adc41d62f6927f2fb87002c0ae1..733b7044ba62f9cd2d17ba475458c206cfdbb56d 100644 --- a/series_type.go +++ b/v1/series_type.go @@ -1,4 +1,4 @@ -package pandas +package v1 import ( "fmt" diff --git a/series_xstring.go b/v1/series_xstring.go similarity index 99% rename from series_xstring.go rename to v1/series_xstring.go index a38f08a46987631914127cb4b6eb1e06ec2c4053..771c0c74cc155891ad121e2d950809afd84090cb 100644 --- a/series_xstring.go +++ b/v1/series_xstring.go @@ -1,4 +1,4 @@ -package pandas +package v1 import ( "gitee.com/quant1x/pandas/stat" diff --git a/series_xstring_test.go b/v1/series_xstring_test.go similarity index 97% rename from series_xstring_test.go rename to v1/series_xstring_test.go index 597b9606b9ee6183e93d61a237e9a69ee4cedaf4..640a904b4bcc64818c4f526e61369b3aa3fc0f25 100644 --- a/series_xstring_test.go +++ b/v1/series_xstring_test.go @@ -1,4 +1,4 @@ -package pandas +package v1 import ( "fmt" diff --git a/unsafe.go b/v1/unsafe.go similarity index 98% rename from unsafe.go rename to v1/unsafe.go index 025dd1f5dc60baa9e2f46a071e9506313f9a4d9e..832e90f9657b77b2c25ac69e0c4ce73e798ae898 100644 --- a/unsafe.go +++ b/v1/unsafe.go @@ -3,7 +3,7 @@ //go:build !js && !appengine && !safe // +build !js,!appengine,!safe -package pandas +package v1 import ( "unsafe"