1 Star 0 Fork 0

mqyqingkong / autowire

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
autowire.go 5.94 KB
一键复制 编辑 原始数据 按行查看 历史
mqyqingkong 提交于 2023-10-08 17:49 . Generics Method
package autowire
import (
"errors"
"fmt"
"log"
"reflect"
"sync"
)
var ErrBeanNotFound = errors.New("can not find bean")
// BeanFactory manages beans.
type BeanFactory interface {
GetBeanByName(beanName string) (bean any, found bool)
GetBeanByType(beanType reflect.Type) (bean any, err error)
GetBeansByType(beanType reflect.Type) (beans []any)
RegisterBeanByName(beanName string, bean any) error
RegisterBeans(beans ...any)
RegisterBeanCreatorByType(beanType reflect.Type, creator BeanCreator) error
RegisterBeanCreatorByName(beanName string, creator BeanCreator) error
// Autowire will wire all beans.
Autowire() error
// AutowireBean will wire the specific bean.
AutowireBean(bean any) error
}
func NewBeanFactory() BeanFactory {
bf := &beanFactory{
lock: sync.RWMutex{},
beanByName: map[string]any{},
beanByType: map[reflect.Type][]any{},
beanByTypeCreator: map[reflect.Type]BeanCreator{},
beanByNameCreator: map[string]BeanCreator{},
}
return bf
}
// BeanCreator creates bean.
type BeanCreator func(bf BeanFactory) any
type beanFactory struct {
lock sync.RWMutex
beanByName map[string]any
beanByType map[reflect.Type][]any
beanByTypeCreator map[reflect.Type]BeanCreator
beanByNameCreator map[string]BeanCreator
}
func (bf *beanFactory) GetBeanByName(beanName string) (bean any, found bool) {
bf.lock.RLock()
bean, found = bf.beanByName[beanName]
bf.lock.RUnlock()
if found {
return
}
bf.lock.RLock()
creator, found := bf.beanByNameCreator[beanName]
bf.lock.RUnlock()
if found {
bean = creator(bf)
}
return
}
func (bf *beanFactory) GetBeanByType(beanType reflect.Type) (bean any, err error) {
bf.lock.RLock()
defer bf.lock.RUnlock()
beans := bf.beanByType[beanType]
if len(beans) == 1 {
bean = beans[0]
err = nil
return
}
if len(beans) > 1 {
err = fmt.Errorf("find more than one bean by specific type:%s", beanType.Name())
return
}
beans = bf.GetBeansByType(beanType)
if len(beans) == 1 {
bean = beans[0]
err = nil
return
}
if len(beans) > 1 {
err = fmt.Errorf("find more than one bean by type:%s", beanType.Name())
return
}
err = ErrBeanNotFound
return
}
func (bf *beanFactory) GetBeansByType(beanType reflect.Type) (beans []any) {
typeIsInterface := beanType.Kind() == reflect.Interface
bf.lock.RLock()
for t, bs := range bf.beanByType {
if t == beanType || (typeIsInterface && t.Implements(beanType)) {
beans = append(beans, bs...)
}
}
creator, found := bf.beanByTypeCreator[beanType]
bf.lock.RUnlock()
if found {
bean := creator(bf)
beans = append(beans, bean)
return
}
if typeIsInterface {
var creators []BeanCreator
bf.lock.RLock()
for t, c := range bf.beanByTypeCreator {
if t.Implements(beanType) {
creators = append(creators, c)
}
}
bf.lock.RUnlock()
for i := 0; i < len(creators); i++ {
beans = append(beans, creators[i](bf))
}
}
return
}
func (bf *beanFactory) RegisterBeanByName(beanName string, bean any) error {
bf.lock.Lock()
defer bf.lock.Unlock()
if _, found := bf.beanByName[beanName]; found {
return fmt.Errorf("bean name[%s] duplicated", beanName)
}
bf.beanByName[beanName] = bean
beanType := reflect.TypeOf(bean)
bf.beanByType[beanType] = append(bf.beanByType[beanType], bean)
return nil
}
func (bf *beanFactory) RegisterBeans(beans ...any) {
bf.lock.Lock()
defer bf.lock.Unlock()
for i := 0; i < len(beans); i++ {
beanType := reflect.TypeOf(beans[i])
bf.beanByType[beanType] = append(bf.beanByType[beanType], beans[i])
}
}
func (bf *beanFactory) RegisterBeanCreatorByType(beanType reflect.Type, creator BeanCreator) error {
bf.lock.Lock()
defer bf.lock.Unlock()
if _, found := bf.beanByTypeCreator[beanType]; found {
return fmt.Errorf("bean creator[beanType:%s] duplicated", beanType.Name())
}
bf.beanByTypeCreator[beanType] = creator
return nil
}
func (bf *beanFactory) RegisterBeanCreatorByName(beanName string, creator BeanCreator) error {
bf.lock.Lock()
defer bf.lock.Unlock()
if _, found := bf.beanByNameCreator[beanName]; found {
return fmt.Errorf("bean creator[beanName:%s] duplicated", beanName)
}
bf.beanByNameCreator[beanName] = creator
return nil
}
func (bf *beanFactory) Autowire() error {
for _, beans := range bf.beanByType {
for _, bean := range beans {
err := bf.AutowireBean(bean)
if err != nil {
return err
}
}
}
return nil
}
func (bf *beanFactory) AutowireBean(bean any) error {
rv := reflect.ValueOf(bean)
if rv.Kind() != reflect.Pointer {
return fmt.Errorf("bean[%s] must be a pointer", bean)
}
rve := rv.Elem()
beanType := reflect.TypeOf(bean)
for i := 0; i < rve.NumField(); i++ {
fv := rve.Field(i)
if fv.CanSet() && fv.IsZero() {
field := beanType.Elem().Field(i)
beanName := field.Tag.Get("bean")
if len(beanName) > 0 {
// wire by name
b, found := bf.GetBeanByName(beanName)
if found {
fv.Set(reflect.ValueOf(b))
continue
}
}
// wire by type
fk := fv.Kind()
if fk == reflect.Interface || fk == reflect.Pointer || fk == reflect.Struct {
ft := fv.Type()
b, err := bf.GetBeanByType(ft)
if err == ErrBeanNotFound {
log.Printf(`can not find bean by type:%v, when autowire bean:%v, ignore the field:"%s"`, ft, beanType, field.Name)
continue
}
if err == nil {
fv.Set(reflect.ValueOf(b))
} else {
return err
}
}
}
}
return nil
}
func GetReflectType[T any]() reflect.Type {
ret := reflect.TypeOf((*T)(nil)).Elem()
return ret
}
func GetBeanByType[T any](bf BeanFactory) (bean T, err error) {
b, err := bf.GetBeanByType(GetReflectType[T]())
if err != nil {
return
}
bean = b.(T)
return
}
func GetBeansByType[T any](bf BeanFactory) (beans []T) {
bs := bf.GetBeansByType(GetReflectType[T]())
beans = make([]T, len(bs))
for i := 0; i < len(bs); i++ {
beans[i] = bs[i].(T)
}
return
}
func RegisterBeanCreatorByType[T any](bf BeanFactory, creator BeanCreator) error {
return bf.RegisterBeanCreatorByType(GetReflectType[T](), creator)
}
1
https://gitee.com/mqyqingkong/autowire.git
git@gitee.com:mqyqingkong/autowire.git
mqyqingkong
autowire
autowire
v0.0.2

搜索帮助