gotel

simple terminal chat program
git clone git://git.mdnr.space/gotel
Log | Files | Refs | README | LICENSE

server.go (10307B)


      1 package main
      2 
      3 import (
      4 	"crypto/rand"
      5 	"database/sql"
      6 	"fmt"
      7 	_ "github.com/mattn/go-sqlite3"
      8 	"golang.org/x/crypto/bcrypt"
      9 	"log"
     10 	"math/big"
     11 	"net"
     12 	"os"
     13 	"strings"
     14 	"time"
     15 )
     16 
     17 type msgType int
     18 
     19 const (
     20 	msgConnect msgType = iota + 1
     21 	msgJoin
     22 	msgSignup
     23 	msgLogin
     24 	msgText
     25 	msgQuit
     26 )
     27 
     28 var RequestMap map[string]msgType
     29 var userDB *sql.DB
     30 
     31 type Msg struct {
     32 	Author Client
     33 	Type   msgType
     34 	Args   []string
     35 }
     36 
     37 type loginStage int
     38 
     39 const (
     40 	verification loginStage = iota + 1
     41 	username
     42 	password
     43 )
     44 
     45 type Client struct {
     46 	Stage       loginStage
     47 	UserName    string
     48 	LastMsgTime time.Time
     49 	PassRetry   int
     50 	Strike      int
     51 	Banned      bool
     52 	BanEnd      time.Time
     53 	Conn        net.Conn
     54 }
     55 
     56 const (
     57 	msgCoolDownTimeSec = 1
     58 	banLimit           = 5
     59 	banTimeoutSec      = 180
     60 	Port               = "6969"
     61 )
     62 
     63 func initReqMap() {
     64 	RequestMap = make(map[string]msgType)
     65 	RequestMap["/join"] = msgJoin
     66 	RequestMap["/signup"] = msgSignup
     67 	RequestMap["/login"] = msgLogin
     68 }
     69 
     70 func clientRoutine(conn net.Conn, Msg_q chan Msg) {
     71 	Msg_q <- Msg{
     72 		Type: msgConnect,
     73 		Author: Client{
     74 			Conn:  conn,
     75 			Stage: verification,
     76 		},
     77 	}
     78 	readBuf := make([]byte, 512)
     79 	for {
     80 		n, err := conn.Read(readBuf)
     81 		if n == 0 {
     82 			Msg_q <- Msg{
     83 				Type: msgQuit,
     84 				Author: Client{
     85 					Conn: conn,
     86 				},
     87 			}
     88 			return
     89 		}
     90 		if n > 0 {
     91 			if strings.HasPrefix(string(readBuf), "/") {
     92 				items := strings.Split(string(readBuf[:n]), " ")
     93 				Msg_q <- Msg{
     94 					Type: RequestMap[items[0]],
     95 					Args: items[1:len(items)],
     96 					Author: Client{
     97 						Conn: conn,
     98 					},
     99 				}
    100 			} else {
    101 				log.Printf("Got message from %s\n", conn.RemoteAddr().String())
    102 				Msg_q <- Msg{
    103 					Type: msgText,
    104 					Args: []string{string(readBuf[:n])},
    105 					Author: Client{
    106 						Conn: conn,
    107 					},
    108 				}
    109 			}
    110 		}
    111 		if err != nil {
    112 			log.Printf("Could not read message from client %s: %s\n", conn.RemoteAddr().String(), err)
    113 			conn.Close()
    114 			return
    115 		}
    116 	}
    117 }
    118 
    119 func canMessage(client *Client) bool {
    120 	if !client.Banned {
    121 		diff := time.Now().Sub(client.LastMsgTime).Seconds()
    122 		if diff <= msgCoolDownTimeSec {
    123 			client.Strike += 1
    124 			if client.Strike >= banLimit {
    125 				client.Banned = true
    126 				client.BanEnd = time.Now().Add(banTimeoutSec * time.Second)
    127 			}
    128 			return false
    129 		}
    130 		return true
    131 	}
    132 	banTimeRemaining := client.BanEnd.Sub(time.Now()).Seconds()
    133 	if banTimeRemaining >= 0.0 {
    134 		banTimeRemainingStr := fmt.Sprintf("You're banned. Try again in %.0f seconds.\n", banTimeRemaining)
    135 		client.Conn.Write([]byte(banTimeRemainingStr))
    136 		return false
    137 	}
    138 	client.Strike = 0
    139 	client.Banned = false
    140 	return true
    141 }
    142 
    143 func checkForDuplicateUN(needle string, heystack map[string]Client) bool {
    144 	for _, client := range heystack {
    145 		if client.UserName == needle {
    146 			return true
    147 		}
    148 	}
    149 	return false
    150 }
    151 
    152 func verifyToken(input string) bool {
    153 	tokenBytes := make([]byte, 32)
    154 	tokenFile, err := os.Open("TOKEN")
    155 	if err != nil {
    156 		log.Fatalf("Could not open TOKEN file for authentication: %s\n", err)
    157 	}
    158 	n, err := tokenFile.Read(tokenBytes)
    159 	if err != nil {
    160 		log.Fatalf("Could not read TOKEN file: %s\n", err)
    161 	}
    162 	if n < 32 {
    163 		log.Fatalf("TOKEN file is not valid.\n")
    164 	}
    165 	return input == string(tokenBytes)
    166 }
    167 
    168 func trimNewline(input rune) bool {
    169 	if input == '\n' {
    170 		return true
    171 	}
    172 	if input == '\r' {
    173 		return true
    174 	}
    175 	return false
    176 }
    177 
    178 func isUserInDB(username string) (bool, error) {
    179 	var count int
    180 	check := "SELECT COUNT(*) FROM users WHERE username = ?"
    181 	err := userDB.QueryRow(check, username).Scan(&count)
    182 	if err != nil {
    183 		return false, err
    184 	}
    185 	return count > 0, nil
    186 }
    187 
    188 func addUserToDB(client Client, rawPass string) error {
    189 	insert := "INSERT INTO users (username, password, banned, banEnd) VALUES ($1, $2, $3, $4)"
    190 	passHashed, err := bcrypt.GenerateFromPassword([]byte(rawPass), bcrypt.DefaultCost)
    191 	if err != nil {
    192 		return err
    193 	}
    194 	_, err = userDB.Exec(insert, client.UserName, passHashed, client.Banned, client.BanEnd)
    195 	return err
    196 }
    197 
    198 func getPassHash(username string) (string, error) {
    199 	var passHash string
    200 	query := "SELECT password FROM users WHERE username = ?"
    201 	err := userDB.QueryRow(query, username).Scan(&passHash)
    202 	return passHash, err
    203 }
    204 
    205 func server(Msg_q chan Msg) {
    206 	onlineList := make(map[string]Client)
    207 	joinedList := make(map[string]Client)
    208 	offlineList := make(map[string]Client)
    209 	for {
    210 		msg := <-Msg_q
    211 		keyString := msg.Author.Conn.RemoteAddr().String()
    212 		switch msg.Type {
    213 		case msgConnect:
    214 			// TODO: implement rate limit for connection requests
    215 			log.Printf("Got connection request from %s\n", keyString)
    216 			offlineList[keyString] = msg.Author
    217 			break
    218 		case msgQuit:
    219 			author, ok := onlineList[keyString]
    220 			if ok {
    221 				log.Printf("%s logged out.\n", author.UserName)
    222 				author.Conn.Close()
    223 				delete(onlineList, keyString)
    224 				break
    225 			}
    226 			break
    227 
    228 		case msgJoin:
    229 			log.Printf("server: Got join request from %s\n", keyString)
    230 			_, offline := offlineList[keyString]
    231 			_, joined := joinedList[keyString]
    232 			if offline {
    233 				if verifyToken(msg.Args[0]) {
    234 					joinedList[keyString] = msg.Author
    235 					delete(offlineList, keyString)
    236 					msg.Author.Conn.Write([]byte("Authentication successfull."))
    237 					break
    238 				}
    239 				msg.Author.Conn.Write([]byte("Provided token is not valid."))
    240 				break
    241 			}
    242 			if joined {
    243 				msg.Author.Conn.Write([]byte("You are already joined the server.\nTry logging in or signing up."))
    244 				break
    245 			}
    246 			msg.Author.Conn.Write([]byte("You are currently logged in."))
    247 			break
    248 
    249 		case msgSignup:
    250 			_, offline := offlineList[keyString]
    251 			client, joined := joinedList[keyString]
    252 			_, online := onlineList[keyString]
    253 			if online {
    254 				msg.Author.Conn.Write([]byte("This username already exists."))
    255 				break
    256 			}
    257 			if joined {
    258 				if !canMessage(&client) {
    259 					joinedList[keyString] = client // update the timings
    260 					break
    261 				}
    262 				yes, err := isUserInDB(msg.Args[0])
    263 				if err != nil {
    264 					msg.Author.Conn.Write([]byte("Database error: " + err.Error()))
    265 					break
    266 				}
    267 				if yes {
    268 					msg.Author.Conn.Write([]byte("This username already exists."))
    269 					break
    270 				}
    271 				msg.Author.UserName = msg.Args[0]
    272 				msg.Author.LastMsgTime = time.Now()
    273 				if err := addUserToDB(msg.Author, msg.Args[1]); err != nil {
    274 					msg.Author.Conn.Write([]byte("Could not signup user. Database error: " + err.Error()))
    275 					break
    276 				}
    277 				onlineList[keyString] = msg.Author
    278 				delete(joinedList, keyString)
    279 				msg.Author.Conn.Write([]byte("Welcome " + msg.Author.UserName))
    280 				break
    281 			}
    282 			if offline {
    283 				msg.Author.Conn.Write([]byte("You should provide the token first with the /join command.\n"))
    284 				break
    285 			}
    286 		case msgLogin:
    287 			_, offline := offlineList[keyString]
    288 			client, joined := joinedList[keyString]
    289 			_, online := onlineList[keyString]
    290 
    291 			if online {
    292 				msg.Author.Conn.Write([]byte("You are currently logged in."))
    293 				break
    294 			}
    295 			if offline {
    296 				msg.Author.Conn.Write([]byte("You should provide the token first with the /join command."))
    297 				break
    298 			}
    299 			if joined {
    300 				if !canMessage(&client) {
    301 					joinedList[keyString] = client // update the timings
    302 					break
    303 				}
    304 				yes, err := isUserInDB(msg.Args[0])
    305 				if err != nil {
    306 					msg.Author.Conn.Write([]byte("Database error: " + err.Error()))
    307 					break
    308 				}
    309 				if !yes {
    310 					msg.Author.Conn.Write([]byte("Username does not exist. You can create new user using /signup command."))
    311 					break
    312 				}
    313 				passHash, err := getPassHash(msg.Args[0])
    314 				if err != nil {
    315 					msg.Author.Conn.Write([]byte("Database error: " + err.Error()))
    316 					break
    317 				}
    318 				if err = bcrypt.CompareHashAndPassword([]byte(passHash), []byte(msg.Args[1])); err != nil {
    319 					client.PassRetry += 1
    320 					joinedList[keyString] = client
    321 					if client.PassRetry >= 3 {
    322 						client.Banned = true
    323 						client.BanEnd = time.Now().Add(banTimeoutSec * time.Second)
    324 						msg.Author.Conn.Write([]byte("Reached the limit of retries. Youre banned for 180 seconds."))
    325 						joinedList[keyString] = client
    326 						break
    327 					}
    328 					msg.Author.Conn.Write([]byte("Incorrect password. You have " + fmt.Sprintf("%d", 3-client.PassRetry) + " chances before getting banned for 3 minuetes."))
    329 					break
    330 				}
    331 				msg.Author.UserName = msg.Args[0]
    332 				msg.Author.LastMsgTime = time.Now()
    333 				onlineList[keyString] = msg.Author
    334 				delete(joinedList, keyString)
    335 				msg.Author.Conn.Write([]byte("Welcome " + msg.Author.UserName))
    336 				break
    337 			}
    338 			break
    339 
    340 		case msgText:
    341 			author, online := onlineList[keyString]
    342 			if !online {
    343 				msg.Author.Conn.Write([]byte("You must be logged in to send messages.\n"))
    344 				break
    345 			}
    346 			if !canMessage(&author) {
    347 				onlineList[keyString] = author // update the timings
    348 				break
    349 			}
    350 			author.LastMsgTime = time.Now()
    351 			onlineList[keyString] = author
    352 			for _, client := range onlineList {
    353 				_, err := client.Conn.Write([]byte(author.UserName + ": " + msg.Args[0]))
    354 				if err != nil {
    355 					log.Printf("Could not send message to client %s\n", client.UserName)
    356 				}
    357 			}
    358 		default:
    359 			log.Printf("server: Invalid request\n")
    360 		}
    361 	}
    362 }
    363 
    364 func genToken() {
    365 	max := big.NewInt(0xF)
    366 	var randInt *big.Int
    367 	var err error
    368 	var tokenStr string
    369 	for range [32]int{} {
    370 		randInt, err = rand.Int(rand.Reader, max)
    371 		if err != nil {
    372 			log.Fatalf("Could not generate random number: %s\n", err)
    373 		}
    374 		tokenStr = fmt.Sprintf(tokenStr+"%X", randInt)
    375 	}
    376 	tokenFile, err := os.Create("TOKEN")
    377 	if err != nil {
    378 		log.Fatalf("Could not create token file: %s\n", err)
    379 	}
    380 	_, err = tokenFile.WriteString(tokenStr)
    381 	if err != nil {
    382 		log.Fatalf("Could not write token file: %s\n", err)
    383 	}
    384 }
    385 
    386 func main() {
    387 	userDB, _ = sql.Open("sqlite3", "./users.db")
    388 	createTable := `CREATE TABLE IF NOT EXISTS users (
    389 		id INTEGER PRIMARY KEY AUTOINCREMENT,
    390 		username VARCHAR(50) UNIQUE NOT NULL,
    391 		password VARCHAR(100) NOT NULL,
    392 		banned BOOLEAN,
    393 		banEnd TIMESTAMP
    394 	);`
    395 	_, err := userDB.Exec(createTable)
    396 	if err != nil {
    397 		log.Fatalf("Error creating users table: %s\n", err)
    398 	}
    399 	log.Printf("Users table created.\n")
    400 
    401 	initReqMap()
    402 	genToken()
    403 	ln, err := net.Listen("tcp", ":"+Port)
    404 	if err != nil {
    405 		log.Fatalf("Could not listen to port %s: %s\n", Port, err)
    406 	}
    407 	Msg_q := make(chan Msg)
    408 	go server(Msg_q)
    409 	for {
    410 		conn, err := ln.Accept()
    411 		if err != nil {
    412 			log.Printf("Could not accept the connection: %s\n", err)
    413 			continue
    414 		}
    415 		log.Printf("Accepted connection from %s", conn.RemoteAddr())
    416 		go clientRoutine(conn, Msg_q)
    417 	}
    418 }