120 lines
3.2 KiB
Go
120 lines
3.2 KiB
Go
package repository
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
"github.com/jmoiron/sqlx"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.urec56.ru/urec/chat_back_go/internal/database"
|
|
"git.urec56.ru/urec/chat_back_go/internal/domain"
|
|
"git.urec56.ru/urec/chat_back_go/internal/logger"
|
|
)
|
|
|
|
type userRepository struct {
|
|
db *sqlx.DB
|
|
l *logger.Logger
|
|
}
|
|
|
|
func newUserRepo(db *sqlx.DB, l *logger.Logger) *userRepository {
|
|
return &userRepository{db: db, l: l}
|
|
}
|
|
|
|
func (r *userRepository) GetByID(userID int) (domain.User, error) {
|
|
var user domain.User
|
|
query := `SELECT * FROM users WHERE id = $1`
|
|
if err := r.db.Get(&user, query, userID); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return domain.User{}, domain.UserNotFoundError
|
|
}
|
|
return domain.User{}, domain.InternalServerError
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func (r *userRepository) GetAll(username string) ([]domain.User, error) {
|
|
var users []domain.User
|
|
username = fmt.Sprint("%", username, "%")
|
|
query := fmt.Sprintf(`SELECT * FROM users WHERE username ILIKE $1 AND role != %d`, domain.AdminUser)
|
|
err := r.db.Select(&users, query, username)
|
|
if err != nil {
|
|
r.l.Errorf("getting users: %s", err)
|
|
return nil, domain.InternalServerError
|
|
}
|
|
return users, nil
|
|
}
|
|
|
|
func (r *userRepository) FindOne(username, email string) (domain.User, error) {
|
|
var conditions []string
|
|
var args []interface{}
|
|
var user domain.User
|
|
query := `SELECT * FROM users`
|
|
|
|
if username != "" {
|
|
conditions = append(conditions, `username = ?`)
|
|
args = append(args, username)
|
|
}
|
|
if email != "" {
|
|
conditions = append(conditions, `email = ?`)
|
|
args = append(args, email)
|
|
}
|
|
|
|
if len(conditions) > 0 {
|
|
query += ` WHERE ` + strings.Join(conditions, ` AND `)
|
|
}
|
|
|
|
query += ` LIMIT 1`
|
|
|
|
query = r.db.Rebind(query)
|
|
err := r.db.Get(&user, query, args...)
|
|
if err != nil {
|
|
return domain.User{}, domain.InternalServerError
|
|
}
|
|
return user, nil
|
|
}
|
|
|
|
func (r *userRepository) Register(email, hashedPassword, username string, dateOfBirth time.Time) (domain.User, error) {
|
|
var user domain.User
|
|
uQuery := `INSERT INTO users (email, hashed_password, username, date_of_birth) VALUES ($1, $2, $3, $4) RETURNING *`
|
|
aQuery := `INSERT INTO user_avatar (user_id, avatar_image) VALUES ($1, $2)`
|
|
|
|
tx, err := r.db.Beginx()
|
|
if err != nil {
|
|
r.l.Errorf("user registration: tx begin: %s", err)
|
|
return domain.User{}, domain.InternalServerError
|
|
}
|
|
defer func() {
|
|
if err != nil {
|
|
var pgError *pgconn.PgError
|
|
if !errors.As(err, &pgError) {
|
|
r.l.Errorf("user registration: %s", err)
|
|
}
|
|
|
|
if err = tx.Rollback(); err != nil && !errors.Is(err, sql.ErrTxDone) {
|
|
r.l.Errorf("user registration: tx rollback: %s", err)
|
|
}
|
|
}
|
|
}()
|
|
|
|
if err = tx.Get(&user, uQuery, email, hashedPassword, username, dateOfBirth); err != nil {
|
|
var pgErr *pgconn.PgError
|
|
if errors.As(err, &pgErr) && pgErr.SQLState() == database.IntegrityErrorCode {
|
|
return domain.User{}, domain.UserAlreadyExistsError
|
|
}
|
|
return domain.User{}, domain.InternalServerError
|
|
}
|
|
|
|
if _, err = tx.Exec(aQuery, user.ID, user.AvatarImage); err != nil {
|
|
return domain.User{}, domain.InternalServerError
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
return domain.User{}, domain.InternalServerError
|
|
}
|
|
|
|
return user, nil
|
|
}
|