go-recipe-book/pkg/db/db_connect_test.go

60 lines
1013 B
Go
Raw Normal View History

2024-11-09 19:50:05 -08:00
package db_test
import (
"errors"
"fmt"
2024-11-10 18:20:13 -08:00
"io"
2024-11-09 19:50:05 -08:00
"math/rand"
2024-11-10 18:20:13 -08:00
"os"
2024-11-09 19:50:05 -08:00
"testing"
"github.com/stretchr/testify/assert"
. "recipe_book/pkg/db"
)
2024-11-10 18:20:13 -08:00
var seed_sql string
func init() {
file, err := os.Open("../../sample_data/seed.sql")
if err != nil {
panic(err)
}
data, err := io.ReadAll(file)
if err != nil {
panic(err)
}
seed_sql = string(data)
}
2024-11-09 19:50:05 -08:00
func get_test_db() DB {
db_path := "../../sample_data/data/test.db"
db, err := DBCreate(db_path)
if errors.Is(err, ErrTargetExists) {
db, err = DBConnect(db_path)
2024-11-10 18:20:13 -08:00
} else if err == nil {
db.DB.MustExec(seed_sql)
2024-11-09 19:50:05 -08:00
}
if err != nil {
panic(err)
}
return db
}
2024-11-10 18:20:13 -08:00
func get_food(db DB, id FoodID) Food {
ret, err := db.GetFoodByID(id)
if err != nil {
panic(err)
}
return ret
}
2024-11-09 19:50:05 -08:00
func TestCreateAndConnectToDB(t *testing.T) {
i := rand.Uint32()
_, err := DBCreate(fmt.Sprintf("../../sample_data/data/random-%d.db", i))
assert.NoError(t, err)
_, err = DBConnect(fmt.Sprintf("../../sample_data/data/random-%d.db", i))
assert.NoError(t, err)
}