前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >基于sqlmock模拟数据库驱动编写Golang单元测试用例

基于sqlmock模拟数据库驱动编写Golang单元测试用例

原创
作者头像
KunkkaWu
发布2024-03-25 10:15:25
2680
发布2024-03-25 10:15:25

TOC

1. 场景

当前golang开发人员,在编写完成代码后,通常会写对应的单测来保证代码的健壮。对于很多大厂来说,编写单测已经是代码规范的一部分。基于官方提供的gomock框架和mockgen辅助工具就可以满足绝大部分场景,对于不能直接创建的依赖进行mock。但是,当我们编写API接口的时候,往往会对数据库进行操作,那么就需要支持对SQL进行mock的场景。

2. sqlmock 简介

在使用gorm等orm框架时,由于需要和数据库进行交互,并且CICD服务器在对代码检测的时候,往往也无法连接真正的数据库,因此编写单元测试,就会变得很困难。

go-sqlmock 本质是一个实现了 sql/driver 接口的 mock 库,它的设计目标是支持在测试中,模拟任何 sql driver 的行为,而不需要一个真正的数据库连接。因此,可以很好的解决这个问题。

3. 安装 go-sqlmock

代码语言:go
复制
go get github.com/DATA-DOG/go-sqlmock

4. sqlmock实战

首先我们模拟一下,在实际开发中会使用到gorm来对数据库查询操作。

目录结构:

  • main.go: 主程序,加载TagController,并注入已经初始化后的*gorm.DB, 然后调用TagController中的方法PrintTagList()
  • controller
    • tag.go: 包含控制器TagController的代码
  • model
    • tag.go: 包含model层,使用gorm需要定义的tag表的字段信息

4.1 定义接口

4.1.1 main.go

这里省去了,我们可能会用到的gin等框架负载的启动逻辑。假设main函数中,就是单纯的初始化gorm,并实例化控制器后,调用控制器的方法,获取数据库中的结果。

dsn连接信息,这里预设的是本地的数据库连接信息。

代码语言:go
复制
package main

import (
    "test/utils/sqlmock/controller"

    "gorm.io/driver/mysql"
    "gorm.io/gorm"
)

func main() {
    db := initDB()
    tagCtrl := controller.TagController{
        DB: db,
    }
    tagCtrl.PrintTagList()
}

func initDB() *gorm.DB {
    dsn := "root:@tcp(127.0.0.1:3306)/registration-service?charset=utf8&parseTime=true&loc=Asia%2FShanghai"
    db, err := gorm.Open(mysql.Open(dsn))
    if err != nil {
        panic(err)
    }
    return db
}
4.1.2 controller/tag.go

类似于一般的框架,MVC架构下,通常会首先进入controller中,然后通过controller来访问model层的代码。这里提供了,TagControllerPrintTagList()方法,来打印所有从数据库中查询出来的TagName

代码语言:go
复制
package controller

import (
    "fmt"
    "test/utils/sqlmock/model"

    "gorm.io/gorm"
)

type TagController struct {
    DB *gorm.DB
}

func (c *TagController) PrintTagList() {
    var tagModel []*model.Tag
    if err := c.DB.Find(&tagModel).Error; err != nil {
        fmt.Println(err)
    }
    for _, tag := range tagModel {
        fmt.Println(tag.TagName)
    }
}
4.1.3 model/tag.go

MVC 中的model层代码,这里是按照gorm的使用规范,定义了Tag表的结构信息。

代码语言:go
复制
package model

import (
    "time"
)

// Tag 表
type Tag struct {
    Id        uint      `gorm:"column:id;type:int(11) unsigned;primary_key;AUTO_INCREMENT" json:"id"`
    TagName   string    `gorm:"column:tag_name;type:varchar(20);comment:关键字;NOT NULL" json:"tag_name"`
    Enabled   int32     `gorm:"column:enabled;type:tinyint(2);default:0;comment:是否启用:1启用,0禁用;NOT NULL" json:"enabled"`
    CreatedAt time.Time `gorm:"column:created_at;type:datetime;comment:创建时间" json:"created_at"`
    UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;comment:更新时间" json:"updated_at"`
}

// TableName -
func (m *Tag) TableName() string {
    return "tag"
}
4.1.4 执行main.go

当然,实际的数据,已经预先写入到了数据库中。这里可以正确的被打印出来

代码语言:go
复制
结果:
tag1
tag2
apple
orange
water
banana

4.2 通过sqlmock来对TagController的代码编写单测

创建controller/tag_test.go的单测文件,填写以下信息:

代码语言:go
复制
package controller

import (
    "fmt"
    "testing"

    "gorm.io/driver/mysql"
    "gorm.io/gorm"

    "github.com/DATA-DOG/go-sqlmock"
)

// 初始化sqlmock
func initTest() (*gorm.DB, sqlmock.Sqlmock) {
    // 1. 初始化 sql mock
    db, mock, err := sqlmock.New()
    if err != nil {
        fmt.Println("err:", err)
    }

    // 2. mock数据库版本查询
    mock.ExpectQuery("SELECT VERSION()").WillReturnRows(sqlmock.NewRows([]string{"version"}).AddRow("5.7.25"))

    // 3. mock gorm driver
    gormDB, err := gorm.Open(mysql.New(mysql.Config{
        Conn: db,
    }), &gorm.Config{})
    if err != nil {
        panic(err) // Error here
    }

    return gormDB, mock
}

// 对PrintTagList方法单测
func TestTagController_PrintTagList(t *testing.T) {
    // 初始化sqlmock
    gormDB, sqlMock := initTest()
    // 对即将产生的sql,预先打桩处理
    mockExpect(sqlMock)
    
    // 初始化控制器,并将mock后的gorm注入
    tagController := &TagController{
        DB: gormDB,
    }
    // 调用需要测试的方法
    tagController.PrintTagList()
}

// 对即将产生的sql,预先打桩处理
func mockExpect(mock sqlmock.Sqlmock) {
    mock.ExpectQuery("^SELECT (.+) FROM `tag`").WillReturnRows(sqlmock.NewRows([]string{"tag_name"}).AddRow("apple").AddRow("orange"))
}

在执行结果中,将会显mock的内容AddRow("apple").AddRow("orange")

代码语言:go
复制
// 执行结果:
=== RUN   TestTagController_PrintTagList
apple
orange
--- PASS: TestTagController_PrintTagList (0.00s)
PASS

4.3 支持事务

4.3.1 在TagController中增加方法Create()
代码语言:go
复制
func (c *TagController) Create(tagName string) {
    // 开启事务
    tx := c.DB.Begin()
    tagModel := &model.Tag{
        TagName:   tagName,
        Enabled:   1,
        CreatedAt: time.Now(),
    }
    if err := tx.Create(tagModel).Error; err != nil {
        // 创建失败回滚
        tx.Rollback()
        fmt.Println(err)
    }
    // 提交事务
    if err := tx.Commit().Error; err != nil {
        fmt.Println(err)
    }
}
4.3.2 增加单测
代码语言:go
复制
func TestTagController_Create(t *testing.T) {
    gormDB, sqlMock := initTest()
    mockCreateExpect(sqlMock)

    tagController := &TagController{
        DB: gormDB,
    }
    // 测试创建失败
    tagController.Create("banana")
    // 测试创建成功
    tagController.Create("banana")
}

func mockCreateExpect(mock sqlmock.Sqlmock) {
    // mock创建失败
    mock.ExpectBegin()
    mock.ExpectExec("^INSERT INTO `tag` ").WillReturnError(gorm.ErrInvalidData)
    mock.ExpectRollback()

    // mock创建成功
    mock.ExpectBegin()
    mock.ExpectExec("^INSERT INTO `tag` ").WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(),
        sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1))
    mock.ExpectCommit()
}

其中sqlmock.AnyArg() 跳过对参数的匹配校验,但是.WithArgs()方法要求,对参数的数量需要一致。

4.4. 完整代码

controller/tag.go

代码语言:go
复制
package controller

import (
    "fmt"
    "test/utils/sqlmock/model"
    "time"

    "gorm.io/gorm"
)

type TagController struct {
    DB *gorm.DB
}

func (c *TagController) PrintTagList() {
    var tagModel []*model.Tag
    if err := c.DB.Find(&tagModel).Error; err != nil {
        fmt.Println(err)
    }
    for _, tag := range tagModel {
        fmt.Println(tag.TagName)
    }
}

func (c *TagController) Create(tagName string) {
    // 开启事务
    tx := c.DB.Begin()
    tagModel := &model.Tag{
        TagName:   tagName,
        Enabled:   1,
        CreatedAt: time.Now(),
    }
    if err := tx.Create(tagModel).Error; err != nil {
        // 创建失败回滚
        tx.Rollback()
        fmt.Println(err)
    }
    // 提交事务
    if err := tx.Commit().Error; err != nil {
        fmt.Println(err)
    }
}

controller/tag_test.go

代码语言:go
复制
package controller

import (
    "fmt"
    "testing"

    "gorm.io/driver/mysql"
    "gorm.io/gorm"

    "github.com/DATA-DOG/go-sqlmock"
)

func initTest() (*gorm.DB, sqlmock.Sqlmock) {
    // 1. 初始化 sql mock
    db, mock, err := sqlmock.New()
    if err != nil {
        fmt.Println("err:", err)
    }

    // 2. mock数据库版本查询
    mock.ExpectQuery("SELECT VERSION()").WillReturnRows(sqlmock.NewRows([]string{"version"}).AddRow("5.7.25"))

    // 3. 组装mock的gorm
    gormDB, err := gorm.Open(mysql.New(mysql.Config{
        Conn: db,
    }), &gorm.Config{})
    if err != nil {
        panic(err) // Error here
    }

    return gormDB, mock
}

func TestTagController_PrintTagList(t *testing.T) {
    gormDB, sqlMock := initTest()
    mockExpect(sqlMock)

    tagController := &TagController{
        DB: gormDB,
    }
    tagController.PrintTagList()

}

func TestTagController_Create(t *testing.T) {
    gormDB, sqlMock := initTest()
    mockCreateExpect(sqlMock)

    tagController := &TagController{
        DB: gormDB,
    }
    // 测试创建失败
    tagController.Create("banana")
    // 测试创建成功
    tagController.Create("banana")
}

func mockExpect(mock sqlmock.Sqlmock) {
    mock.ExpectQuery("^SELECT (.+) FROM `tag`").WillReturnRows(sqlmock.NewRows([]string{"tag_name"}).AddRow("apple").AddRow("orange"))
}

func mockCreateExpect(mock sqlmock.Sqlmock) {
    // mock创建失败
    mock.ExpectBegin()
    mock.ExpectExec("^INSERT INTO `tag` ").WillReturnError(gorm.ErrInvalidData)
    mock.ExpectRollback()

    // mock创建成功
    mock.ExpectBegin()
    mock.ExpectExec("^INSERT INTO `tag` ").WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(),
        sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1))
    mock.ExpectCommit()
}

4.5 注意事项

  1. initTest()方法中,对gorm driver进行mock的时候,低版本和高版本的代码实现会有一定的差异。目前网上搜索到的示例大多数都是旧版本的实现方式,本文中的示例,是基于gorm.io/gorm v1.25.5版本实现的。[error] failed to initialize database, got error all expectations were already fulfilled, call to Query 'SELECT VERSION()' with args [] was not expected --- FAIL: TestTagController_Create (0.00s) panic: all expectations were already fulfilled, call to Query 'SELECT VERSION()' with args [] was not expected [recovered] panic: all expectations were already fulfilled, call to Query 'SELECT VERSION()' with args [] was not expected需要增加Expect:mock.ExpectQuery("SELECT VERSION()").WillReturnRows(sqlmock.NewRows([]string{"version"}).AddRow("5.7.25")) 4. 数据库连接关闭问题sql: database is closed sql: database is closed; invalid transaction通过db, mock, err := sqlmock.New()获取到db后,千万不要defer db.Close(),否则会导致后续对数据库操作,引起database is closed的问题。 // 1. 初始化 sql mock db, mock, err := sqlmock.New() // defer db.Close() 不要加这一行 if err != nil { fmt.Println("err:", err) }
  2. mock.ExpectQuery()方法中,支持正则表达式来对sql语句进行匹配。
  3. 初始化数据库,SELECT VERSION()问题

5. 总结

上面主要是,简单的介绍和示例了,通过sqlmock来对gorm打桩mock。从而更加简单和方便的来对使用到数据库操作的业务代码进行单测的编写。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. 场景
  • 2. sqlmock 简介
  • 3. 安装 go-sqlmock
  • 4. sqlmock实战
    • 4.1 定义接口
      • 4.1.1 main.go
      • 4.1.2 controller/tag.go
      • 4.1.3 model/tag.go
      • 4.1.4 执行main.go
    • 4.2 通过sqlmock来对TagController的代码编写单测
      • 4.3 支持事务
        • 4.3.1 在TagController中增加方法Create()
        • 4.3.2 增加单测
      • 4.4. 完整代码
        • 4.5 注意事项
        • 5. 总结
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档

        http://www.vxiaotou.com