文章目录
- 前言
- 基本使用
- 初始化db实例
- 定义model
- 增删改查
- 数据结构
- gorm.DB
- Statement
- Schema元数据
- clone
- 初始化
- 初始化DB
- 初始化dialector
- 用dialector初始化db
- 注册crud函数
- 执行器processor
- 注册callback
- Clause抽象
- 解析元数据
- 解析schema
- 解析field
- 总结
前言
上一篇文章介绍了什么是ORM框架,以及为啥需要ORM框架。从这篇文章开始将介绍gorm的底层实现原理
本文基于源码:https://github.com/go-gorm/gorm,版本:1.25.12
基本使用
关于gorm的更多详细的介绍,可以看官方文档 https://gorm.io/zh_CN/docs/index.html
这里简单介绍下其基本使用:
初始化db实例
- 通过
gorm.Open
创建一个db实例,都需的所有操作都在这个实例上 - 其参数为
mysql.Open(连接mysql的dsn)
创建的方言dialector
import (
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
func TestDB(t *testing.T) {
db, err := gorm.Open(mysql.Open("账号:密码@(连接地址)/数据库名称"))
if err != nil {
panic(err)
}
}
定义model
基于ORM的定义,一个model对应db里的一张表,model里的字段对应表里的列
如果结构体实现了TableName
方法,gorm使用该方法的返回值当做表名
默认使用字段名的驼峰转下划线当做列名。如果需要使用其他列名,需要通过标签gorm:"column:names"
来指定
type User struct {
Id int64
Name string `gorm:"column:names"`
Number int64
}
func (*User) TableName() string {
return "user"
}
增删改查
下面介绍如何使用gorm进行简单的增删改查:
增:新增一条记录,name=tom,number=10
func TestDB(t *testing.T) {
err = db.Create(&User{
Name: "tom",
Number: 10,
}).Error
if err != nil {
panic(err)
}
}
gorm会生成如下sql:INSERT INTO user (name,number) VALUES (?,?)
,参数为"tom",10
删:删除id=10的记录
func TestDB(t *testing.T) {
err = db.Where("id = ?", 10).Delete(&User{}).Error
}
会生成如下sql:DELETE FROM user WHERE id = ?
,参数为10
改:更新id=10的记录,将其name设为jerry,number设为20
db.Where("id = ?", 10).Updates(User{
Name: "jerry",
Number: 20,
})
生成的sql为:UPDATE user SET name=?,number=? WHERE id = ?
,参数为: "jerry", 20, 10
查:查询id=10的记录
var users []User
db.Where("id = ?", 10).Find(&users)
生成的sql为:SELECT * FROM user WHERE id = ?
,参数为:10
数据结构
gorm.DB
gorm.DB是gorm的核心结构,代表一次操作,拥有这次操作需要的的所有信息。基本上所有的api都定义在该结构体上
type DB struct {
*Config
Error error
RowsAffected int64
Statement *Statement
clone int
}
- Error:本次话执行中遇到的错误
- RowsAffected:本休操作的影响的行数
- 原生和mysql server交互,只有exec操作有会返回RowsAffected
- gorm给查询操作也赋予了RowsAffected语意,代表查出了多少行
- Statement:一次会话的状态信息,比如查询条件,查询结果
- clone:
- 如果clone = 1,代表始祖db实例
- 如果 clone = 0,代表从始祖DB克隆出来的会话
Statement
里面存储了一次会话中包含的状态信息,重点有以下这些字段:
- Clauses:sql的各个部分
- Schema:要操作表的元数据
- Selects,Omits:要select哪些字段,要omit哪些字段
- query操作中,Selects就是select哪些列
- update操作中,Selects就是要更新哪些列
- Dest:查询结果往哪写
- ConnPool:使用的连接池。普通模式下是 database/sql.DB,预编译模式为gorm.PreparedStmtDB
type Statement struct {
*DB
TableExpr *clause.Expr
// 表名
Table string
// 操作的po
Model interface{}
// ...
// 将结果反序列化到这
Dest interface{}
ReflectValue reflect.Value
// 各种条件
Clauses map[string]clause.Clause
BuildClauses []string
Distinct bool
// 要select哪些字段
Selects []string
// 要忽略哪些字段
Omits []string
// ...
// 连接池,通常情况下是 database/sql.DB. 在prepare模式为gorm.PreparedStmtDB
ConnPool ConnPool
// 表的schema
Schema *schema.Schema
// 请求生命周期控制
Context context.Context
// 如果数据未找到,是否要抛出recordNotFound
RaiseErrorOnNotFound bool
// 是否要跳过hooks
SkipHooks bool
// 要执行的sql
SQL strings.Builder
// sql的参数
Vars []interface{}
// ...
}
Schema元数据
schema代表一张表的元数据信息,通过解析一个model结构体得到,用于构建sql,结果集处理。主要包含以下信息:
- Table:根据model解析出的表名。一般来说表名和结构体名都不一样,需要通过实现TableName方法指定表名
- PrioritizedPrimaryField:主键
- Fields:有哪些字段
- FieldsByDBName:根据在db中的字段名找field,用于加速解析查询结果到结构体中
type Schema struct {
// model的name
Name string
// model的类型
ModelType reflect.Type
// 表名
Table string
// 主键
PrioritizedPrimaryField *Field
DBNames []string
PrimaryFields []*Field
PrimaryFieldDBNames []string
// 字段
Fields []*Field
FieldsByName map[string]*Field
FieldsByDBName map[string]*Field
// 被改写的clause:用于软删除,本文暂时不介绍
CreateClauses []clause.Interface
QueryClauses []clause.Interface
UpdateClauses []clause.Interface
DeleteClauses []clause.Interface
// 是否有各种回调
BeforeCreate, AfterCreate bool
BeforeUpdate, AfterUpdate bool
BeforeDelete, AfterDelete bool
BeforeSave, AfterSave bool
cacheStore *sync.Map
}
Field代表一个字段的元数据,主要根据结构体中定义的gorm tag解析出来
常见的用法为:
- 当结构体的字段名和db中列名不一致时,通过
gorm:"column:列名"
指定在db中的列名 - 通过
gorm:"autoUpdateTime:milli"
,gorm:"autoCreateTime"
指定某字段需要自动填充更新时间,自动填充创建时间
type Field struct {
// 字段名
Name string
// 列名
DBName string
// ...
AutoCreateTime TimeType
AutoUpdateTime TimeType
// ...
}
TimeType有这些枚举,例如当AutoCreateTime=3时,表示创建一条记录时自动设置当前时间的毫秒时间戳
const (
UnixTime TimeType = 1
UnixSecond TimeType = 2
UnixMillisecond TimeType = 3
UnixNanosecond TimeType = 4
)
clone
基于gorm.Open创建出来的db实例执行各种操作时,首先要clone一份db出来,当做这次操作的会话。之后对新的db的链式调用,在新的db上追加各种状态信息,做到会话之间数据隔离
会话信息主要存在db.Statement,因此clone的主要操作是new一个新的Statement出来
db.clone代表是不是始祖db实例,以及是否需要克隆
clone=1
:代表是始祖db实例,基于这种db调api时,需要克隆出一个新的db实例,专门给这次操作用clone=0
:代表是从始祖db实例克隆出来的实例,基于这种db调api时,不需要克隆
gorm.Open时,生成的始祖db的clone为1:
func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
// ...
db = &DB{Config: config, clone: 1}
// ...
}
基于始祖db调任何方法时,会创建新的db实例,其clone=0
如果已经clone过了(clone=0)就不会再次clone
func (db *DB) getInstance() *DB {
if db.clone > 0 {
// 需要clone,db.clone字段变成0,这样下次进入getInstance方法时就不会clone
tx := &DB{Config: db.Config, Error: db.Error}
// 初始化时,clone=1
if db.clone == 1 {
// 第一次clone
tx.Statement = &Statement{
DB: tx,
ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context,
// 使用新的clause,vars
Clauses: map[string]clause.Clause{},
Vars: make([]interface{}, 0, 8),
SkipHooks: db.Statement.SkipHooks,
}
if db.Config.PropagateUnscoped {
tx.Statement.Unscoped = db.Statement.Unscoped
}
} else {
// with clone statement
tx.Statement = db.Statement.clone()
tx.Statement.DB = tx
}
return tx
}
return db
}
初始化
接下来介绍gorm.Open的核心流程
初始化DB
核心步骤如下:
- 传入dialector,本文使用 mysql的Dialector
- 初始化
clone=1
,这样后面操作时,发现clone=1就表示需要新构造一个 - 用Dialector初始化db
- 完成
connPool
的创建以及各类processor fns
函数的注册
- 完成
- 如果启用了prepare模式,需要使用preparedStmtDB当做connPool
- 构造statement实例
- 根据策略,决定是否通过 ping 请求测试连接
- 返回db实例
func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
config := &Config{}
// ...
if config.NamingStrategy == nil {
config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64
}
// ...
// 设置Dialector
if dialector != nil {
config.Dialector = dialector
}
// ...
if config.cacheStore == nil {
config.cacheStore = &sync.Map{}
}
// 初始化clone=1,这样后面操作时,发现clone=1就表示需要新构造一个
db = &DB{Config: config, clone: 1}
// 完成 callbacks 中 crud 等几类processor的创建
db.callbacks = initializeCallbacks(db)
if config.Dialector != nil {
// 用Dialector初始化db
err = config.Dialector.Initialize(db)
if err != nil {
if db, _ := db.DB(); db != nil {
_ = db.Close()
}
}
}
// 如果启用了 prepare 模式,需要使用 preparedStmtDB替换调原来的sql.DB
if config.PrepareStmt {
preparedStmt := NewPreparedStmtDB(db.ConnPool)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
db.ConnPool = preparedStmt
}
// 构造statement实例
db.Statement = &Statement{
DB: db,
ConnPool: db.ConnPool,
Context: context.Background(),
Clauses: map[string]clause.Clause{},
}
// 根据策略,决定是否ping测试和db server的连接
if err == nil && !config.DisableAutomaticPing {
if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
err = pinger.Ping()
}
}
if err != nil {
config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err)
}
return
}
初始化dialector
传入连接mysql的dsn
func Open(dsn string) gorm.Dialector {
dsnConf, _ := mysql.ParseDSN(dsn)
return &Dialector{Config: &Config{DSN: dsn, DSNConfig: dsnConf}}
}
用dialector初始化db
- 初始化
sql.DB
,作为gorm.DB的连接池 - 注册crud的函数链,设置需要执行的
clauses
- 设置一些mysql独有的
clauseBuilde
,以后遇到这里面的clause就用这个builder构造- mysql dialector自己定义了
ON CONFLICT
和VALUES
两个clause的实现
- mysql dialector自己定义了
关于什么是clause,下文会详细说明
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// ...
if dialector.Conn != nil {
db.ConnPool = dialector.Conn
} else {
// 初始化一个sql.DB,作为gorm.DB的连接池
db.ConnPool, err = sql.Open(dialector.DriverName, dialector.DSN)
if err != nil {
return err
}
}
withReturning := false
if !dialector.Config.SkipInitializeWithVersion {
err = db.ConnPool.QueryRowContext(context.Background(), "SELECT VERSION()").Scan(&dialector.ServerVersion)
if err != nil {
return err
}
// 根据mysql的版本做一些特殊配置
}
// register callbacks
callbackConfig := &callbacks.Config{
CreateClauses: CreateClauses,
QueryClauses: QueryClauses,
UpdateClauses: UpdateClauses,
DeleteClauses: DeleteClauses,
}
// ...
// 注册crud的函数链,设置需要执行的clauses
callbacks.RegisterDefaultCallbacks(db, callbackConfig)
// 设置一些mysql独有的clauseBuilder,以后遇到这里面的clause就用这个builder构造
for k, v := range dialector.ClauseBuilders() {
db.ClauseBuilders[k] = v
}
return
}
注册crud函数
给CURD注册函数链,设置需要执行的clauses
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
// ...
createCallback := db.Callback().Create()
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
// 注册create的函数链
createCallback.Register("gorm:before_create", BeforeCreate)
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true))
createCallback.Register("gorm:create", Create(config))
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
createCallback.Register("gorm:after_create", AfterCreate)
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
// 设置create的clauses
createCallback.Clauses = config.CreateClauses
queryCallback := db.Callback().Query()
// 注册query的函数链
queryCallback.Register("gorm:query", Query)
queryCallback.Register("gorm:preload", Preload)
queryCallback.Register("gorm:after_query", AfterQuery)
// 设置query的clauses
queryCallback.Clauses = config.QueryClauses
// 注册delete,update的函数链,设置delete,updata的clauses
}
执行器processor
gorm将crud要用的逻辑封装到processor
type callbacks struct {
processors map[string]*processor
}
初始化:
gorm.Open
-> initializeCallbacks:
func initializeCallbacks(db *DB) *callbacks {
return &callbacks{
processors: map[string]*processor{
"create": {db: db},
"query": {db: db},
"update": {db: db},
"delete": {db: db},
"row": {db: db},
"raw": {db: db},
},
}
}
执行请求时,会根据crud的类型,从callbacks中获取processor,例如查询操作,通过Query方法获取:
func (cs *callbacks) Query() *processor {
return cs.processors["query"]
}
type processor struct {
// 从属的gorm.DB实例
db *DB
// 拼接sql的关键字顺序, 例如Query,固定为 SELECT,FROM,WHERE,GROUP BY, ORDER BY, LIMIT, FOR
Clauses []string
// 执行函数链
fns []func(*DB)
callbacks []*callback
}
所有请求遵循的处理思路都是,首先根据其从属的 crud 类型,找到对应的 processor,然后调用 processor 的 Execute 方法,执行该 processor 下的 fns 函数链
注册callback
回到grom.Open初始化流程中,初始化注册Callback时:
以Query为例,会按顺序注册Query,Preload,AfterQuery方法
if len(config.QueryClauses) == 0 {
config.QueryClauses = queryClauses
}
// ...
queryCallback := db.Callback().Query()
queryCallback.Register("gorm:query", Query)
queryCallback.Register("gorm:preload", Preload)
queryCallback.Register("gorm:after_query", AfterQuery)
// 设置query的clauses
queryCallback.Clauses = config.QueryClauses
处理函数链到processor中:
func (p *processor) Register(name string, fn func(*DB)) error {
return (&callback{processor: p}).Register(name, fn)
}
func (c *callback) Register(name string, fn func(*DB)) error {
c.name = name
c.handler = fn
// 将自己注册到 processor的callbacks中
c.processor.callbacks = append(c.processor.callbacks, c)
// 按照callback定义的before,after对callbacks排序,将结果放到processor.fns中
return c.processor.compile()
}
compile方法的作用:将callbacks中的方法按照name排序,整理到fns中
// 将callbacks整理到fns中
func (p *processor) compile() (err error) {
// ...
// 对p.callback排序,结果收集到p.fns中
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
}
return
}
sortCallbacks:根据每个callback定义的before,after,决定怎么排序
func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
var (
names, sorted []string
sortCallback func(*callback) error
)
/**
before为*的,排到后面,方便后面处理*的callback
after为*的,排到后面
*/
sort.SliceStable(cs, func(i, j int) bool {
if cs[j].before == "*" && cs[i].before != "*" {
return true
}
if cs[j].after == "*" && cs[i].after != "*" {
return true
}
return false
})
// 收集所有name
for _, c := range cs {
// ...
names = append(names, c.name)
}
sortCallback = func(c *callback) error {
// 定义了before
if c.before != "" {
// before是*,加到所有的callback后面
if c.before == "*" && len(sorted) > 0 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
sorted = append([]string{c.name}, sorted...)
}
// 要before的哪个name在sorted中存在
} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
// 自己不存在
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
// 将自己添加到c.before的前面
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
} else if curIdx > sortedIdx {
return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before)
}
// 要before的那个name在sorted中存在,但在name中存在
// 那就在那个name以后添加时,添加到c的后面
} else if idx := getRIndex(names, c.before); idx != -1 {
// if before callback exists
cs[idx].after = c.name
}
}
// 定义了after
if c.after != "" {
// 类似before,这里就不展开了
}
// 没添加过,添加到尾部
if getRIndex(sorted, c.name) == -1 {
sorted = append(sorted, c.name)
}
return nil
}
// 根据callback的before,after排序
for _, c := range cs {
if err = sortCallback(c); err != nil {
return
}
}
// 将排序好的结果添加到fns中
for _, name := range sorted {
if idx := getRIndex(names, name); !cs[idx].remove {
fns = append(fns, cs[idx].handler)
}
}
return
}
Clause抽象
gorm相当于一个组装厂,采用分而治之的策略:它不知道怎么构造sql,但每个部分知道怎么构造自己
具体来说,gorm将sql的各个部分都抽象成一个clause
,
例如 SELECT * FROM user WHERE id < 10 ORDER by id
这条sql,就包含了 SELECT、FROM、WHERE 和 ORDER
四个 clause
在mysql中,每种操作拥有的clause,以及其顺序如下:
var (
CreateClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
QueryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
UpdateClauses = []string{"UPDATE", "SET", "WHERE", "ORDER BY", "LIMIT"}
DeleteClauses = []string{"DELETE", "FROM", "WHERE", "ORDER BY", "LIMIT"}
)
生成sql的入口在Statement.Build,入参为curd中某一个processor的BuildClauses
func (stmt *Statement) Build(clauses ...string) {
var firstClauseWritten bool
for _, name := range clauses {
if c, ok := stmt.Clauses[name]; ok {
if firstClauseWritten {
stmt.WriteByte(' ')
}
firstClauseWritten = true
if b, ok := stmt.DB.ClauseBuilders[name]; ok {
b(c, stmt)
} else {
// 每个clause build自己负责的一部分
c.Build(stmt)
}
}
}
}
以query为例,会遵循以下的顺序:
SELECT->FROM->WHERE->GROUP BY->ORDER BY->LIMIT->FOR
依次从 statement.Clauses
中获取对应的 clause,调用 clause.Build
方法,构造出sql的各个部分,组装到statement.SQL 字段中
当使用方通过链式操作克隆 DB时,对应追加的状态信息就会生成一个新的 clause,追加到 statement 对应的 clauses 集合当中
请求实际执行时,会取出 clauses 集合,拼接生成完整的 sql
接下来看看clause的接口抽象:
- Interface:描述一个clause,sql的每个部分都是一个clause
- 例如:select,where,from
- Expiression:表达式,核心是build方法,也就是构造自己这部分的sql
- 所有的Interface也都是Expiression,因为都有Build方法
- 各个比较符都有一个Expression实现,例如Eq,In
- Not,And,Or被认为是一系列Expression的集合
- Builder:可以看作一个[]byte,Builder.WriteString方法就是往其后面追加字符串
- 实现类为Statement,实际就是往Statement.SQL追加字符串
- 实现类为Statement,实际就是往Statement.SQL追加字符串
这种接口设计,使得gorm有非常强的扩展性,只要实现这些接口,可以把各种sql都构造出来
但缺点是:代码可读性降低了。如果要看某种操作的源码实现,需要看这种操作有哪些clause,然后看每种clause内部是怎么构造sql的
这也是代码设计的取舍,高扩展性往往带来复杂的接口机制,使得代码代码可读性低
每个clause都实现了clause.Interface接口,该接口每个方法的含义如下:
type Interface interface {
// clause的name
Name() string
// 该clause怎么构造自己这部分sql
Build(Builder)
// 添加多个相同name的clause时,怎么进行合并
MergeClause(*Clause)
}
以Select为例,其具体实现如下:
type Select struct {
Distinct bool
Columns []Column
Expression Expression
}
func (s Select) Name() string {
return "SELECT"
}
func (s Select) Build(builder Builder) {
// 查询指定的列
if len(s.Columns) > 0 {
if s.Distinct {
builder.WriteString("DISTINCT ")
}
// 拼接指定的列
for idx, column := range s.Columns {
if idx > 0 {
builder.WriteByte(',')
}
// 前后加反引号
builder.WriteQuoted(column)
}
// 否则就是select *
} else {
builder.WriteByte('*')
}
}
每个clause被Statement持有,放在Clauses字段中
type Statement struct {
// ...
Clauses map[string]clause.Clause
// ...
}
clause.Clause结构体如下,一般在Expression字段持有Select,From,Where等具体的clause
type Clause struct {
Name string // WHERE
BeforeExpression Expression // 没有使用
AfterNameExpression Expression // 没有使用
AfterExpression Expression // 没有使用
// 一般在这里持有Select,From,Where等具体结构
Expression Expression
Builder ClauseBuilder
}
Expression只有Build方法,因此所有实现了clause.Interface接口的,也实现了Expression接口
type Expression interface {
Build(builder Builder)
}
往Statement添加clause的方法如下:
func (stmt *Statement) AddClause(v clause.Interface) {
if optimizer, ok := v.(StatementModifier); ok {
optimizer.ModifyStatement(stmt)
} else {
name := v.Name()
c := stmt.Clauses[name]
c.Name = name
// 将v和已有的clause合并
v.MergeClause(&c)
// 合并结果写到Statement.Clause中
stmt.Clauses[name] = c
}
}
Clause.Build方法如下:
- 往builder写
clause.Name
- 调具体clause.Build,例如如果
Clause.Expression
是Select,就调Select.Build
func (c Clause) Build(builder Builder) {
if c.Builder != nil {
c.Builder(c, builder)
} else if c.Expression != nil {
// 前置处理
// ...
// 写name
if c.Name != "" {
builder.WriteString(c.Name)
builder.WriteByte(' ')
}
// ...
// 调具体clause.Build
// 例如:调Select.Build
c.Expression.Build(builder)
// ...
}
}
解析元数据
接下来看gorm如何把一个结构体解析成元数据
解析schema
入口为Statement.Parse
func (stmt *Statement) Parse(value interface{}) (err error) {
return stmt.ParseWithSpecialTableName(value, "")
}
func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) {
if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" {
// ...
}
return err
}
ParseWithSpecialTableName把dest解析成schema
将结构体解析成schema不是一个简单的操作,如果每次调gorm的方法都要解析一次,那太消耗性能了。因此参数中有个cacheStore
,在gorm.DB层面唯一。如果cache中有,就不会重复解析,直接拿来用
核心步骤如下:
- 如果dest是指针,接口,slice,得到其指向的具体类型
- 以
reflect.Type
作为key,去cacheStore
看是否解析过,如果是直接返回已经解析好的schema - 解析tableName:
- 如果实现了
Tabler
接口(有TableName() string方法),使用该返回值作为tableName - 否则根据struct的name先生成一个默认的tableName,规则为驼峰转下划线,转小写
- 如果实现了
- 解析schema中的每个字段
- 如果某个字段没有在tag指定
column name
, 使用默认的columnName:驼峰转下划线,转小写 - 将解析好的field加入3个结构中,用于后续根据对象转化为sql,和根据查询结果转化为对象时使用
DBNames []string
FieldsByName map[string]*Field
FieldsByDBName map[string]*Field
- 寻找name叫id的字段,当做主键
- 如果model定义了BeforeCreate,BeforeUpdate等方法,将schema中对应字段设为true
- 如果field实现了CRUD的ClausesInterface接口,加入schema中
- 主要用于deletedAt使用,本文不展开介绍
- 主要用于deletedAt使用,本文不展开介绍
func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) {
// ...
// 如果dest是指针,接口,slice,得到其指向的具体类型
value := reflect.ValueOf(dest)
if value.Kind() == reflect.Ptr && value.IsNil() {
value = reflect.New(value.Type().Elem())
}
modelType := reflect.Indirect(value).Type()
if modelType.Kind() == reflect.Interface {
modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type()
}
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
// 到这里,modelType一定是struct
if modelType.Kind() != reflect.Struct {
if modelType.PkgPath() == "" {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
}
var schemaCacheKey interface{}
if specialTableName != "" {
// ...
} else {
// 以reflect.Type作为key
schemaCacheKey = modelType
}
// 如果已经解析过schema,使用解析好的
if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
// 下面开始正式解析
modelValue := reflect.New(modelType)
// 根据struct的name先生成一个默认的tableName,规则为驼峰转下划线,转小写
tableName := namer.TableName(modelType.Name())
// 如果实现了Tabler接口(TableName() string方法),使用该返回值作为tableName
if tabler, ok := modelValue.Interface().(Tabler); ok {
tableName = tabler.TableName()
}
// ...
// 构建schema
schema := &Schema{
Name: modelType.Name(),
ModelType: modelType,
Table: tableName,
FieldsByName: map[string]*Field{},
FieldsByBindName: map[string]*Field{},
FieldsByDBName: map[string]*Field{},
Relationships: Relationships{Relations: map[string]*Relationship{}},
cacheStore: cacheStore,
namer: namer,
initialized: make(chan struct{}),
}
// close schema.initialized,代表初始化完成
defer close(schema.initialized)
// Load exist schema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
for i := 0; i < modelType.NumField(); i++ {
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
// 解析每个字段,收集到Fields中
if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)
} else {
schema.Fields = append(schema.Fields, field)
}
}
}
for _, field := range schema.Fields {
// 如果没有在tag指定column name, 使用默认的columnName:驼峰转下划线,转小写
if field.DBName == "" && field.DataType != "" {
field.DBName = namer.ColumnName(schema.Table, field.Name)
}
bindName := field.BindName()
if field.DBName != "" {
// nonexistence or shortest path or first appear prioritized if has permission
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
// 收集到DBNames数组中
schema.DBNames = append(schema.DBNames, field.DBName)
}
// 建立3个map,根据dbname,name查找field
schema.FieldsByDBName[field.DBName] = field
schema.FieldsByName[field.Name] = field
schema.FieldsByBindName[bindName] = field
// 收集主键
if v != nil && v.PrimaryKey {
for idx, f := range schema.PrimaryFields {
if f == v {
schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
}
}
}
if field.PrimaryKey {
schema.PrimaryFields = append(schema.PrimaryFields, field)
}
}
}
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
schema.FieldsByName[field.Name] = field
}
if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" {
schema.FieldsByBindName[bindName] = field
}
field.setupValuerAndSetter()
}
// 找name叫id的字段,当作主键
prioritizedPrimaryField := schema.LookUpField("id")
if prioritizedPrimaryField == nil {
prioritizedPrimaryField = schema.LookUpField("ID")
}
if prioritizedPrimaryField != nil {
if prioritizedPrimaryField.PrimaryKey {
schema.PrioritizedPrimaryField = prioritizedPrimaryField
} else if len(schema.PrimaryFields) == 0 {
prioritizedPrimaryField.PrimaryKey = true
// 设为 schema.PrioritizedPrimaryField
// id也会作为PrimaryFields
schema.PrimaryFields = append(schema.PrimaryFields, prioritizedPrimaryField)
}
}
// ...
// ...
if field := schema.PrioritizedPrimaryField; field != nil {
switch field.GORMDataType {
case Int, Uint:
if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok {
if !field.HasDefaultValue || field.DefaultValueInterface != nil {
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
}
field.HasDefaultValue = true
// 将id的AutoIncrement强制设为true
field.AutoIncrement = true
}
}
}
callbackTypes := []callbackType{
callbackTypeBeforeCreate, callbackTypeAfterCreate,
callbackTypeBeforeUpdate, callbackTypeAfterUpdate,
callbackTypeBeforeSave, callbackTypeAfterSave,
callbackTypeBeforeDelete, callbackTypeAfterDelete,
callbackTypeAfterFind,
}
// 如果model定义了BeforeCreate,BeforeUpdate等方法,将schema中对应字段设为true
for _, cbName := range callbackTypes {
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
switch methodValue.Type().String() {
case "func(*gorm.DB) error":
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
default:
// ...
}
}
}
// 将schema放入缓存
if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
defer func() {
if schema.err != nil {
logger.Default.Error(context.Background(), schema.err.Error())
cacheStore.Delete(modelType)
}
}()
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
for _, field := range schema.Fields {
// ...
fieldValue := reflect.New(field.IndirectFieldType)
fieldInterface := fieldValue.Interface()
// 如果field实现了CRUD的ClausesInterface接口,加入schema中
// 主要用于 deletedAt使用
if fc, ok := fieldInterface.(CreateClausesInterface); ok {
field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...)
}
if fc, ok := fieldInterface.(QueryClausesInterface); ok {
field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...)
}
if fc, ok := fieldInterface.(UpdateClausesInterface); ok {
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...)
}
if fc, ok := fieldInterface.(DeleteClausesInterface); ok {
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...)
}
}
}
return schema, schema.err
}
解析field
首先解析field中的tag
tagSetting = ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";")
ParseTagSetting:解析field的tag
func ParseTagSetting(str string, sep string) map[string]string {
settings := map[string]string{}
names := strings.Split(str, sep)
for i := 0; i < len(names); i++ {
j := i
// ...
values := strings.Split(names[j], ":")
k := strings.TrimSpace(strings.ToUpper(values[0]))
// key: 冒号前的值,value:冒号后的值
if len(values) >= 2 {
settings[k] = strings.Join(values[1:], ":")
// 如果没有冒号,key就等于value
} else if k != "" {
settings[k] = k
}
}
return settings
}
接下来看看如何将每个字段解析成field结构:
- 解析每个字段的tag,按;分割
- 如果在tag指定了列名,设置到
field.DBName
中 - 如果指定了自动设置创建时间 ,更新时间,设置到
field.AutoCreateTime,field.AutoUpdateTime
中
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
var (
err error
// gorm的tag,按;分割
tagSetting = ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";")
)
field := &Field{
Name: fieldStruct.Name,
// 如果在tag指定了列名
DBName: tagSetting["COLUMN"],
BindNames: []string{fieldStruct.Name},
EmbeddedBindNames: []string{fieldStruct.Name},
FieldType: fieldStruct.Type,
IndirectFieldType: fieldStruct.Type,
StructField: fieldStruct,
Tag: fieldStruct.Tag,
TagSettings: tagSetting,
Schema: schema,
Creatable: true,
Updatable: true,
Readable: true,
// tag里指定了PRIMARYKEY
PrimaryKey: utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]),
AutoIncrement: utils.CheckTruth(tagSetting["AUTOINCREMENT"]),
HasDefaultValue: utils.CheckTruth(tagSetting["AUTOINCREMENT"]),
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
Comment: tagSetting["COMMENT"],
AutoIncrementIncrement: DefaultAutoIncrementIncrement,
}
for field.IndirectFieldType.Kind() == reflect.Ptr {
field.IndirectFieldType = field.IndirectFieldType.Elem()
}
// ...
// ...
// 自增
if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok {
field.AutoIncrementIncrement, _ = strconv.ParseInt(num, 10, 64)
}
// 指定了默认值
if v, ok := field.TagSettings["DEFAULT"]; ok {
field.HasDefaultValue = true
field.DefaultValue = v
}
// ...
// 指定了自动设置创建时间
if v, ok := field.TagSettings["AUTOCREATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if field.DataType == Time {
field.AutoCreateTime = UnixTime
} else if strings.ToUpper(v) == "NANO" {
field.AutoCreateTime = UnixNanosecond
} else if strings.ToUpper(v) == "MILLI" {
// 使用毫秒时间戳的创建时间
field.AutoCreateTime = UnixMillisecond
} else {
field.AutoCreateTime = UnixSecond
}
}
// 指定了自动设置更新时间
if v, ok := field.TagSettings["AUTOUPDATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if field.DataType == Time {
field.AutoUpdateTime = UnixTime
} else if strings.ToUpper(v) == "NANO" {
field.AutoUpdateTime = UnixNanosecond
} else if strings.ToUpper(v) == "MILLI" {
field.AutoUpdateTime = UnixMillisecond
} else {
field.AutoUpdateTime = UnixSecond
}
}
// ...
return field
}
总结
本文介绍了gorm的整体设计和db的初始化,下一篇文章继续深入源码,介绍gorm怎么实现增删改查,和结果集处理