From e5412eba878125e56915602130b0d04064ebd11e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20K=C3=A1n=C4=9B?= Date: Sun, 21 Mar 2021 15:30:42 +0100 Subject: [PATCH] Pass current time to model methods by arguments to make them purer and therefore more testable --- cmd/web/handlers.go | 7 ++++--- pkg/model/model.go | 15 ++++++--------- pkg/model/model_test.go | 6 +++--- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/cmd/web/handlers.go b/cmd/web/handlers.go index c30ffa1..2d6046b 100644 --- a/cmd/web/handlers.go +++ b/cmd/web/handlers.go @@ -7,6 +7,7 @@ import ( "net/http" "net/url" "strings" + "time" "vkane.cz/tinyquiz/pkg/model" "vkane.cz/tinyquiz/pkg/model/ent" ) @@ -43,7 +44,7 @@ func (app *application) play(w http.ResponseWriter, r *http.Request, params http return } - if player, err := app.model.RegisterPlayer(player, code, r.Context()); err == nil { + if player, err := app.model.RegisterPlayer(player, code, time.Now(), r.Context()); err == nil { if session, err := player.Unwrap().QuerySession().Only(r.Context()); err == nil { if su, err := app.model.GetPlayersStateUpdate(session.ID, r.Context()); err == nil { app.rtClients.SendToAll(session.ID, su) @@ -106,7 +107,7 @@ func (app *application) nextQuestion(w http.ResponseWriter, r *http.Request, par if player, err := app.model.GetPlayerWithSessionAndGame(playerUid, r.Context()); err == nil { var sessionId = player.Edges.Session.ID - if err := app.model.NextQuestion(sessionId, r.Context()); err == nil { + if err := app.model.NextQuestion(sessionId, time.Now(), r.Context()); err == nil { if su, err := app.model.GetQuestionStateUpdate(sessionId, r.Context()); err == nil { app.rtClients.SendToAll(sessionId, su) } else { @@ -144,7 +145,7 @@ func (app *application) answer(w http.ResponseWriter, r *http.Request, params ht return } - if _, err := app.model.SaveAnswer(playerUid, choiceUid, r.Context()); err == nil { + if _, err := app.model.SaveAnswer(playerUid, choiceUid, time.Now(), r.Context()); err == nil { // TODO notify organisers w.WriteHeader(http.StatusCreated) // TODO or StatusNoContent? return diff --git a/pkg/model/model.go b/pkg/model/model.go index bd10683..c0fcd49 100644 --- a/pkg/model/model.go +++ b/pkg/model/model.go @@ -60,7 +60,7 @@ func (m *Model) GetStats(c context.Context) (Stats, error) { // returns the player's UUID if error is nil // err = NoSuchEntity if the sessionCode is incorrect -func (m *Model) RegisterPlayer(playerName string, sessionCode string, c context.Context) (*ent.Player, error) { +func (m *Model) RegisterPlayer(playerName string, sessionCode string, now time.Time, c context.Context) (*ent.Player, error) { tx, err := m.c.BeginTx(c, &sql.TxOptions{ Isolation: sql.LevelRepeatableRead, }) @@ -74,7 +74,7 @@ func (m *Model) RegisterPlayer(playerName string, sessionCode string, c context. } else if err != nil { return nil, err } else { - if p, err := tx.Player.Create().SetID(uuid.New()).SetJoined(time.Now()).SetName(playerName).SetSession(s).Save(c); err == nil { + if p, err := tx.Player.Create().SetID(uuid.New()).SetJoined(now).SetName(playerName).SetSession(s).Save(c); err == nil { return p, nil } else if ent.IsConstraintError(err) { return nil, ConstraintViolation @@ -178,7 +178,7 @@ var NoNextQuestion = errors.New("there is no next question") // TODO fill // TODO retry on serialization failure // TODO validate sessionId -func (m *Model) NextQuestion(sessionId uuid.UUID, c context.Context) error { +func (m *Model) NextQuestion(sessionId uuid.UUID, now time.Time, c context.Context) error { tx, err := m.c.BeginTx(c, &sql.TxOptions{ Isolation: sql.LevelSerializable, }) @@ -188,13 +188,10 @@ func (m *Model) NextQuestion(sessionId uuid.UUID, c context.Context) error { // TODO rollback only if not yet committed defer tx.Rollback() - var now = time.Now() - var query = tx.Question.Query().Where(question.HasGameWith(game.HasSessionsWith(session.ID(sessionId)))).Order(ent.Asc(question.FieldOrder)) if current, err := tx.AskedQuestion.Query().Where(askedquestion.HasSessionWith(session.ID(sessionId))).WithQuestion().Order(ent.Desc(askedquestion.FieldAsked)).First(c); err == nil { query.Where(question.OrderGT(current.Edges.Question.Order)) - // TODO make sure we do not extend the deadline by slow processing if current.Ended.After(now) { if _, err := current.Update().SetEnded(now).Save(c); err != nil { return err @@ -221,7 +218,7 @@ func (m *Model) NextQuestion(sessionId uuid.UUID, c context.Context) error { var QuestionClosed = errors.New("the deadline for answers to this question has passed") var AlreadyAnswered = errors.New("the player has already answered the question") -func (m *Model) SaveAnswer(playerId uuid.UUID, choiceId uuid.UUID, c context.Context) (*ent.Answer, error) { +func (m *Model) SaveAnswer(playerId uuid.UUID, choiceId uuid.UUID, now time.Time, c context.Context) (*ent.Answer, error) { tx, err := m.c.BeginTx(c, &sql.TxOptions{ Isolation: sql.LevelSerializable, }) @@ -247,7 +244,7 @@ func (m *Model) SaveAnswer(playerId uuid.UUID, choiceId uuid.UUID, c context.Con // check if the question is open // Asked[0] is guaranteed to exist thanks to the previous query - if !q.Edges.Asked[0].Ended.After(time.Now()) { + if !q.Edges.Asked[0].Ended.After(now) { return nil, QuestionClosed } @@ -258,7 +255,7 @@ func (m *Model) SaveAnswer(playerId uuid.UUID, choiceId uuid.UUID, c context.Con return nil, AlreadyAnswered } - if a, err := tx.Answer.Create().SetID(uuid.New()).SetAnswered(time.Now()).SetChoiceID(choiceId).SetAnswererID(playerId).Save(c); err == nil { + if a, err := tx.Answer.Create().SetID(uuid.New()).SetAnswered(now).SetChoiceID(choiceId).SetAnswererID(playerId).Save(c); err == nil { tx.Commit() return a, nil } else { diff --git a/pkg/model/model_test.go b/pkg/model/model_test.go index e5d464d..d149c23 100644 --- a/pkg/model/model_test.go +++ b/pkg/model/model_test.go @@ -112,7 +112,7 @@ func TestModel_NextQuestion(t *testing.T) { m := newTestModelWithData(t) c := context.Background() - if err := m.NextQuestion(uuid.MustParse("b3d2f5b2-d5eb-4461-b352-622431a35b12"), c); err != nil { + if err := m.NextQuestion(uuid.MustParse("b3d2f5b2-d5eb-4461-b352-622431a35b12"), time.Unix(1613388006, 0), c); err != nil { t.Fatalf("Unexpected error when switching to next question: %v", err) } } @@ -122,11 +122,11 @@ func TestModel_NextQuestion_noNextQuestion(t *testing.T) { m := newTestModelWithData(t) c := context.Background() - if err := m.NextQuestion(uuid.MustParse("b3d2f5b2-d5eb-4461-b352-622431a35b12"), c); err != nil { + if err := m.NextQuestion(uuid.MustParse("b3d2f5b2-d5eb-4461-b352-622431a35b12"), time.Unix(1613388006, 0), c); err != nil { t.Fatalf("Unexpected error when switching to next question: %v", err) } - if err := m.NextQuestion(uuid.MustParse("b3d2f5b2-d5eb-4461-b352-622431a35b12"), c); err == nil { + if err := m.NextQuestion(uuid.MustParse("b3d2f5b2-d5eb-4461-b352-622431a35b12"), time.Unix(1613388008, 0), c); err == nil { t.Fatalf("Switching to next question from the last one did not fail") } else if !errors.Is(err, NoNextQuestion) { t.Fatalf("Unexpected error type after switching to next question from the last one: %v", err)