如何使用 gomock (或类似的)来模拟/验证对数据库的调用?

转到这里,使用gorm或/映射到数据库 (PSQL)。


我有以下代码:


package dbstuff


import (

    "errors"


  "github.com/google/uuid"

  "github.com/jinzhu/gorm"

    _ "github.com/jinzhu/gorm/dialects/postgres"

)


type OrderPersister struct {

        db *gorm.DB

}


func (p *OrderPersister) GetOrder(id uuid.UUID) (*Order, error) {

        ret := &Order{}


        err := p.db.Table("orders").Where("order_id = ?", id).Scan(ret).Error

        return ret, err

}

我正在尝试为它编写一个单元测试,如下所示:


package dbstuff


import (

    "testing"

  "errors"


  "github.com/stretchr/testify/assert"

)


func TestErrInternalServerError(t *testing.T) {


  // given

  id := uuid.New()

  op := OrderPersister{}


  // when

  order, err := op.GetOrder(id)


  // then

  assert.NotNil(t, order)

  assert.NotNil(t, err)


}

当我运行它时,我得到无效的内存地址或 nil 指针取消引用错误,因为我没有*gorm.DB在我的OrderPersister实例上实例化设置 a。有没有一种简单的方法来模拟/存根,以便我的测试确认我们尝试查询orders表并返回或/映射结果?


呼如林
浏览 172回答 1
1回答

慕桂英4014372

我将使用testify包为您的代码编写单元测试。而不是使用具体类型,而是为struct*gorm.DB声明 DB 接口。OrderPersister由于我们不能在 Go 中模拟具体类型及其方法。我们需要创建一个抽象层—— interface.63622995/db/db.go:package dbtype OrmDBWithError struct {    OrmDB    Error error}type OrmDB interface {    Table(name string) OrmDB    Where(query interface{}, args ...interface{}) OrmDB    Scan(dest interface{}) *OrmDBWithError}63622995/main.go:package mainimport (    "github.com/google/uuid"    _ "github.com/jinzhu/gorm/dialects/postgres"    "github.com/mrdulin/golang/src/stackoverflow/63622995/db")type Order struct {    order_id string}type OrderPersister struct {    DB db.OrmDB    //DB *gorm.DB}func (p *OrderPersister) GetOrder(id uuid.UUID) (*Order, error) {    ret := &Order{}    err := p.DB.Table("orders").Where("order_id = ?", id).Scan(ret).Error    return ret, err}OrmDB为实现接口的 db 创建了模拟对象。然后,您可以创建此模拟 DB 对象并将其传递给OrderPersisterstruct。63622995/mocks/db.go:package mocksimport (    "github.com/mrdulin/golang/src/stackoverflow/63622995/db"    "github.com/stretchr/testify/mock")type MockedOrmDB struct {    mock.Mock}func (s *MockedOrmDB) Table(name string) db.OrmDB {    args := s.Called(name)    return args.Get(0).(db.OrmDB)}func (s *MockedOrmDB) Where(query interface{}, args ...interface{}) db.OrmDB {    arguments := s.Called(query, args)    return arguments.Get(0).(db.OrmDB)}func (s *MockedOrmDB) Scan(dest interface{}) *db.OrmDBWithError {    args := s.Called(dest)    return args.Get(0).(*db.OrmDBWithError)}63622995/main_test.go:package mainimport (    "testing"    "github.com/google/uuid"    "github.com/mrdulin/golang/src/stackoverflow/63622995/db"    "github.com/mrdulin/golang/src/stackoverflow/63622995/mocks"    "github.com/stretchr/testify/assert"    "github.com/stretchr/testify/mock")func TestOrderPersister_GetOrder(t *testing.T) {    assert := assert.New(t)    t.Run("should get order", func(t *testing.T) {        testDb := new(mocks.MockedOrmDB)        id := uuid.New()        testDb.            On("Table", "orders").            Return(testDb).            On("Where", "order_id = ?", mock.Anything).            Return(testDb).            On("Scan", mock.Anything).Run(func(args mock.Arguments) {            ret := args.Get(0).(*Order)            ret.order_id = "123"        }).            Return(&db.OrmDBWithError{Error: nil})        op := OrderPersister{DB: testDb}        got, err := op.GetOrder(id)        testDb.AssertExpectations(t)        assert.Nil(err)        assert.Equal(Order{order_id: "123"}, *got)    })    t.Run("should return error", func(t *testing.T) {        testDb := new(mocks.MockedOrmDB)        id := uuid.New()        testDb.            On("Table", "orders").            Return(testDb).            On("Where", "order_id = ?", mock.Anything).            Return(testDb).            On("Scan", mock.Anything).            Return(&db.OrmDBWithError{Error: errors.New("network")})        op := OrderPersister{DB: testDb}        got, err := op.GetOrder(id)        testDb.AssertExpectations(t)        assert.Equal(Order{}, *got)        assert.Equal(err.Error(), "network")    })}单元测试结果:=== RUN   TestOrderPersister_GetOrder=== RUN   TestOrderPersister_GetOrder/should_get_order    TestOrderPersister_GetOrder/should_get_order: main_test.go:32: PASS:    Table(string)    TestOrderPersister_GetOrder/should_get_order: main_test.go:32: PASS:    Where(string,string)    TestOrderPersister_GetOrder/should_get_order: main_test.go:32: PASS:    Scan(string)=== RUN   TestOrderPersister_GetOrder/should_return_error    TestOrderPersister_GetOrder/should_return_error: main_test.go:49: PASS: Table(string)    TestOrderPersister_GetOrder/should_return_error: main_test.go:49: PASS: Where(string,string)    TestOrderPersister_GetOrder/should_return_error: main_test.go:49: PASS: Scan(string)--- PASS: TestOrderPersister_GetOrder (0.00s)    --- PASS: TestOrderPersister_GetOrder/should_get_order (0.00s)    --- PASS: TestOrderPersister_GetOrder/should_return_error (0.00s)PASSProcess finished with exit code 0覆盖报告:
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Go