gotel

simple terminal chat program
git clone git://git.mdnr.space/gotel
Log | Files | Refs | README | LICENSE

commit c1af6554271669412463ed574cc8ddabc17b3b33
parent 802ea27f60823d52329db441c2f9d3b6cf6c0848
Author: mdnrz <mehdeenoroozi@gmail.com>
Date:   Mon, 21 Apr 2025 00:57:09 +0330

add database for authenticated users

Diffstat:
Mclient.go | 191+++++++++++++++++++++++++++++++++++++++++++++++++++++++------------------------
Mgo.mod | 1+
Mserver.go | 363++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------------------
3 files changed, 404 insertions(+), 151 deletions(-)

diff --git a/client.go b/client.go @@ -2,39 +2,39 @@ package main import ( "fmt" - "net" "log" + "net" "strings" "github.com/jroimartin/gocui" ) type Command struct { - Name string + Name string Description string - Signature string - Function func(*gocui.View, string) error + Signature string + Function func(*gocui.View, []string) error } -const commandCnt = 3 +const commandCnt = 5 + var commands [commandCnt]Command var gui *gocui.Gui var serverConn net.Conn const serverAddr = "127.0.0.1:6969" -const initMsg = -`This is a client for connecting to GoTel chat server. +const initMsg = `This is a client for connecting to GoTel chat server. =======================================================` func main() { - initCommands(); + initCommands() gui, _ = gocui.NewGui(gocui.OutputNormal) // if err != nil { // log.Panicln(err) // } defer gui.Close() - gui.Cursor = true; + gui.Cursor = true gui.SetManagerFunc(layout) if err := gui.SetKeybinding("", gocui.KeyCtrlC, gocui.ModNone, quit); err != nil { @@ -68,95 +68,173 @@ func layout(g *gocui.Gui) error { if promptErr != gocui.ErrUnknownView { return promptErr } - prompt.Editable = true; - prompt.Wrap = true; - if _,err := g.SetCurrentView("prompt"); err != nil { - return err; + prompt.Editable = true + prompt.Wrap = true + if _, err := g.SetCurrentView("prompt"); err != nil { + return err } } return nil } -func getCommandArg(items []string) string { - if len(items) >= 2 { - return items[1] - } - return "" -} +// func getCommandArg(items []string) string { +// if len(items) >= 2 { +// return items[1] +// } +// return "" +// } func getInput(g *gocui.Gui, v *gocui.View) error { - input := strings.TrimRight(v.Buffer(), "\r\n"); + input := strings.TrimRight(v.Buffer(), "\r\n") items := strings.Split(input, " ") - v.Clear(); - v.SetCursor(0, 0); - chatLog, _ := g.View("chatLog"); - for _, cmd := range commands { - if strings.HasPrefix(cmd.Signature, items[0]) { - cmd.Function(chatLog, getCommandArg(items)); - return nil; + v.Clear() + v.SetCursor(0, 0) + chatLog, _ := g.View("chatLog") + if strings.HasPrefix(items[0], "/") { + for _, cmd := range commands { + if strings.HasPrefix(cmd.Signature, items[0]) { + cmd.Function(chatLog, items) + return nil + } } + fmt.Fprintf(chatLog, "Invalid command: %s\n", items[0]) + return nil } - // fmt.Fprintf(chatLog, "you entered: %s\n", input); serverConn.Write([]byte(input)) - return nil; + return nil } func initCommands() { - commands = [commandCnt]Command { + commands = [commandCnt]Command{ - Command { - Name: "help", + Command{ + Name: "help", Description: "Print help menu", - Signature: "/help", - Function: printHelp, + Signature: "/help [command]", + Function: printHelp, }, - Command { - Name: "join", + Command{ + Name: "join", Description: "Join the chat", - Signature: "/join <Token>", - Function: sendJoin, + Signature: "/join <ipv4>:<port> <TOKEN>", + Function: sendJoin, + }, + + Command{ + Name: "signup", + Description: "Sign up to server", + Signature: "/signup <username> <password>", + Function: sendSignup, }, - Command { - Name: "quit", + Command{ + Name: "login", + Description: "Login to server", + Signature: "/login <username> <password>", + Function: sendLogin, + }, + + Command{ + Name: "exit", Description: "Logout from server", - Signature: "/quit", - Function: sendQuit, + Signature: "/exit", + Function: sendQuit, }, } } -func printHelp(v *gocui.View, input string) error { - for _, cmd := range commands { - fmt.Fprintf(v, "%s - %s\n", cmd.Signature, cmd.Description) +func printHelp(v *gocui.View, input []string) error { + if len(input) == 1 { + fmt.Fprintf(v, "Commands:\n") + for _, cmd := range commands { + fmt.Fprintf(v, "%s - %s\n", cmd.Signature, cmd.Description) + } + return nil + } + if len(input) == 2 { + for _, cmd := range commands { + if strings.HasPrefix(cmd.Signature, input[1]) { + fmt.Fprintf(v, "%s - %s\n", cmd.Signature, cmd.Description) + return nil + } + } + fmt.Fprintf(v, "%s is not a valid command. Type /help to see the full list of commands.\n", input[1]) + return nil } + fmt.Fprintln(v, "Too many arguments for /help command.") + fmt.Fprintf(v, "%s - %s\n", commands[0].Signature, commands[0].Description) return nil } -func sendJoin(v *gocui.View, input string) error { - serverConn, _ = net.Dial("tcp", serverAddr); - // if err != nil { - // return err - // } - _, err := serverConn.Write([]byte(input)); +func sendJoin(v *gocui.View, input []string) error { + var err error + if len(input) != 3 { + fmt.Fprintf(v, "Invalid join command.\n") + input := []string{"/help", "/join"} + printHelp(v, input) + return nil + } + + if len(input[2]) != 32 { + fmt.Fprintf(v, "Invalid token: %s\nThe token should be a 32-character string.\n", input[1]) + return nil + } + + serverConn, err = net.Dial("tcp", input[1]) if err != nil { - return err + fmt.Fprintf(v, "Server is not responding.\n") + return nil + } + _, err = serverConn.Write([]byte(input[0] + " " + input[2])) + if err != nil { + fmt.Fprintf(v, "Could not send join request to the server: %s\n", err) + return nil } go getMsg(serverConn) return nil } -func sendQuit (v *gocui.View, input string) error { - fmt.Fprintln(v, "Quiting from server"); +func sendSignup(v *gocui.View, input []string) error { + if len(input) != 3 { + fmt.Fprintf(v, "Invalid signup command.\n") + input := []string{"/help", "/signup"} + printHelp(v, input) + return nil + } + + _, err := serverConn.Write([]byte(input[0] + " " + input[1] + " " + input[2])) + if err != nil { + fmt.Fprintf(v, "Could not send signup request to the server: %s\n", err) + } + return nil +} + +func sendLogin(v *gocui.View, input []string) error { + if len(input) != 3 { + fmt.Fprintf(v, "Invalid login command.\n") + input := []string{"/help", "/login"} + printHelp(v, input) + return nil + } + + _, err := serverConn.Write([]byte(input[0] + " " + input[1] + " " + input[2])) + if err != nil { + fmt.Fprintf(v, "Could not send login request to the server: %s\n", err) + } + return nil +} + +func sendQuit(v *gocui.View, input []string) error { + fmt.Fprintln(v, "Quiting from server") serverConn.Close() return nil } -func getMsg (conn net.Conn) { +func getMsg(conn net.Conn) { readBuf := make([]byte, 512) for { - n, readErr := conn.Read(readBuf); + n, readErr := conn.Read(readBuf) if readErr != nil { return @@ -172,7 +250,6 @@ func getMsg (conn net.Conn) { } } - func quit(g *gocui.Gui, v *gocui.View) error { return gocui.ErrQuit } diff --git a/go.mod b/go.mod @@ -5,5 +5,6 @@ go 1.24.0 require ( github.com/jroimartin/gocui v0.5.0 // indirect github.com/mattn/go-runewidth v0.0.9 // indirect + github.com/mattn/go-sqlite3 v1.14.28 // indirect github.com/nsf/termbox-go v1.1.1 // indirect ) diff --git a/server.go b/server.go @@ -1,76 +1,110 @@ package main import ( - "os" - "math/big" "crypto/rand" + "fmt" "log" + "math/big" "net" - "time" + "os" "strings" - "fmt" + "time" + "database/sql" + _"github.com/mattn/go-sqlite3" ) -var commands [3]string = [3]string{"/help", "/login", "/quit"}; - type msgType int + const ( msgConnect msgType = iota + 1 + msgJoin + msgSignup msgLogin msgText msgQuit ) +var RequestMap map[string]msgType +var userDB *sql.DB + +type Msg struct { + Author Client + Type msgType + Args []string +} + type loginStage int + const ( verification loginStage = iota + 1 username password ) - type Client struct { - Stage loginStage - UserName string - Password string + Stage loginStage + UserName string + Password string LastMsgTime time.Time - Request msgType - Strike int - Banned bool - BanEnd time.Time - Conn net.Conn - Text string + Strike int + Banned bool + BanEnd time.Time + Conn net.Conn } const ( msgCoolDownTimeSec = 1 - banLimit = 5 - banTimeoutSec = 180 - Port = "6969" + banLimit = 5 + banTimeoutSec = 180 + Port = "6969" ) -func addClient(conn net.Conn, Client_q chan Client) { - // loginPrompt := "Who are you?\n> " - // _, err := conn.Write([]byte(loginPrompt)) - // if err != nil { - // log.Printf("[ERROR] Could not send login prompt to user %s: %s\n", - // conn.RemoteAddr().String(), err) - // } +func initReqMap() { + RequestMap = make(map[string]msgType) + RequestMap["/join"] = msgJoin + RequestMap["/signup"] = msgSignup + RequestMap["/login"] = msgLogin +} + +func clientRoutine(conn net.Conn, Msg_q chan Msg) { + Msg_q <- Msg{ + Type: msgConnect, + Author: Client{ + Conn: conn, + Stage: verification, + }, + } readBuf := make([]byte, 512) for { n, err := conn.Read(readBuf) if n == 0 { - Client_q <- Client { - Request: msgQuit, - Conn: conn, + Msg_q <- Msg{ + Type: msgQuit, + Author: Client{ + Conn: conn, + }, } - return; + return } if n > 0 { - Client_q <- Client { - Request: msgText, - Conn: conn, - Text: string(readBuf[:n]), + if strings.HasPrefix(string(readBuf), "/") { + items := strings.Split(string(readBuf[:n]), " ") + Msg_q <- Msg{ + Type: RequestMap[items[0]], + Args: items[1:len(items)], + Author: Client{ + Conn: conn, + }, + } + } else { + log.Printf("Got message from %s\n", conn.RemoteAddr().String()) + Msg_q <- Msg{ + Type: msgText, + Args: []string{string(readBuf[:n])}, + Author: Client{ + Conn: conn, + }, + } } } if err != nil { @@ -83,33 +117,35 @@ func addClient(conn net.Conn, Client_q chan Client) { func canMessage(client *Client) bool { if !client.Banned { - diff := time.Now().Sub(client.LastMsgTime).Seconds(); + diff := time.Now().Sub(client.LastMsgTime).Seconds() if diff <= msgCoolDownTimeSec { - client.Strike += 1; + client.Strike += 1 if client.Strike >= banLimit { - client.Banned = true; + client.Banned = true client.BanEnd = time.Now().Add(banTimeoutSec * time.Second) } - return false; + return false } return true } - banTimeRemaining := client.BanEnd.Sub(time.Now()).Seconds(); - if banTimeRemaining >= 0.0 { + banTimeRemaining := client.BanEnd.Sub(time.Now()).Seconds() + if banTimeRemaining >= 0.0 { banTimeRemainingStr := fmt.Sprintf("You're banned. Try again in %.0f seconds.\n", banTimeRemaining) - client.Conn.Write([]byte(banTimeRemainingStr)); - return false; - } - client.Strike = 0; - client.Banned = false; - return true; + client.Conn.Write([]byte(banTimeRemainingStr)) + return false + } + client.Strike = 0 + client.Banned = false + return true } func checkForDuplicateUN(needle string, heystack map[string]Client) bool { for _, client := range heystack { - if client.UserName == needle { return true } + if client.UserName == needle { + return true + } } - return false; + return false } func verifyToken(input string) bool { @@ -128,17 +164,27 @@ func verifyToken(input string) bool { return input == string(tokenBytes) } +func trimNewline(input rune) bool { + if input == '\n' { + return true + } + if input == '\r' { + return true + } + return false +} + func login(client *Client, msg string) bool { if client.Stage == verification { - if !verifyToken(strings.TrimRight(msg, "\r\n")) { - _, err := client.Conn.Write([]byte("Invalid token.\n")); + if !verifyToken(strings.TrimRightFunc(msg, trimNewline)) { + _, err := client.Conn.Write([]byte("Invalid token.\n")) if err != nil { log.Printf("Could not send invalid token message to client %s\n", client.Conn.RemoteAddr().String()) } client.LastMsgTime = time.Now() - return false; + return false } - _, err := client.Conn.Write([]byte("Token verified successfully!\nEnter your user name")); + _, err := client.Conn.Write([]byte("Token verified successfully!\nEnter your user name")) if err != nil { log.Printf("Could not send verification message to %s\n", client.Conn.RemoteAddr().String()) } @@ -147,18 +193,18 @@ func login(client *Client, msg string) bool { return false } if client.Stage == username { - client.UserName = strings.TrimRight(msg, "\r\n"); + client.UserName = strings.TrimRight(msg, "\r\n") client.LastMsgTime = time.Now() - _, err := client.Conn.Write([]byte("Enter your password:\n")); + _, err := client.Conn.Write([]byte("Enter your password:\n")) if err != nil { log.Printf("Could not send passowrd message to %s\n", client.Conn.RemoteAddr().String()) } client.Stage = password return false } - client.Password = strings.TrimRight(msg, "\r\n"); + client.Password = strings.TrimRight(msg, "\r\n") client.LastMsgTime = time.Now() - _, err := client.Conn.Write([]byte("Welcome " + client.UserName + "\n")); + _, err := client.Conn.Write([]byte("Welcome " + client.UserName + "\n")) if err != nil { log.Printf("Could not send welcome message to %s\n", client.Conn.RemoteAddr().String()) } @@ -166,60 +212,172 @@ func login(client *Client, msg string) bool { return true } -func server(Client_q chan Client) { +func isUserInDB (username string) (bool, error) { + var count int + check := "SELECT COUNT(*) FROM users WHERE username = ?" + err := userDB.QueryRow(check, username).Scan(&count) + if err != nil { + return false, err + } + return count > 0, nil +} + +func addUserToDB (client Client) error { + insert := "INSERT INTO users (username, password, banned, banEnd) VALUES ($1, $2, $3, $4)" + _, err := userDB.Exec(insert, client.UserName, client.Password, client.Banned, client.BanEnd) + return err +} + +func checkPassword(username string, inputPass string) (bool, error) { + var dbPass string + query := "SELECT password FROM users WHERE username = ?" + err := userDB.QueryRow(query, username).Scan(&dbPass) + if err != nil { + return false, nil + } + return dbPass == inputPass, nil +} + +func server(Msg_q chan Msg) { onlineList := make(map[string]Client) + joinedList := make(map[string]Client) offlineList := make(map[string]Client) for { - client := <-Client_q - keyString := client.Conn.RemoteAddr().String(); - switch client.Request { + msg := <-Msg_q + keyString := msg.Author.Conn.RemoteAddr().String() + switch msg.Type { case msgConnect: // TODO: implement rate limit for connection requests - log.Printf("Got join request from %s\n", keyString); - offlineList[keyString] = client; + log.Printf("Got connection request from %s\n", keyString) + offlineList[keyString] = msg.Author + break case msgQuit: author, ok := onlineList[keyString] if ok { - log.Printf("%s logged out.\n", author.UserName); - author.Conn.Close(); - delete(onlineList, keyString); + log.Printf("%s logged out.\n", author.UserName) + author.Conn.Close() + delete(onlineList, keyString) + break } - case msgText: - clientOffline, ok := offlineList[keyString] - if ok { - if !canMessage(&clientOffline) { - offlineList[keyString] = clientOffline // update the timings + break + + case msgJoin: + log.Printf("server: Got join request from %s\n", keyString) + _, offline := offlineList[keyString] + _, joined := joinedList[keyString] + if offline { + if verifyToken(msg.Args[0]) { + joinedList[keyString] = msg.Author + delete(offlineList, keyString) + msg.Author.Conn.Write([]byte("Authentication successfull.")) + break + } + msg.Author.Conn.Write([]byte("Provided token is not valid.")) + break + } + if joined { + msg.Author.Conn.Write([]byte("You are already joined the server.\nTry logging in or signing up.")) + break + } + msg.Author.Conn.Write([]byte("You are currently logged in.")) + break + + case msgSignup: + _, offline := offlineList[keyString] + _, joined := joinedList[keyString] + _, online := onlineList[keyString] + if online { + msg.Author.Conn.Write([]byte("This username already exists.")) + break + } + if joined { + yes, err := isUserInDB(msg.Args[0]) + if err != nil { + msg.Author.Conn.Write([]byte("Database error: " + err.Error())) break } - if login(&clientOffline, client.Text) { - onlineList[keyString] = clientOffline - delete(offlineList, keyString) + if yes { + msg.Author.Conn.Write([]byte("This username already exists.")) + break + } + msg.Author.UserName = msg.Args[0] + msg.Author.Password = msg.Args[1] + msg.Author.LastMsgTime = time.Now() + if err := addUserToDB(msg.Author); err != nil { + msg.Author.Conn.Write([]byte("Could not signup user. Database error: " + err.Error())) + break + } + onlineList[keyString] = msg.Author + delete(joinedList, keyString) + msg.Author.Conn.Write([]byte("Welcome " + msg.Author.UserName)) + break + } + if offline { + msg.Author.Conn.Write([]byte("You should provide the token first with the /join command.\n")) + break + } + case msgLogin: + _, offline := offlineList[keyString] + _, joined := joinedList[keyString] + _, online := onlineList[keyString] + if online { + msg.Author.Conn.Write([]byte("You are currently logged in.")) + break + } + if offline { + msg.Author.Conn.Write([]byte("You should provide the token first with the /join command.")) + break + } + if joined { + yes, err := isUserInDB(msg.Args[0]) + if err != nil { + msg.Author.Conn.Write([]byte("Database error: " + err.Error())) + break + } + if !yes { + msg.Author.Conn.Write([]byte("Username does not exist. You can create new user using /signup command.")) + break + } + passOK, err := checkPassword(msg.Args[0], msg.Args[1]) + if err != nil { + msg.Author.Conn.Write([]byte("Database error: " + err.Error())) break } - offlineList[keyString] = clientOffline + if !passOK { + // TODO: limited retry + msg.Author.Conn.Write([]byte("Incorrect password.")) + break + } + msg.Author.UserName = msg.Args[0] + msg.Author.Password = msg.Args[1] + msg.Author.LastMsgTime = time.Now() + onlineList[keyString] = msg.Author + delete(joinedList, keyString) + msg.Author.Conn.Write([]byte("Welcome " + msg.Author.UserName)) break } + break - author, ok := onlineList[keyString]; - if !ok { - log.Fatal("cannot find client\n"); + case msgText: + author, online := onlineList[keyString] + if !online { + msg.Author.Conn.Write([]byte("You must be logged in to send messages.\n")) + break } if !canMessage(&author) { onlineList[keyString] = author // update the timings - break; + break } author.LastMsgTime = time.Now() - author.Text = client.Text; onlineList[keyString] = author - for _, value := range onlineList { - // if value.Conn == author.Conn { - // continue - // } - _, err := value.Conn.Write([]byte(author.UserName + ": " + author.Text)) + for _, client := range onlineList { + _, err := client.Conn.Write([]byte(author.UserName + ": " + msg.Args[0])) if err != nil { - log.Printf("Could not send message to client %s\n", value.UserName) + log.Printf("Could not send message to client %s\n", client.UserName) } } + default: + log.Printf("server: Invalid request\n") } } } @@ -234,7 +392,7 @@ func genToken() { if err != nil { log.Fatalf("Could not generate random number: %s\n", err) } - tokenStr = fmt.Sprintf(tokenStr + "%X", randInt) + tokenStr = fmt.Sprintf(tokenStr+"%X", randInt) } tokenFile, err := os.Create("TOKEN") if err != nil { @@ -247,13 +405,28 @@ func genToken() { } func main() { + userDB, _ = sql.Open("sqlite3", "./users.db") + createTable := `CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username VARCHAR(50) UNIQUE NOT NULL, + password VARCHAR(100) NOT NULL, + banned BOOLEAN, + banEnd TIMESTAMP + );` + _, err := userDB.Exec(createTable) + if err != nil { + log.Fatalf("Error creating users table: %s\n", err) + } + log.Printf("Users table created.\n") + + initReqMap() genToken() ln, err := net.Listen("tcp", ":"+Port) if err != nil { log.Fatalf("Could not listen to port %s: %s\n", Port, err) } - Client_q := make(chan Client) - go server(Client_q) + Msg_q := make(chan Msg) + go server(Msg_q) for { conn, err := ln.Accept() if err != nil { @@ -261,11 +434,13 @@ func main() { continue } log.Printf("Accepted connection from %s", conn.RemoteAddr()) - Client_q <- Client { - Request: msgConnect, - Conn: conn, - Stage: verification, - } - go addClient(conn, Client_q) + // Msg_q <- Msg{ + // Type: msgConnect, + // Author: Client{ + // Conn: conn, + // Stage: verification, + // }, + // } + go clientRoutine(conn, Msg_q) } }