server.go (10307B)
1 package main 2 3 import ( 4 "crypto/rand" 5 "database/sql" 6 "fmt" 7 _ "github.com/mattn/go-sqlite3" 8 "golang.org/x/crypto/bcrypt" 9 "log" 10 "math/big" 11 "net" 12 "os" 13 "strings" 14 "time" 15 ) 16 17 type msgType int 18 19 const ( 20 msgConnect msgType = iota + 1 21 msgJoin 22 msgSignup 23 msgLogin 24 msgText 25 msgQuit 26 ) 27 28 var RequestMap map[string]msgType 29 var userDB *sql.DB 30 31 type Msg struct { 32 Author Client 33 Type msgType 34 Args []string 35 } 36 37 type loginStage int 38 39 const ( 40 verification loginStage = iota + 1 41 username 42 password 43 ) 44 45 type Client struct { 46 Stage loginStage 47 UserName string 48 LastMsgTime time.Time 49 PassRetry int 50 Strike int 51 Banned bool 52 BanEnd time.Time 53 Conn net.Conn 54 } 55 56 const ( 57 msgCoolDownTimeSec = 1 58 banLimit = 5 59 banTimeoutSec = 180 60 Port = "6969" 61 ) 62 63 func initReqMap() { 64 RequestMap = make(map[string]msgType) 65 RequestMap["/join"] = msgJoin 66 RequestMap["/signup"] = msgSignup 67 RequestMap["/login"] = msgLogin 68 } 69 70 func clientRoutine(conn net.Conn, Msg_q chan Msg) { 71 Msg_q <- Msg{ 72 Type: msgConnect, 73 Author: Client{ 74 Conn: conn, 75 Stage: verification, 76 }, 77 } 78 readBuf := make([]byte, 512) 79 for { 80 n, err := conn.Read(readBuf) 81 if n == 0 { 82 Msg_q <- Msg{ 83 Type: msgQuit, 84 Author: Client{ 85 Conn: conn, 86 }, 87 } 88 return 89 } 90 if n > 0 { 91 if strings.HasPrefix(string(readBuf), "/") { 92 items := strings.Split(string(readBuf[:n]), " ") 93 Msg_q <- Msg{ 94 Type: RequestMap[items[0]], 95 Args: items[1:len(items)], 96 Author: Client{ 97 Conn: conn, 98 }, 99 } 100 } else { 101 log.Printf("Got message from %s\n", conn.RemoteAddr().String()) 102 Msg_q <- Msg{ 103 Type: msgText, 104 Args: []string{string(readBuf[:n])}, 105 Author: Client{ 106 Conn: conn, 107 }, 108 } 109 } 110 } 111 if err != nil { 112 log.Printf("Could not read message from client %s: %s\n", conn.RemoteAddr().String(), err) 113 conn.Close() 114 return 115 } 116 } 117 } 118 119 func canMessage(client *Client) bool { 120 if !client.Banned { 121 diff := time.Now().Sub(client.LastMsgTime).Seconds() 122 if diff <= msgCoolDownTimeSec { 123 client.Strike += 1 124 if client.Strike >= banLimit { 125 client.Banned = true 126 client.BanEnd = time.Now().Add(banTimeoutSec * time.Second) 127 } 128 return false 129 } 130 return true 131 } 132 banTimeRemaining := client.BanEnd.Sub(time.Now()).Seconds() 133 if banTimeRemaining >= 0.0 { 134 banTimeRemainingStr := fmt.Sprintf("You're banned. Try again in %.0f seconds.\n", banTimeRemaining) 135 client.Conn.Write([]byte(banTimeRemainingStr)) 136 return false 137 } 138 client.Strike = 0 139 client.Banned = false 140 return true 141 } 142 143 func checkForDuplicateUN(needle string, heystack map[string]Client) bool { 144 for _, client := range heystack { 145 if client.UserName == needle { 146 return true 147 } 148 } 149 return false 150 } 151 152 func verifyToken(input string) bool { 153 tokenBytes := make([]byte, 32) 154 tokenFile, err := os.Open("TOKEN") 155 if err != nil { 156 log.Fatalf("Could not open TOKEN file for authentication: %s\n", err) 157 } 158 n, err := tokenFile.Read(tokenBytes) 159 if err != nil { 160 log.Fatalf("Could not read TOKEN file: %s\n", err) 161 } 162 if n < 32 { 163 log.Fatalf("TOKEN file is not valid.\n") 164 } 165 return input == string(tokenBytes) 166 } 167 168 func trimNewline(input rune) bool { 169 if input == '\n' { 170 return true 171 } 172 if input == '\r' { 173 return true 174 } 175 return false 176 } 177 178 func isUserInDB(username string) (bool, error) { 179 var count int 180 check := "SELECT COUNT(*) FROM users WHERE username = ?" 181 err := userDB.QueryRow(check, username).Scan(&count) 182 if err != nil { 183 return false, err 184 } 185 return count > 0, nil 186 } 187 188 func addUserToDB(client Client, rawPass string) error { 189 insert := "INSERT INTO users (username, password, banned, banEnd) VALUES ($1, $2, $3, $4)" 190 passHashed, err := bcrypt.GenerateFromPassword([]byte(rawPass), bcrypt.DefaultCost) 191 if err != nil { 192 return err 193 } 194 _, err = userDB.Exec(insert, client.UserName, passHashed, client.Banned, client.BanEnd) 195 return err 196 } 197 198 func getPassHash(username string) (string, error) { 199 var passHash string 200 query := "SELECT password FROM users WHERE username = ?" 201 err := userDB.QueryRow(query, username).Scan(&passHash) 202 return passHash, err 203 } 204 205 func server(Msg_q chan Msg) { 206 onlineList := make(map[string]Client) 207 joinedList := make(map[string]Client) 208 offlineList := make(map[string]Client) 209 for { 210 msg := <-Msg_q 211 keyString := msg.Author.Conn.RemoteAddr().String() 212 switch msg.Type { 213 case msgConnect: 214 // TODO: implement rate limit for connection requests 215 log.Printf("Got connection request from %s\n", keyString) 216 offlineList[keyString] = msg.Author 217 break 218 case msgQuit: 219 author, ok := onlineList[keyString] 220 if ok { 221 log.Printf("%s logged out.\n", author.UserName) 222 author.Conn.Close() 223 delete(onlineList, keyString) 224 break 225 } 226 break 227 228 case msgJoin: 229 log.Printf("server: Got join request from %s\n", keyString) 230 _, offline := offlineList[keyString] 231 _, joined := joinedList[keyString] 232 if offline { 233 if verifyToken(msg.Args[0]) { 234 joinedList[keyString] = msg.Author 235 delete(offlineList, keyString) 236 msg.Author.Conn.Write([]byte("Authentication successfull.")) 237 break 238 } 239 msg.Author.Conn.Write([]byte("Provided token is not valid.")) 240 break 241 } 242 if joined { 243 msg.Author.Conn.Write([]byte("You are already joined the server.\nTry logging in or signing up.")) 244 break 245 } 246 msg.Author.Conn.Write([]byte("You are currently logged in.")) 247 break 248 249 case msgSignup: 250 _, offline := offlineList[keyString] 251 client, joined := joinedList[keyString] 252 _, online := onlineList[keyString] 253 if online { 254 msg.Author.Conn.Write([]byte("This username already exists.")) 255 break 256 } 257 if joined { 258 if !canMessage(&client) { 259 joinedList[keyString] = client // update the timings 260 break 261 } 262 yes, err := isUserInDB(msg.Args[0]) 263 if err != nil { 264 msg.Author.Conn.Write([]byte("Database error: " + err.Error())) 265 break 266 } 267 if yes { 268 msg.Author.Conn.Write([]byte("This username already exists.")) 269 break 270 } 271 msg.Author.UserName = msg.Args[0] 272 msg.Author.LastMsgTime = time.Now() 273 if err := addUserToDB(msg.Author, msg.Args[1]); err != nil { 274 msg.Author.Conn.Write([]byte("Could not signup user. Database error: " + err.Error())) 275 break 276 } 277 onlineList[keyString] = msg.Author 278 delete(joinedList, keyString) 279 msg.Author.Conn.Write([]byte("Welcome " + msg.Author.UserName)) 280 break 281 } 282 if offline { 283 msg.Author.Conn.Write([]byte("You should provide the token first with the /join command.\n")) 284 break 285 } 286 case msgLogin: 287 _, offline := offlineList[keyString] 288 client, joined := joinedList[keyString] 289 _, online := onlineList[keyString] 290 291 if online { 292 msg.Author.Conn.Write([]byte("You are currently logged in.")) 293 break 294 } 295 if offline { 296 msg.Author.Conn.Write([]byte("You should provide the token first with the /join command.")) 297 break 298 } 299 if joined { 300 if !canMessage(&client) { 301 joinedList[keyString] = client // update the timings 302 break 303 } 304 yes, err := isUserInDB(msg.Args[0]) 305 if err != nil { 306 msg.Author.Conn.Write([]byte("Database error: " + err.Error())) 307 break 308 } 309 if !yes { 310 msg.Author.Conn.Write([]byte("Username does not exist. You can create new user using /signup command.")) 311 break 312 } 313 passHash, err := getPassHash(msg.Args[0]) 314 if err != nil { 315 msg.Author.Conn.Write([]byte("Database error: " + err.Error())) 316 break 317 } 318 if err = bcrypt.CompareHashAndPassword([]byte(passHash), []byte(msg.Args[1])); err != nil { 319 client.PassRetry += 1 320 joinedList[keyString] = client 321 if client.PassRetry >= 3 { 322 client.Banned = true 323 client.BanEnd = time.Now().Add(banTimeoutSec * time.Second) 324 msg.Author.Conn.Write([]byte("Reached the limit of retries. Youre banned for 180 seconds.")) 325 joinedList[keyString] = client 326 break 327 } 328 msg.Author.Conn.Write([]byte("Incorrect password. You have " + fmt.Sprintf("%d", 3-client.PassRetry) + " chances before getting banned for 3 minuetes.")) 329 break 330 } 331 msg.Author.UserName = msg.Args[0] 332 msg.Author.LastMsgTime = time.Now() 333 onlineList[keyString] = msg.Author 334 delete(joinedList, keyString) 335 msg.Author.Conn.Write([]byte("Welcome " + msg.Author.UserName)) 336 break 337 } 338 break 339 340 case msgText: 341 author, online := onlineList[keyString] 342 if !online { 343 msg.Author.Conn.Write([]byte("You must be logged in to send messages.\n")) 344 break 345 } 346 if !canMessage(&author) { 347 onlineList[keyString] = author // update the timings 348 break 349 } 350 author.LastMsgTime = time.Now() 351 onlineList[keyString] = author 352 for _, client := range onlineList { 353 _, err := client.Conn.Write([]byte(author.UserName + ": " + msg.Args[0])) 354 if err != nil { 355 log.Printf("Could not send message to client %s\n", client.UserName) 356 } 357 } 358 default: 359 log.Printf("server: Invalid request\n") 360 } 361 } 362 } 363 364 func genToken() { 365 max := big.NewInt(0xF) 366 var randInt *big.Int 367 var err error 368 var tokenStr string 369 for range [32]int{} { 370 randInt, err = rand.Int(rand.Reader, max) 371 if err != nil { 372 log.Fatalf("Could not generate random number: %s\n", err) 373 } 374 tokenStr = fmt.Sprintf(tokenStr+"%X", randInt) 375 } 376 tokenFile, err := os.Create("TOKEN") 377 if err != nil { 378 log.Fatalf("Could not create token file: %s\n", err) 379 } 380 _, err = tokenFile.WriteString(tokenStr) 381 if err != nil { 382 log.Fatalf("Could not write token file: %s\n", err) 383 } 384 } 385 386 func main() { 387 userDB, _ = sql.Open("sqlite3", "./users.db") 388 createTable := `CREATE TABLE IF NOT EXISTS users ( 389 id INTEGER PRIMARY KEY AUTOINCREMENT, 390 username VARCHAR(50) UNIQUE NOT NULL, 391 password VARCHAR(100) NOT NULL, 392 banned BOOLEAN, 393 banEnd TIMESTAMP 394 );` 395 _, err := userDB.Exec(createTable) 396 if err != nil { 397 log.Fatalf("Error creating users table: %s\n", err) 398 } 399 log.Printf("Users table created.\n") 400 401 initReqMap() 402 genToken() 403 ln, err := net.Listen("tcp", ":"+Port) 404 if err != nil { 405 log.Fatalf("Could not listen to port %s: %s\n", Port, err) 406 } 407 Msg_q := make(chan Msg) 408 go server(Msg_q) 409 for { 410 conn, err := ln.Accept() 411 if err != nil { 412 log.Printf("Could not accept the connection: %s\n", err) 413 continue 414 } 415 log.Printf("Accepted connection from %s", conn.RemoteAddr()) 416 go clientRoutine(conn, Msg_q) 417 } 418 }