package middleware import ( "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "net/http" "net/http/httptest" "testing" "time" "git.urec56.ru/urec/chat_back_go/config" "git.urec56.ru/urec/chat_back_go/internal/domain" "git.urec56.ru/urec/chat_back_go/internal/logger" mock_service "git.urec56.ru/urec/chat_back_go/internal/service/mocks" ) func TestMiddleware_Auth(t *testing.T) { type extractBehavior func(s *mock_service.MockServ, r *http.Request, token string, err error) type decodeBehavior func(s *mock_service.MockServ, token string, userID int, err error) type getBehavior func(s *mock_service.MockServ, userID int, user domain.User, err error) testTable := []struct { name string extractBehavior extractBehavior decodeBehavior decodeBehavior getBehavior getBehavior reqToken string extractToken string extractErr error decodeUserID int decodeErr error getUser domain.User getErr error logErr error expectedUser domain.User expectedStatusCode int }{ { name: "ok", extractBehavior: func(s *mock_service.MockServ, r *http.Request, token string, err error) { s.EXPECT().ExtractAuthToken(r).Return(token, err) }, decodeBehavior: func(s *mock_service.MockServ, token string, userID int, err error) { s.EXPECT().DecodeAuthToken(token).Return(userID, err) }, getBehavior: func(s *mock_service.MockServ, userID int, user domain.User, err error) { s.EXPECT().Get(userID, false).Return(user, err) }, reqToken: "Bearer token", extractToken: "token", decodeUserID: 1, getUser: domain.User{ ID: 1, Role: 1, Username: "urec", Email: "mail@mail.ru", HashedPassword: "hp", AvatarImage: "image", BlackPhoenix: true, DateOfBirth: domain.CustomDate{Time: time.Date(2002, time.February, 2, 0, 0, 0, 0, time.UTC)}, DateOfRegistration: domain.CustomDate{Time: time.Date(2025, time.February, 2, 0, 0, 0, 0, time.UTC)}, }, expectedUser: domain.User{ ID: 1, Role: 1, Username: "urec", Email: "mail@mail.ru", HashedPassword: "hp", AvatarImage: "image", BlackPhoenix: true, DateOfBirth: domain.CustomDate{Time: time.Date(2002, time.February, 2, 0, 0, 0, 0, time.UTC)}, DateOfRegistration: domain.CustomDate{Time: time.Date(2025, time.February, 2, 0, 0, 0, 0, time.UTC)}, }, expectedStatusCode: http.StatusOK, }, { name: "user_not_found", extractBehavior: func(s *mock_service.MockServ, r *http.Request, token string, err error) { s.EXPECT().ExtractAuthToken(r).Return(token, err) }, decodeBehavior: func(s *mock_service.MockServ, token string, userID int, err error) { s.EXPECT().DecodeAuthToken(token).Return(userID, err) }, getBehavior: func(s *mock_service.MockServ, userID int, user domain.User, err error) { s.EXPECT().Get(userID, false).Return(user, err) }, reqToken: "Bearer token", extractToken: "token", getErr: domain.UserNotFoundError, expectedStatusCode: http.StatusNotFound, }, { name: "extract_error", extractBehavior: func(s *mock_service.MockServ, r *http.Request, token string, err error) { s.EXPECT().ExtractAuthToken(r).Return(token, err) }, decodeBehavior: func(s *mock_service.MockServ, token string, userID int, err error) {}, getBehavior: func(s *mock_service.MockServ, userID int, user domain.User, err error) {}, extractErr: domain.AnyError, logErr: domain.AnyError, expectedStatusCode: http.StatusUnauthorized, }, { name: "decode_error", extractBehavior: func(s *mock_service.MockServ, r *http.Request, token string, err error) { s.EXPECT().ExtractAuthToken(r).Return(token, err) }, decodeBehavior: func(s *mock_service.MockServ, token string, userID int, err error) { s.EXPECT().DecodeAuthToken(token).Return(userID, err) }, getBehavior: func(s *mock_service.MockServ, userID int, user domain.User, err error) {}, decodeErr: domain.AnyError, logErr: domain.AnyError, expectedStatusCode: http.StatusUnauthorized, }, { name: "get_error", extractBehavior: func(s *mock_service.MockServ, r *http.Request, token string, err error) { s.EXPECT().ExtractAuthToken(r).Return(token, err) }, decodeBehavior: func(s *mock_service.MockServ, token string, userID int, err error) { s.EXPECT().DecodeAuthToken(token).Return(userID, err) }, getBehavior: func(s *mock_service.MockServ, userID int, user domain.User, err error) { s.EXPECT().Get(userID, false).Return(user, err) }, getErr: domain.AnyError, logErr: domain.AnyError, expectedStatusCode: http.StatusInternalServerError, }, } for _, tc := range testTable { t.Run(tc.name, func(t *testing.T) { c := gomock.NewController(t) defer c.Finish() log := logger.NewLogger(config.Config{Mode: "TEST"}) serv := mock_service.NewMockServ(c) reqPath := "/" req := httptest.NewRequest(http.MethodGet, reqPath, nil) w := httptest.NewRecorder() req.Header.Set("Authorization", tc.reqToken) tc.extractBehavior(serv, req, tc.extractToken, tc.extractErr) tc.decodeBehavior(serv, tc.extractToken, tc.decodeUserID, tc.decodeErr) tc.getBehavior(serv, tc.decodeUserID, tc.getUser, tc.getErr) m := &Middleware{serv: serv, l: log} server := m.Auth(func(w http.ResponseWriter, r *http.Request) { u := r.Context().Value("user") assert.Equal(t, tc.expectedUser, u) }) server.ServeHTTP(w, req) resp := w.Result() assert.Equal(t, tc.expectedStatusCode, resp.StatusCode) }) } } func TestMiddleware_VerificatedAuth(t *testing.T) { type extractBehavior func(s *mock_service.MockServ, r *http.Request, token string, err error) type decodeBehavior func(s *mock_service.MockServ, token string, userID int, err error) type getBehavior func(s *mock_service.MockServ, userID int, user domain.User, err error) testTable := []struct { name string extractBehavior extractBehavior decodeBehavior decodeBehavior getBehavior getBehavior reqToken string extractToken string extractErr error decodeUserID int decodeErr error getUser domain.User getErr error logErr error expectedUser domain.User expectedStatusCode int }{ { name: "ok", extractBehavior: func(s *mock_service.MockServ, r *http.Request, token string, err error) { s.EXPECT().ExtractAuthToken(r).Return(token, err) }, decodeBehavior: func(s *mock_service.MockServ, token string, userID int, err error) { s.EXPECT().DecodeAuthToken(token).Return(userID, err) }, getBehavior: func(s *mock_service.MockServ, userID int, user domain.User, err error) { s.EXPECT().GetVerificated(userID).Return(user, err) }, reqToken: "Bearer token", extractToken: "token", decodeUserID: 1, getUser: domain.User{ ID: 1, Role: 1, Username: "urec", Email: "mail@mail.ru", HashedPassword: "hp", AvatarImage: "image", BlackPhoenix: true, DateOfBirth: domain.CustomDate{Time: time.Date(2002, time.February, 2, 0, 0, 0, 0, time.UTC)}, DateOfRegistration: domain.CustomDate{Time: time.Date(2025, time.February, 2, 0, 0, 0, 0, time.UTC)}, }, expectedUser: domain.User{ ID: 1, Role: 1, Username: "urec", Email: "mail@mail.ru", HashedPassword: "hp", AvatarImage: "image", BlackPhoenix: true, DateOfBirth: domain.CustomDate{Time: time.Date(2002, time.February, 2, 0, 0, 0, 0, time.UTC)}, DateOfRegistration: domain.CustomDate{Time: time.Date(2025, time.February, 2, 0, 0, 0, 0, time.UTC)}, }, expectedStatusCode: http.StatusOK, }, { name: "unverified_user", extractBehavior: func(s *mock_service.MockServ, r *http.Request, token string, err error) { s.EXPECT().ExtractAuthToken(r).Return(token, err) }, decodeBehavior: func(s *mock_service.MockServ, token string, userID int, err error) { s.EXPECT().DecodeAuthToken(token).Return(userID, err) }, getBehavior: func(s *mock_service.MockServ, userID int, user domain.User, err error) { s.EXPECT().GetVerificated(userID).Return(user, err) }, reqToken: "Bearer token", extractToken: "token", getErr: domain.UnverifiedUserError, expectedStatusCode: http.StatusConflict, }, { name: "user_not_found", extractBehavior: func(s *mock_service.MockServ, r *http.Request, token string, err error) { s.EXPECT().ExtractAuthToken(r).Return(token, err) }, decodeBehavior: func(s *mock_service.MockServ, token string, userID int, err error) { s.EXPECT().DecodeAuthToken(token).Return(userID, err) }, getBehavior: func(s *mock_service.MockServ, userID int, user domain.User, err error) { s.EXPECT().GetVerificated(userID).Return(user, err) }, reqToken: "Bearer token", extractToken: "token", getErr: domain.UserNotFoundError, expectedStatusCode: http.StatusNotFound, }, { name: "extract_error", extractBehavior: func(s *mock_service.MockServ, r *http.Request, token string, err error) { s.EXPECT().ExtractAuthToken(r).Return(token, err) }, decodeBehavior: func(s *mock_service.MockServ, token string, userID int, err error) {}, getBehavior: func(s *mock_service.MockServ, userID int, user domain.User, err error) {}, extractErr: domain.AnyError, logErr: domain.AnyError, expectedStatusCode: http.StatusUnauthorized, }, { name: "decode_error", extractBehavior: func(s *mock_service.MockServ, r *http.Request, token string, err error) { s.EXPECT().ExtractAuthToken(r).Return(token, err) }, decodeBehavior: func(s *mock_service.MockServ, token string, userID int, err error) { s.EXPECT().DecodeAuthToken(token).Return(userID, err) }, getBehavior: func(s *mock_service.MockServ, userID int, user domain.User, err error) {}, decodeErr: domain.AnyError, logErr: domain.AnyError, expectedStatusCode: http.StatusUnauthorized, }, { name: "get_error", extractBehavior: func(s *mock_service.MockServ, r *http.Request, token string, err error) { s.EXPECT().ExtractAuthToken(r).Return(token, err) }, decodeBehavior: func(s *mock_service.MockServ, token string, userID int, err error) { s.EXPECT().DecodeAuthToken(token).Return(userID, err) }, getBehavior: func(s *mock_service.MockServ, userID int, user domain.User, err error) { s.EXPECT().GetVerificated(userID).Return(user, err) }, getErr: domain.AnyError, logErr: domain.AnyError, expectedStatusCode: http.StatusInternalServerError, }, } for _, tc := range testTable { t.Run(tc.name, func(t *testing.T) { c := gomock.NewController(t) defer c.Finish() log := logger.NewLogger(config.Config{Mode: "TEST"}) serv := mock_service.NewMockServ(c) reqPath := "/" req := httptest.NewRequest(http.MethodGet, reqPath, nil) w := httptest.NewRecorder() req.Header.Set("Authorization", tc.reqToken) tc.extractBehavior(serv, req, tc.extractToken, tc.extractErr) tc.decodeBehavior(serv, tc.extractToken, tc.decodeUserID, tc.decodeErr) tc.getBehavior(serv, tc.decodeUserID, tc.getUser, tc.getErr) m := &Middleware{serv: serv, l: log} server := m.VerificatedAuth(func(w http.ResponseWriter, r *http.Request) { u := r.Context().Value("user") assert.Equal(t, tc.expectedUser, u) }) server.ServeHTTP(w, req) resp := w.Result() assert.Equal(t, tc.expectedStatusCode, resp.StatusCode) }) } }