commit c1af6554271669412463ed574cc8ddabc17b3b33
parent 802ea27f60823d52329db441c2f9d3b6cf6c0848
Author: mdnrz <mehdeenoroozi@gmail.com>
Date: Mon, 21 Apr 2025 00:57:09 +0330
add database for authenticated users
Diffstat:
| M | client.go | | | 191 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++------------------------ |
| M | go.mod | | | 1 | + |
| M | server.go | | | 363 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------- |
3 files changed, 404 insertions(+), 151 deletions(-)
diff --git a/client.go b/client.go
@@ -2,39 +2,39 @@ package main
import (
"fmt"
- "net"
"log"
+ "net"
"strings"
"github.com/jroimartin/gocui"
)
type Command struct {
- Name string
+ Name string
Description string
- Signature string
- Function func(*gocui.View, string) error
+ Signature string
+ Function func(*gocui.View, []string) error
}
-const commandCnt = 3
+const commandCnt = 5
+
var commands [commandCnt]Command
var gui *gocui.Gui
var serverConn net.Conn
const serverAddr = "127.0.0.1:6969"
-const initMsg =
-`This is a client for connecting to GoTel chat server.
+const initMsg = `This is a client for connecting to GoTel chat server.
=======================================================`
func main() {
- initCommands();
+ initCommands()
gui, _ = gocui.NewGui(gocui.OutputNormal)
// if err != nil {
// log.Panicln(err)
// }
defer gui.Close()
- gui.Cursor = true;
+ gui.Cursor = true
gui.SetManagerFunc(layout)
if err := gui.SetKeybinding("", gocui.KeyCtrlC, gocui.ModNone, quit); err != nil {
@@ -68,95 +68,173 @@ func layout(g *gocui.Gui) error {
if promptErr != gocui.ErrUnknownView {
return promptErr
}
- prompt.Editable = true;
- prompt.Wrap = true;
- if _,err := g.SetCurrentView("prompt"); err != nil {
- return err;
+ prompt.Editable = true
+ prompt.Wrap = true
+ if _, err := g.SetCurrentView("prompt"); err != nil {
+ return err
}
}
return nil
}
-func getCommandArg(items []string) string {
- if len(items) >= 2 {
- return items[1]
- }
- return ""
-}
+// func getCommandArg(items []string) string {
+// if len(items) >= 2 {
+// return items[1]
+// }
+// return ""
+// }
func getInput(g *gocui.Gui, v *gocui.View) error {
- input := strings.TrimRight(v.Buffer(), "\r\n");
+ input := strings.TrimRight(v.Buffer(), "\r\n")
items := strings.Split(input, " ")
- v.Clear();
- v.SetCursor(0, 0);
- chatLog, _ := g.View("chatLog");
- for _, cmd := range commands {
- if strings.HasPrefix(cmd.Signature, items[0]) {
- cmd.Function(chatLog, getCommandArg(items));
- return nil;
+ v.Clear()
+ v.SetCursor(0, 0)
+ chatLog, _ := g.View("chatLog")
+ if strings.HasPrefix(items[0], "/") {
+ for _, cmd := range commands {
+ if strings.HasPrefix(cmd.Signature, items[0]) {
+ cmd.Function(chatLog, items)
+ return nil
+ }
}
+ fmt.Fprintf(chatLog, "Invalid command: %s\n", items[0])
+ return nil
}
- // fmt.Fprintf(chatLog, "you entered: %s\n", input);
serverConn.Write([]byte(input))
- return nil;
+ return nil
}
func initCommands() {
- commands = [commandCnt]Command {
+ commands = [commandCnt]Command{
- Command {
- Name: "help",
+ Command{
+ Name: "help",
Description: "Print help menu",
- Signature: "/help",
- Function: printHelp,
+ Signature: "/help [command]",
+ Function: printHelp,
},
- Command {
- Name: "join",
+ Command{
+ Name: "join",
Description: "Join the chat",
- Signature: "/join <Token>",
- Function: sendJoin,
+ Signature: "/join <ipv4>:<port> <TOKEN>",
+ Function: sendJoin,
+ },
+
+ Command{
+ Name: "signup",
+ Description: "Sign up to server",
+ Signature: "/signup <username> <password>",
+ Function: sendSignup,
},
- Command {
- Name: "quit",
+ Command{
+ Name: "login",
+ Description: "Login to server",
+ Signature: "/login <username> <password>",
+ Function: sendLogin,
+ },
+
+ Command{
+ Name: "exit",
Description: "Logout from server",
- Signature: "/quit",
- Function: sendQuit,
+ Signature: "/exit",
+ Function: sendQuit,
},
}
}
-func printHelp(v *gocui.View, input string) error {
- for _, cmd := range commands {
- fmt.Fprintf(v, "%s - %s\n", cmd.Signature, cmd.Description)
+func printHelp(v *gocui.View, input []string) error {
+ if len(input) == 1 {
+ fmt.Fprintf(v, "Commands:\n")
+ for _, cmd := range commands {
+ fmt.Fprintf(v, "%s - %s\n", cmd.Signature, cmd.Description)
+ }
+ return nil
+ }
+ if len(input) == 2 {
+ for _, cmd := range commands {
+ if strings.HasPrefix(cmd.Signature, input[1]) {
+ fmt.Fprintf(v, "%s - %s\n", cmd.Signature, cmd.Description)
+ return nil
+ }
+ }
+ fmt.Fprintf(v, "%s is not a valid command. Type /help to see the full list of commands.\n", input[1])
+ return nil
}
+ fmt.Fprintln(v, "Too many arguments for /help command.")
+ fmt.Fprintf(v, "%s - %s\n", commands[0].Signature, commands[0].Description)
return nil
}
-func sendJoin(v *gocui.View, input string) error {
- serverConn, _ = net.Dial("tcp", serverAddr);
- // if err != nil {
- // return err
- // }
- _, err := serverConn.Write([]byte(input));
+func sendJoin(v *gocui.View, input []string) error {
+ var err error
+ if len(input) != 3 {
+ fmt.Fprintf(v, "Invalid join command.\n")
+ input := []string{"/help", "/join"}
+ printHelp(v, input)
+ return nil
+ }
+
+ if len(input[2]) != 32 {
+ fmt.Fprintf(v, "Invalid token: %s\nThe token should be a 32-character string.\n", input[1])
+ return nil
+ }
+
+ serverConn, err = net.Dial("tcp", input[1])
if err != nil {
- return err
+ fmt.Fprintf(v, "Server is not responding.\n")
+ return nil
+ }
+ _, err = serverConn.Write([]byte(input[0] + " " + input[2]))
+ if err != nil {
+ fmt.Fprintf(v, "Could not send join request to the server: %s\n", err)
+ return nil
}
go getMsg(serverConn)
return nil
}
-func sendQuit (v *gocui.View, input string) error {
- fmt.Fprintln(v, "Quiting from server");
+func sendSignup(v *gocui.View, input []string) error {
+ if len(input) != 3 {
+ fmt.Fprintf(v, "Invalid signup command.\n")
+ input := []string{"/help", "/signup"}
+ printHelp(v, input)
+ return nil
+ }
+
+ _, err := serverConn.Write([]byte(input[0] + " " + input[1] + " " + input[2]))
+ if err != nil {
+ fmt.Fprintf(v, "Could not send signup request to the server: %s\n", err)
+ }
+ return nil
+}
+
+func sendLogin(v *gocui.View, input []string) error {
+ if len(input) != 3 {
+ fmt.Fprintf(v, "Invalid login command.\n")
+ input := []string{"/help", "/login"}
+ printHelp(v, input)
+ return nil
+ }
+
+ _, err := serverConn.Write([]byte(input[0] + " " + input[1] + " " + input[2]))
+ if err != nil {
+ fmt.Fprintf(v, "Could not send login request to the server: %s\n", err)
+ }
+ return nil
+}
+
+func sendQuit(v *gocui.View, input []string) error {
+ fmt.Fprintln(v, "Quiting from server")
serverConn.Close()
return nil
}
-func getMsg (conn net.Conn) {
+func getMsg(conn net.Conn) {
readBuf := make([]byte, 512)
for {
- n, readErr := conn.Read(readBuf);
+ n, readErr := conn.Read(readBuf)
if readErr != nil {
return
@@ -172,7 +250,6 @@ func getMsg (conn net.Conn) {
}
}
-
func quit(g *gocui.Gui, v *gocui.View) error {
return gocui.ErrQuit
}
diff --git a/go.mod b/go.mod
@@ -5,5 +5,6 @@ go 1.24.0
require (
github.com/jroimartin/gocui v0.5.0 // indirect
github.com/mattn/go-runewidth v0.0.9 // indirect
+ github.com/mattn/go-sqlite3 v1.14.28 // indirect
github.com/nsf/termbox-go v1.1.1 // indirect
)
diff --git a/server.go b/server.go
@@ -1,76 +1,110 @@
package main
import (
- "os"
- "math/big"
"crypto/rand"
+ "fmt"
"log"
+ "math/big"
"net"
- "time"
+ "os"
"strings"
- "fmt"
+ "time"
+ "database/sql"
+ _"github.com/mattn/go-sqlite3"
)
-var commands [3]string = [3]string{"/help", "/login", "/quit"};
-
type msgType int
+
const (
msgConnect msgType = iota + 1
+ msgJoin
+ msgSignup
msgLogin
msgText
msgQuit
)
+var RequestMap map[string]msgType
+var userDB *sql.DB
+
+type Msg struct {
+ Author Client
+ Type msgType
+ Args []string
+}
+
type loginStage int
+
const (
verification loginStage = iota + 1
username
password
)
-
type Client struct {
- Stage loginStage
- UserName string
- Password string
+ Stage loginStage
+ UserName string
+ Password string
LastMsgTime time.Time
- Request msgType
- Strike int
- Banned bool
- BanEnd time.Time
- Conn net.Conn
- Text string
+ Strike int
+ Banned bool
+ BanEnd time.Time
+ Conn net.Conn
}
const (
msgCoolDownTimeSec = 1
- banLimit = 5
- banTimeoutSec = 180
- Port = "6969"
+ banLimit = 5
+ banTimeoutSec = 180
+ Port = "6969"
)
-func addClient(conn net.Conn, Client_q chan Client) {
- // loginPrompt := "Who are you?\n> "
- // _, err := conn.Write([]byte(loginPrompt))
- // if err != nil {
- // log.Printf("[ERROR] Could not send login prompt to user %s: %s\n",
- // conn.RemoteAddr().String(), err)
- // }
+func initReqMap() {
+ RequestMap = make(map[string]msgType)
+ RequestMap["/join"] = msgJoin
+ RequestMap["/signup"] = msgSignup
+ RequestMap["/login"] = msgLogin
+}
+
+func clientRoutine(conn net.Conn, Msg_q chan Msg) {
+ Msg_q <- Msg{
+ Type: msgConnect,
+ Author: Client{
+ Conn: conn,
+ Stage: verification,
+ },
+ }
readBuf := make([]byte, 512)
for {
n, err := conn.Read(readBuf)
if n == 0 {
- Client_q <- Client {
- Request: msgQuit,
- Conn: conn,
+ Msg_q <- Msg{
+ Type: msgQuit,
+ Author: Client{
+ Conn: conn,
+ },
}
- return;
+ return
}
if n > 0 {
- Client_q <- Client {
- Request: msgText,
- Conn: conn,
- Text: string(readBuf[:n]),
+ if strings.HasPrefix(string(readBuf), "/") {
+ items := strings.Split(string(readBuf[:n]), " ")
+ Msg_q <- Msg{
+ Type: RequestMap[items[0]],
+ Args: items[1:len(items)],
+ Author: Client{
+ Conn: conn,
+ },
+ }
+ } else {
+ log.Printf("Got message from %s\n", conn.RemoteAddr().String())
+ Msg_q <- Msg{
+ Type: msgText,
+ Args: []string{string(readBuf[:n])},
+ Author: Client{
+ Conn: conn,
+ },
+ }
}
}
if err != nil {
@@ -83,33 +117,35 @@ func addClient(conn net.Conn, Client_q chan Client) {
func canMessage(client *Client) bool {
if !client.Banned {
- diff := time.Now().Sub(client.LastMsgTime).Seconds();
+ diff := time.Now().Sub(client.LastMsgTime).Seconds()
if diff <= msgCoolDownTimeSec {
- client.Strike += 1;
+ client.Strike += 1
if client.Strike >= banLimit {
- client.Banned = true;
+ client.Banned = true
client.BanEnd = time.Now().Add(banTimeoutSec * time.Second)
}
- return false;
+ return false
}
return true
}
- banTimeRemaining := client.BanEnd.Sub(time.Now()).Seconds();
- if banTimeRemaining >= 0.0 {
+ banTimeRemaining := client.BanEnd.Sub(time.Now()).Seconds()
+ if banTimeRemaining >= 0.0 {
banTimeRemainingStr := fmt.Sprintf("You're banned. Try again in %.0f seconds.\n", banTimeRemaining)
- client.Conn.Write([]byte(banTimeRemainingStr));
- return false;
- }
- client.Strike = 0;
- client.Banned = false;
- return true;
+ client.Conn.Write([]byte(banTimeRemainingStr))
+ return false
+ }
+ client.Strike = 0
+ client.Banned = false
+ return true
}
func checkForDuplicateUN(needle string, heystack map[string]Client) bool {
for _, client := range heystack {
- if client.UserName == needle { return true }
+ if client.UserName == needle {
+ return true
+ }
}
- return false;
+ return false
}
func verifyToken(input string) bool {
@@ -128,17 +164,27 @@ func verifyToken(input string) bool {
return input == string(tokenBytes)
}
+func trimNewline(input rune) bool {
+ if input == '\n' {
+ return true
+ }
+ if input == '\r' {
+ return true
+ }
+ return false
+}
+
func login(client *Client, msg string) bool {
if client.Stage == verification {
- if !verifyToken(strings.TrimRight(msg, "\r\n")) {
- _, err := client.Conn.Write([]byte("Invalid token.\n"));
+ if !verifyToken(strings.TrimRightFunc(msg, trimNewline)) {
+ _, err := client.Conn.Write([]byte("Invalid token.\n"))
if err != nil {
log.Printf("Could not send invalid token message to client %s\n", client.Conn.RemoteAddr().String())
}
client.LastMsgTime = time.Now()
- return false;
+ return false
}
- _, err := client.Conn.Write([]byte("Token verified successfully!\nEnter your user name"));
+ _, err := client.Conn.Write([]byte("Token verified successfully!\nEnter your user name"))
if err != nil {
log.Printf("Could not send verification message to %s\n", client.Conn.RemoteAddr().String())
}
@@ -147,18 +193,18 @@ func login(client *Client, msg string) bool {
return false
}
if client.Stage == username {
- client.UserName = strings.TrimRight(msg, "\r\n");
+ client.UserName = strings.TrimRight(msg, "\r\n")
client.LastMsgTime = time.Now()
- _, err := client.Conn.Write([]byte("Enter your password:\n"));
+ _, err := client.Conn.Write([]byte("Enter your password:\n"))
if err != nil {
log.Printf("Could not send passowrd message to %s\n", client.Conn.RemoteAddr().String())
}
client.Stage = password
return false
}
- client.Password = strings.TrimRight(msg, "\r\n");
+ client.Password = strings.TrimRight(msg, "\r\n")
client.LastMsgTime = time.Now()
- _, err := client.Conn.Write([]byte("Welcome " + client.UserName + "\n"));
+ _, err := client.Conn.Write([]byte("Welcome " + client.UserName + "\n"))
if err != nil {
log.Printf("Could not send welcome message to %s\n", client.Conn.RemoteAddr().String())
}
@@ -166,60 +212,172 @@ func login(client *Client, msg string) bool {
return true
}
-func server(Client_q chan Client) {
+func isUserInDB (username string) (bool, error) {
+ var count int
+ check := "SELECT COUNT(*) FROM users WHERE username = ?"
+ err := userDB.QueryRow(check, username).Scan(&count)
+ if err != nil {
+ return false, err
+ }
+ return count > 0, nil
+}
+
+func addUserToDB (client Client) error {
+ insert := "INSERT INTO users (username, password, banned, banEnd) VALUES ($1, $2, $3, $4)"
+ _, err := userDB.Exec(insert, client.UserName, client.Password, client.Banned, client.BanEnd)
+ return err
+}
+
+func checkPassword(username string, inputPass string) (bool, error) {
+ var dbPass string
+ query := "SELECT password FROM users WHERE username = ?"
+ err := userDB.QueryRow(query, username).Scan(&dbPass)
+ if err != nil {
+ return false, nil
+ }
+ return dbPass == inputPass, nil
+}
+
+func server(Msg_q chan Msg) {
onlineList := make(map[string]Client)
+ joinedList := make(map[string]Client)
offlineList := make(map[string]Client)
for {
- client := <-Client_q
- keyString := client.Conn.RemoteAddr().String();
- switch client.Request {
+ msg := <-Msg_q
+ keyString := msg.Author.Conn.RemoteAddr().String()
+ switch msg.Type {
case msgConnect:
// TODO: implement rate limit for connection requests
- log.Printf("Got join request from %s\n", keyString);
- offlineList[keyString] = client;
+ log.Printf("Got connection request from %s\n", keyString)
+ offlineList[keyString] = msg.Author
+ break
case msgQuit:
author, ok := onlineList[keyString]
if ok {
- log.Printf("%s logged out.\n", author.UserName);
- author.Conn.Close();
- delete(onlineList, keyString);
+ log.Printf("%s logged out.\n", author.UserName)
+ author.Conn.Close()
+ delete(onlineList, keyString)
+ break
}
- case msgText:
- clientOffline, ok := offlineList[keyString]
- if ok {
- if !canMessage(&clientOffline) {
- offlineList[keyString] = clientOffline // update the timings
+ break
+
+ case msgJoin:
+ log.Printf("server: Got join request from %s\n", keyString)
+ _, offline := offlineList[keyString]
+ _, joined := joinedList[keyString]
+ if offline {
+ if verifyToken(msg.Args[0]) {
+ joinedList[keyString] = msg.Author
+ delete(offlineList, keyString)
+ msg.Author.Conn.Write([]byte("Authentication successfull."))
+ break
+ }
+ msg.Author.Conn.Write([]byte("Provided token is not valid."))
+ break
+ }
+ if joined {
+ msg.Author.Conn.Write([]byte("You are already joined the server.\nTry logging in or signing up."))
+ break
+ }
+ msg.Author.Conn.Write([]byte("You are currently logged in."))
+ break
+
+ case msgSignup:
+ _, offline := offlineList[keyString]
+ _, joined := joinedList[keyString]
+ _, online := onlineList[keyString]
+ if online {
+ msg.Author.Conn.Write([]byte("This username already exists."))
+ break
+ }
+ if joined {
+ yes, err := isUserInDB(msg.Args[0])
+ if err != nil {
+ msg.Author.Conn.Write([]byte("Database error: " + err.Error()))
break
}
- if login(&clientOffline, client.Text) {
- onlineList[keyString] = clientOffline
- delete(offlineList, keyString)
+ if yes {
+ msg.Author.Conn.Write([]byte("This username already exists."))
+ break
+ }
+ msg.Author.UserName = msg.Args[0]
+ msg.Author.Password = msg.Args[1]
+ msg.Author.LastMsgTime = time.Now()
+ if err := addUserToDB(msg.Author); err != nil {
+ msg.Author.Conn.Write([]byte("Could not signup user. Database error: " + err.Error()))
+ break
+ }
+ onlineList[keyString] = msg.Author
+ delete(joinedList, keyString)
+ msg.Author.Conn.Write([]byte("Welcome " + msg.Author.UserName))
+ break
+ }
+ if offline {
+ msg.Author.Conn.Write([]byte("You should provide the token first with the /join command.\n"))
+ break
+ }
+ case msgLogin:
+ _, offline := offlineList[keyString]
+ _, joined := joinedList[keyString]
+ _, online := onlineList[keyString]
+ if online {
+ msg.Author.Conn.Write([]byte("You are currently logged in."))
+ break
+ }
+ if offline {
+ msg.Author.Conn.Write([]byte("You should provide the token first with the /join command."))
+ break
+ }
+ if joined {
+ yes, err := isUserInDB(msg.Args[0])
+ if err != nil {
+ msg.Author.Conn.Write([]byte("Database error: " + err.Error()))
+ break
+ }
+ if !yes {
+ msg.Author.Conn.Write([]byte("Username does not exist. You can create new user using /signup command."))
+ break
+ }
+ passOK, err := checkPassword(msg.Args[0], msg.Args[1])
+ if err != nil {
+ msg.Author.Conn.Write([]byte("Database error: " + err.Error()))
break
}
- offlineList[keyString] = clientOffline
+ if !passOK {
+ // TODO: limited retry
+ msg.Author.Conn.Write([]byte("Incorrect password."))
+ break
+ }
+ msg.Author.UserName = msg.Args[0]
+ msg.Author.Password = msg.Args[1]
+ msg.Author.LastMsgTime = time.Now()
+ onlineList[keyString] = msg.Author
+ delete(joinedList, keyString)
+ msg.Author.Conn.Write([]byte("Welcome " + msg.Author.UserName))
break
}
+ break
- author, ok := onlineList[keyString];
- if !ok {
- log.Fatal("cannot find client\n");
+ case msgText:
+ author, online := onlineList[keyString]
+ if !online {
+ msg.Author.Conn.Write([]byte("You must be logged in to send messages.\n"))
+ break
}
if !canMessage(&author) {
onlineList[keyString] = author // update the timings
- break;
+ break
}
author.LastMsgTime = time.Now()
- author.Text = client.Text;
onlineList[keyString] = author
- for _, value := range onlineList {
- // if value.Conn == author.Conn {
- // continue
- // }
- _, err := value.Conn.Write([]byte(author.UserName + ": " + author.Text))
+ for _, client := range onlineList {
+ _, err := client.Conn.Write([]byte(author.UserName + ": " + msg.Args[0]))
if err != nil {
- log.Printf("Could not send message to client %s\n", value.UserName)
+ log.Printf("Could not send message to client %s\n", client.UserName)
}
}
+ default:
+ log.Printf("server: Invalid request\n")
}
}
}
@@ -234,7 +392,7 @@ func genToken() {
if err != nil {
log.Fatalf("Could not generate random number: %s\n", err)
}
- tokenStr = fmt.Sprintf(tokenStr + "%X", randInt)
+ tokenStr = fmt.Sprintf(tokenStr+"%X", randInt)
}
tokenFile, err := os.Create("TOKEN")
if err != nil {
@@ -247,13 +405,28 @@ func genToken() {
}
func main() {
+ userDB, _ = sql.Open("sqlite3", "./users.db")
+ createTable := `CREATE TABLE IF NOT EXISTS users (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ username VARCHAR(50) UNIQUE NOT NULL,
+ password VARCHAR(100) NOT NULL,
+ banned BOOLEAN,
+ banEnd TIMESTAMP
+ );`
+ _, err := userDB.Exec(createTable)
+ if err != nil {
+ log.Fatalf("Error creating users table: %s\n", err)
+ }
+ log.Printf("Users table created.\n")
+
+ initReqMap()
genToken()
ln, err := net.Listen("tcp", ":"+Port)
if err != nil {
log.Fatalf("Could not listen to port %s: %s\n", Port, err)
}
- Client_q := make(chan Client)
- go server(Client_q)
+ Msg_q := make(chan Msg)
+ go server(Msg_q)
for {
conn, err := ln.Accept()
if err != nil {
@@ -261,11 +434,13 @@ func main() {
continue
}
log.Printf("Accepted connection from %s", conn.RemoteAddr())
- Client_q <- Client {
- Request: msgConnect,
- Conn: conn,
- Stage: verification,
- }
- go addClient(conn, Client_q)
+ // Msg_q <- Msg{
+ // Type: msgConnect,
+ // Author: Client{
+ // Conn: conn,
+ // Stage: verification,
+ // },
+ // }
+ go clientRoutine(conn, Msg_q)
}
}