From 8ae71ce933c69682649bc0f2a7535c896cf599c2 Mon Sep 17 00:00:00 2001 From: wangfeng Date: Tue, 21 Feb 2023 08:19:55 +0800 Subject: [PATCH 1/3] =?UTF-8?q?series=E5=A2=9E=E5=8A=A0=E6=8C=89=E5=88=87?= =?UTF-8?q?=E7=89=87=E4=B8=8B=E6=A0=87=E5=8F=96=E4=B8=80=E8=A1=8C=E8=AE=B0?= =?UTF-8?q?=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- stat/errors.go | 1 + stat/ndarray.go | 52 ++++------------------------------- stat/ndarray_range.go | 63 +++++++++++++++++++++++++++++++++++++++++++ stat/series.go | 1 + 4 files changed, 70 insertions(+), 47 deletions(-) create mode 100644 stat/ndarray_range.go diff --git a/stat/errors.go b/stat/errors.go index bcf8e00..0722f9c 100644 --- a/stat/errors.go +++ b/stat/errors.go @@ -12,6 +12,7 @@ const ( var ( // ErrUnsupportedType 不支持的类型 ErrUnsupportedType = exception.New(errorTypeBase+0, "Unsupported type") + ErrRange = exception.New(errorTypeBase+1, "range error") ) func Throw(tv any) *exception.Exception { diff --git a/stat/ndarray.go b/stat/ndarray.go index fc59da7..d35e0d4 100644 --- a/stat/ndarray.go +++ b/stat/ndarray.go @@ -2,7 +2,6 @@ package stat import ( "github.com/mymmsc/gox/exception" - "reflect" ) type NDArray[T BaseType] []T @@ -59,38 +58,6 @@ func (self NDArray[T]) Records() []string { } -func (self NDArray[T]) Subset(start, end int, opt ...any) Series { - // 默认不copy - var __optCopy bool = false - if len(opt) > 0 { - // 第一个参数为是否copy - if _cp, ok := opt[0].(bool); ok { - __optCopy = _cp - } - } - var vs any - var rows int - vv := reflect.ValueOf(self.Values()) - vk := vv.Kind() - switch vk { - case reflect.Slice, reflect.Array: // 切片和数组同样的处理逻辑 - vvs := vv.Slice(start, end) - vs = vvs.Interface() - rows = vv.Len() - if __optCopy && rows > 0 { - vs = Clone(vs) - //vs = slices.Clone(vs) - } - rows = vvs.Len() - var d Series - d = NDArray[T](vs.([]T)) - return d - default: - // 其它类型忽略 - } - return self.Empty() -} - func (self NDArray[T]) Repeat(x any, repeats int) Series { var d any switch values := self.Values().(type) { @@ -109,6 +76,11 @@ func (self NDArray[T]) Repeat(x any, repeats int) Series { return NDArray[T](d.([]T)) } +func (self NDArray[T]) FillNa(v any, inplace bool) Series { + d := FillNa(self, v, inplace) + return NDArray[T](d) +} + func (self NDArray[T]) Shift(periods int) Series { values := self.Values().([]T) d := Shift(values, periods) @@ -130,11 +102,6 @@ func (self NDArray[T]) StdDev() DType { return self.Std() } -func (self NDArray[T]) FillNa(v any, inplace bool) Series { - d := FillNa(self, v, inplace) - return NDArray[T](d) -} - func (self NDArray[T]) Max() any { d := Max2(self) return d @@ -145,15 +112,6 @@ func (self NDArray[T]) Min() any { return d } -func (self NDArray[T]) Select(r ScopeLimit) Series { - start, end, err := r.Limits(self.Len()) - if err != nil { - return nil - } - series := self.Subset(start, end+1) - return series -} - func (self NDArray[T]) Apply(f func(idx int, v any)) { //inplace := true for i, v := range self { diff --git a/stat/ndarray_range.go b/stat/ndarray_range.go new file mode 100644 index 0000000..30ae184 --- /dev/null +++ b/stat/ndarray_range.go @@ -0,0 +1,63 @@ +package stat + +import "reflect" + +func (self NDArray[T]) IndexOf(index int, opt ...any) (any, error) { + if index < 0 || index >= self.Len() { + return nil, ErrRange + } + var __optInplace = false + if len(opt) > 0 { + // 第一个参数为是否copy + if _cp, ok := opt[0].(bool); ok { + __optInplace = _cp + } + } + value := self[index] + if __optInplace { + return reflect.ValueOf(value).Elem(), nil + } + return value, nil + +} + +func (self NDArray[T]) Subset(start, end int, opt ...any) Series { + // 默认不copy + var __optCopy bool = false + if len(opt) > 0 { + // 第一个参数为是否copy + if _cp, ok := opt[0].(bool); ok { + __optCopy = _cp + } + } + var vs any + var rows int + vv := reflect.ValueOf(self.Values()) + vk := vv.Kind() + switch vk { + case reflect.Slice, reflect.Array: // 切片和数组同样的处理逻辑 + vvs := vv.Slice(start, end) + vs = vvs.Interface() + rows = vv.Len() + if __optCopy && rows > 0 { + vs = Clone(vs) + //vs = slices.Clone(vs) + } + rows = vvs.Len() + var d Series + d = NDArray[T](vs.([]T)) + return d + default: + // 其它类型忽略 + } + return self.Empty() +} + +func (self NDArray[T]) Select(r ScopeLimit) Series { + start, end, err := r.Limits(self.Len()) + if err != nil { + return nil + } + series := self.Subset(start, end+1) + return series +} diff --git a/stat/series.go b/stat/series.go index 0a73a62..cfefef4 100644 --- a/stat/series.go +++ b/stat/series.go @@ -54,6 +54,7 @@ type Series interface { // Records returns the elements of a Series as a []string Records() []string + IndexOf(index int, opt ...any) (any, error) // Subset 获取子集 Subset(start, end int, opt ...any) Series // Repeat elements of an array. -- Gitee From 4cbb17108781c3af5b16082ca1ea7fee5857deda Mon Sep 17 00:00:00 2001 From: wangfeng Date: Tue, 21 Feb 2023 08:27:03 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E4=BF=AE=E8=AE=A2IndexOf=E5=87=BD=E6=95=B0?= =?UTF-8?q?=E7=9A=84=E8=BF=94=E5=9B=9E=E5=80=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- stat/ndarray_range.go | 12 +++++++----- stat/series.go | 3 ++- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/stat/ndarray_range.go b/stat/ndarray_range.go index 30ae184..b25cf45 100644 --- a/stat/ndarray_range.go +++ b/stat/ndarray_range.go @@ -2,9 +2,11 @@ package stat import "reflect" -func (self NDArray[T]) IndexOf(index int, opt ...any) (any, error) { - if index < 0 || index >= self.Len() { - return nil, ErrRange +func (self NDArray[T]) IndexOf(index int, opt ...any) any { + if index < 0 { + index = self.Len() + index + } else { + index = self.Len() - 1 } var __optInplace = false if len(opt) > 0 { @@ -15,9 +17,9 @@ func (self NDArray[T]) IndexOf(index int, opt ...any) (any, error) { } value := self[index] if __optInplace { - return reflect.ValueOf(value).Elem(), nil + return reflect.ValueOf(value).Elem() } - return value, nil + return value } diff --git a/stat/series.go b/stat/series.go index cfefef4..504c930 100644 --- a/stat/series.go +++ b/stat/series.go @@ -54,7 +54,8 @@ type Series interface { // Records returns the elements of a Series as a []string Records() []string - IndexOf(index int, opt ...any) (any, error) + // IndexOf 取一条记录, index<0时, 从后往前取值 + IndexOf(index int, opt ...any) any // Subset 获取子集 Subset(start, end int, opt ...any) Series // Repeat elements of an array. -- Gitee From 6b095f863717855d87e6517a111c99c9cf892e07 Mon Sep 17 00:00:00 2001 From: wangfeng Date: Tue, 21 Feb 2023 09:19:18 +0800 Subject: [PATCH 3/3] =?UTF-8?q?#I6GHRR=20=E5=AE=9E=E7=8E=B0IndexOf?= =?UTF-8?q?=E6=96=B9=E6=B3=95,=20=E6=94=AF=E6=8C=81=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dataframe_subset.go | 25 +++++++++++++++++++++++++ dataframe_subset_test.go | 38 ++++++++++++++++++++++++++++++++++++++ series_range.go | 21 +++++++++++++++++++++ stat/ndarray_range.go | 11 ++++++----- stat/ndarray_range_test.go | 18 ++++++++++++++++++ 5 files changed, 108 insertions(+), 5 deletions(-) create mode 100644 dataframe_subset_test.go create mode 100644 stat/ndarray_range_test.go diff --git a/dataframe_subset.go b/dataframe_subset.go index 165c57e..747f277 100644 --- a/dataframe_subset.go +++ b/dataframe_subset.go @@ -87,3 +87,28 @@ func (self DataFrame) Concat(dfb DataFrame) DataFrame { } return NewDataFrame(expandedSeries...) } + +// IndexOf 取一条记录 +// +// idx 为负值时从后往前取 +func (self DataFrame) IndexOf(idx int, opt ...any) map[string]any { + one := map[string]any{} + if idx < 0 { + idx = self.Nrow() + idx + } else if idx >= self.Nrow() { + idx = self.Nrow() - 1 + } + var __optInplace = false + if len(opt) > 0 { + // 第一个参数为是否copy + if _opt, ok := opt[0].(bool); ok { + __optInplace = _opt + } + } + for _, series := range self.columns { + key := series.Name() + value := series.IndexOf(idx, __optInplace) + one[key] = value + } + return one +} diff --git a/dataframe_subset_test.go b/dataframe_subset_test.go new file mode 100644 index 0000000..4094079 --- /dev/null +++ b/dataframe_subset_test.go @@ -0,0 +1,38 @@ +package pandas + +import ( + "fmt" + "reflect" + "testing" +) + +func TestDataFrame_IndexOf(t *testing.T) { + type testStruct struct { + A string + B int + C bool + D float64 + } + data := []testStruct{ + {"a", 1, true, 0.0}, + {"b", 2, false, 0.5}, + } + df1 := LoadStructs(data) + //fmt.Println(df1) + + // 增加1列 + s_e := GenericSeries[string]("x", "a0", "a1", "a2", "a3", "a4") + df2 := df1.Join(s_e) + //fmt.Println(df2) + df := df2.Select([]string{"A"}) + fmt.Println(df) + m := df.IndexOf(1, true) + a, ok := m["A"] + if ok { + fmt.Println(a) + if v, ok := a.(reflect.Value); ok { + v.SetString("1") + } + fmt.Println(df) + } +} diff --git a/series_range.go b/series_range.go index 7a469bb..01179aa 100644 --- a/series_range.go +++ b/series_range.go @@ -124,3 +124,24 @@ func (self *NDFrame) Select(r stat.ScopeLimit) stat.Series { series := self.Subset(start, end+1) return series } + +func (self *NDFrame) IndexOf(index int, opt ...any) any { + if index < 0 { + index = self.Len() + index + } else if index >= self.Len() { + index = self.Len() - 1 + } + var __optInplace = false + if len(opt) > 0 { + // 第一个参数为是否copy + if _opt, ok := opt[0].(bool); ok { + __optInplace = _opt + } + } + if !__optInplace { + return reflect.ValueOf(self.Values()).Index(index).Interface() + } + vv := reflect.ValueOf(self.values) + tv := vv.Index(index) + return tv +} diff --git a/stat/ndarray_range.go b/stat/ndarray_range.go index b25cf45..1b24c96 100644 --- a/stat/ndarray_range.go +++ b/stat/ndarray_range.go @@ -5,19 +5,20 @@ import "reflect" func (self NDArray[T]) IndexOf(index int, opt ...any) any { if index < 0 { index = self.Len() + index - } else { + } else if index >= self.Len() { index = self.Len() - 1 } var __optInplace = false if len(opt) > 0 { - // 第一个参数为是否copy - if _cp, ok := opt[0].(bool); ok { - __optInplace = _cp + // 第一个参数为是否替换 + if _opt, ok := opt[0].(bool); ok { + __optInplace = _opt } } value := self[index] if __optInplace { - return reflect.ValueOf(value).Elem() + mv := reflect.ValueOf(self.Values()) + return mv.Index(index) } return value diff --git a/stat/ndarray_range_test.go b/stat/ndarray_range_test.go new file mode 100644 index 0000000..dba66dd --- /dev/null +++ b/stat/ndarray_range_test.go @@ -0,0 +1,18 @@ +package stat + +import ( + "fmt" + "reflect" + "testing" +) + +func TestNDArray_IndexOf(t *testing.T) { + d1 := []string{"a0", "a1", "a2", "a3", "a4"} + s1 := NDArray[string](d1) + fmt.Println(s1) + v1 := s1.IndexOf(1, true) + if mv, ok := v1.(reflect.Value); ok { + mv.SetString("1") + fmt.Println(s1) + } +} -- Gitee