diff --git a/Makefile b/Makefile index 2d4cce3..392e07d 100644 --- a/Makefile +++ b/Makefile @@ -30,8 +30,17 @@ users.h: asfur: ${OBJ} ${CC} -o $@ ${OBJ} ${LDFLAGS} +test: test_network + ./test_network + +test_network: test_network.o network.o password.o + ${CC} -o $@ test_network.o network.o password.o ${LDFLAGS} + +test_network.o: test_network.c test.h network.h client.h password.h + ${CC} -c ${CFLAGS} test_network.c + clean: - rm -f asfur ${OBJ} asfur-${VERSION}.tar.gz + rm -f asfur test_network ${OBJ} test_network.o asfur-${VERSION}.tar.gz dist: clean mkdir -p asfur-${VERSION} @@ -53,4 +62,4 @@ uninstall: rm -f ${DESTDIR}${PREFIX}/bin/asfur\ ${DESTDIR}${MANPREFIX}/man1/asfur.1 -.PHONY: all options clean dist install uninstall +.PHONY: all options clean dist install uninstall test diff --git a/PROTOCOL b/PROTOCOL index bed9763..1843fc0 100644 --- a/PROTOCOL +++ b/PROTOCOL @@ -1,20 +1,177 @@ -for client: -server parses packets from left to right like this: - [u16 length][u8 type][char *data] +ASFUR PROTOCOL SPECIFICATION - length: the length of the entire packet - type: - PACKET_ERROR = 0 - PACKET_REGISTER = 1 - PACKET_AUTHENTICATE = 2 - PACKET_JOIN = 3 - PACKET_TEXT = 4 - PACKET_LEAVE = 5 - - data: bytes of data, what it could represent depends on the packet type - (following layouts are layouts for data byte array) - PACKET_REGISTER : [char username[20]][char password[20]] - PACKET_AUTHENTICATE : [char username[20]][char password[20]] - PACKET_JOIN : [] - PACKET_TEXT : - PACKET_LEAVE : +All packets follow this structure: + [u16 size][u8 type][data...] + + size: total length of the packet (including header) + type: packet type identifier + data: payload bytes (format depends on type) + +================================================================================ +PACKET TYPES +================================================================================ + +Client -> Server: + PACKET_REGISTER = 1 + PACKET_AUTHENTICATE = 2 + PACKET_JOIN = 3 + PACKET_TEXT = 4 + PACKET_LEAVE = 5 + PACKET_DM_OPEN = 7 + PACKET_CREATE_ROOM = 9 + PACKET_DELETE_ROOM = 10 + PACKET_LIST_ROOMS = 12 + +Server -> Client: + PACKET_ERROR = 0 + PACKET_OK = 6 + PACKET_DM_ROOM = 8 + PACKET_ROOM_CREATED = 11 + PACKET_ROOM_LIST = 13 + +Bidirectional: + PACKET_TEXT = 4 (client sends, server broadcasts) + +================================================================================ +PACKET DATA LAYOUTS +================================================================================ + +PACKET_ERROR (0) + Server response indicating an error occurred. + [u8 code] + + Error codes: + 0 = ERR_OK (no error) + 1 = ERR_UNKNOWN (unknown error) + 2 = ERR_INVALID_PACKET (malformed packet) + 3 = ERR_NOT_AUTHENTICATED (action requires authentication) + 4 = ERR_ALREADY_REGISTERED (username already taken) + 5 = ERR_INVALID_CREDENTIALS (wrong username/password) + 6 = ERR_REGISTRATION_DISABLED (server disabled registration) + 7 = ERR_DATABASE (internal database error) + 8 = ERR_USER_NOT_FOUND (target user does not exist) + 9 = ERR_ACCESS_DENIED (not allowed to access resource) + 10 = ERR_ROOM_NOT_FOUND (room does not exist) + 11 = ERR_ROOM_NAME_TAKEN (room name already in use) + 12 = ERR_NOT_ROOM_OWNER (action requires room ownership) + +PACKET_REGISTER (1) + Register a new user account. + [char username[20]][char password[100]] + + Response: PACKET_OK or PACKET_ERROR + +PACKET_AUTHENTICATE (2) + Authenticate with existing credentials. + [char username[20]][char password[100]] + + Response: PACKET_OK or PACKET_ERROR + +PACKET_JOIN (3) + Join a room (public or DM). + [u64 room_id] + + Response: PACKET_OK or PACKET_ERROR + Requires: authentication + +PACKET_TEXT (4) + Send a message to a room (client -> server). + [u64 room_id][char message[...]] + + message is variable length (packet size - header - 8 bytes) + Client must have joined the room first. + Requires: authentication, joined room + + Broadcast format (server -> client): + [u64 room_id][char username[32]][char message[...]] + + Server broadcasts to all authenticated clients in the room. + +PACKET_LEAVE (5) + Leave a room. + [u64 room_id] + + Response: PACKET_OK or PACKET_ERROR + Requires: authentication + +PACKET_OK (6) + Server response indicating success. + (no data) + +PACKET_DM_OPEN (7) + Open or get existing DM room with another user. + [char username[20]] + + Response: PACKET_DM_ROOM or PACKET_ERROR + Requires: authentication + +PACKET_DM_ROOM (8) + Server response with DM room ID. + [u64 room_id] + + The room_id can be used with PACKET_JOIN. + +PACKET_CREATE_ROOM (9) + Create a new public room. + [char name[32]] + + Response: PACKET_ROOM_CREATED or PACKET_ERROR + Requires: authentication + +PACKET_DELETE_ROOM (10) + Delete a public room. + [u64 room_id] + + Response: PACKET_OK or PACKET_ERROR + Requires: authentication, room ownership + +PACKET_ROOM_CREATED (11) + Server response after successful room creation. + [u64 room_id][char name[32]] + +PACKET_LIST_ROOMS (12) + Request list of all public rooms. + (no data) + + Response: PACKET_ROOM_LIST or PACKET_ERROR + Requires: authentication + +PACKET_ROOM_LIST (13) + Server response with list of public rooms. + [u32 count][room_entry[count]] + + room_entry format: + [u64 room_id][char name[32]][char owner[20]] + +================================================================================ +TYPICAL FLOWS +================================================================================ + +Registration: + Client: PACKET_REGISTER(username, password) + Server: PACKET_OK or PACKET_ERROR(ERR_ALREADY_REGISTERED) + +Authentication: + Client: PACKET_AUTHENTICATE(username, password) + Server: PACKET_OK or PACKET_ERROR(ERR_INVALID_CREDENTIALS) + +Join public room: + Client: PACKET_LIST_ROOMS + Server: PACKET_ROOM_LIST(rooms) + Client: PACKET_JOIN(room_id) + Server: PACKET_OK + +Send message: + Client: PACKET_TEXT(room_id, "hello") + Server: PACKET_TEXT(room_id, sender_username, "hello") -> all clients in room + +Direct message: + Client: PACKET_DM_OPEN(target_username) + Server: PACKET_DM_ROOM(room_id) + Client: PACKET_JOIN(room_id) + Server: PACKET_OK + Client: PACKET_TEXT(room_id, "private message") + +Create room: + Client: PACKET_CREATE_ROOM("my-room") + Server: PACKET_ROOM_CREATED(room_id, "my-room") diff --git a/config.def.h b/config.def.h index 5e6a394..b82f9d2 100644 --- a/config.def.h +++ b/config.def.h @@ -4,3 +4,4 @@ #define CONFIG_CERT_FILE "cert.pem" #define CONFIG_KEY_FILE "key.pem" +#define CONFIG_DB_PATH "asfur.db" diff --git a/config.mk b/config.mk index 0215fc5..fd36c1c 100644 --- a/config.mk +++ b/config.mk @@ -12,7 +12,7 @@ MANPREFIX = ${PREFIX}/share/man # includes and libs INCS = -I. -LIBS = -lssl -lcrypto -lcrypt -luv +LIBS = -lssl -lcrypto -lcrypt -luv -lsqlite3 # flags CFLAGS := -std=c11 -pedantic -Wall -O0 ${INCS} -DVERSION=\"${VERSION}\" CFLAGS := ${CFLAGS} -g diff --git a/main.c b/main.c index f6c5315..56f4060 100644 --- a/main.c +++ b/main.c @@ -5,7 +5,9 @@ #include #include #include +#include "config.h" #include "client.h" +#include "network.h" uv_loop_t *loop; SSL_CTX *ctx; @@ -13,6 +15,7 @@ SSL_CTX *ctx; void cleanup_client(uv_handle_t *handle) { struct client *client = (struct client*) handle; + network_client_remove(client); if (client->ssl) { SSL_free(client->ssl); } @@ -65,11 +68,11 @@ void on_read(uv_stream_t *stream, ssize_t nread, const uv_buf_t *buf) char plain_buf[4096]; int p; - + while ((p = SSL_read(client->ssl, plain_buf, sizeof(plain_buf))) > 0) { - ssl_write_msg(client, plain_buf, p); + network_handle_data(client, plain_buf, p); } - + flush_ssl_to_socket(client); } else if (nread < 0) { @@ -102,7 +105,13 @@ void on_new_connection(uv_stream_t *server, int status) client->wbio = BIO_new(BIO_s_mem()); SSL_set_bio(client->ssl, client->rbio, client->wbio); SSL_set_accept_state(client->ssl); - + + client->user_id = 0; + client->current_room_id = 0; + client->is_authenticated = false; + client->username[0] = '\0'; + + network_client_add(client); uv_read_start((uv_stream_t*) &client->handle, alloc_buffer, on_read); } else { uv_close((uv_handle_t*) &client->handle, cleanup_client); @@ -132,6 +141,12 @@ void init_openssl() int main() { init_openssl(); + + if (network_init() != 0) { + fprintf(stderr, "Failed to initialize network/database\n"); + return 1; + } + loop = uv_default_loop(); uv_tcp_t server; diff --git a/network.c b/network.c index 481dc96..153023b 100644 --- a/network.c +++ b/network.c @@ -1,7 +1,678 @@ #define _GNU_SOURCE +#include +#include +#include +#include #include "network.h" +#include "client.h" +#include "password.h" +#include "config.h" -void network_parse_packet() +#define MAX_CLIENTS 1024 + +static sqlite3 *db = NULL; +static struct client *clients[MAX_CLIENTS]; +static int client_count = 0; + +static const char *schema_sql = + "CREATE TABLE IF NOT EXISTS users (" + " id INTEGER PRIMARY KEY AUTOINCREMENT," + " username TEXT UNIQUE NOT NULL," + " password_hash TEXT NOT NULL," + " created_at INTEGER DEFAULT (strftime('%s', 'now'))" + ");" + "CREATE TABLE IF NOT EXISTS dm_rooms (" + " id INTEGER PRIMARY KEY AUTOINCREMENT," + " user1_id INTEGER NOT NULL," + " user2_id INTEGER NOT NULL," + " created_at INTEGER DEFAULT (strftime('%s', 'now'))," + " UNIQUE(user1_id, user2_id)," + " FOREIGN KEY(user1_id) REFERENCES users(id)," + " FOREIGN KEY(user2_id) REFERENCES users(id)" + ");" + "CREATE TABLE IF NOT EXISTS rooms (" + " id INTEGER PRIMARY KEY AUTOINCREMENT," + " name TEXT UNIQUE NOT NULL," + " owner_id INTEGER NOT NULL," + " created_at INTEGER DEFAULT (strftime('%s', 'now'))," + " FOREIGN KEY(owner_id) REFERENCES users(id)" + ");"; + +int network_init(void) { + int rc = sqlite3_open(CONFIG_DB_PATH, &db); + if (rc != SQLITE_OK) { + fprintf(stderr, "Cannot open database: %s\n", sqlite3_errmsg(db)); + return -1; + } + char *err_msg = NULL; + rc = sqlite3_exec(db, schema_sql, NULL, NULL, &err_msg); + if (rc != SQLITE_OK) { + fprintf(stderr, "SQL error: %s\n", err_msg); + sqlite3_free(err_msg); + return -1; + } + + return 0; +} + +void network_shutdown(void) +{ + if (db) { + sqlite3_close(db); + db = NULL; + } +} + +void network_client_add(struct client *client) +{ + if (client_count >= MAX_CLIENTS) return; + clients[client_count++] = client; +} + +void network_client_remove(struct client *client) +{ + for (int i = 0; i < client_count; i++) { + if (clients[i] == client) { + clients[i] = clients[--client_count]; + return; + } + } +} + +static void send_error(struct client *client, error_code code); +static void send_ok(struct client *client); +static int is_dm_room_member(uint64_t room_id, uint32_t user_id); +static int is_public_room(uint64_t room_id); + +extern void ssl_write_msg(struct client *client, const char *data, size_t len); + +static void send_packet(struct client *client, uint8_t type, const void *data, size_t data_len) +{ + size_t total_len = sizeof(struct packet_header) + data_len; + char *buf = malloc(total_len); + if (!buf) return; + + struct packet_header *hdr = (struct packet_header *)buf; + hdr->size = (uint16_t)total_len; + hdr->type = type; + + if (data && data_len > 0) { + memcpy(buf + sizeof(struct packet_header), data, data_len); + } + + ssl_write_msg(client, buf, total_len); + free(buf); +} + +static void send_error(struct client *client, error_code code) +{ + struct packet_error err = { .code = (uint8_t)code }; + send_packet(client, PACKET_ERROR, &err, sizeof(err)); +} + +static void send_ok(struct client *client) +{ + send_packet(client, PACKET_OK, NULL, 0); +} + +static int handle_register(struct client *client, const char *data, size_t len) +{ + if (len < sizeof(struct packet_register)) { + send_error(client, ERR_INVALID_PACKET); + return -1; + } + +#ifndef CONFIG_ALLOW_REGISTER + send_error(client, ERR_REGISTRATION_DISABLED); + return -1; +#else + if (!CONFIG_ALLOW_REGISTER) { + send_error(client, ERR_REGISTRATION_DISABLED); + return -1; + } +#endif + + const struct packet_register *reg = (const struct packet_register *)data; + + char username[21] = {0}; + char password[101] = {0}; + memcpy(username, reg->username, 20); + memcpy(password, reg->password, 100); + + if (strlen(username) == 0 || strlen(password) == 0) { + send_error(client, ERR_INVALID_PACKET); + return -1; + } + + char *hash = hash_password(password); + if (!hash) { + send_error(client, ERR_DATABASE); + return -1; + } + + sqlite3_stmt *stmt; + const char *sql = "INSERT INTO users (username, password_hash) VALUES (?, ?);"; + int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); + if (rc != SQLITE_OK) { + free(hash); + send_error(client, ERR_DATABASE); + return -1; + } + + sqlite3_bind_text(stmt, 1, username, -1, SQLITE_STATIC); + sqlite3_bind_text(stmt, 2, hash, -1, SQLITE_STATIC); + + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + free(hash); + + if (rc == SQLITE_CONSTRAINT) { + send_error(client, ERR_ALREADY_REGISTERED); + return -1; + } else if (rc != SQLITE_DONE) { + send_error(client, ERR_DATABASE); + return -1; + } + + send_ok(client); + return 0; +} + +static int handle_authenticate(struct client *client, const char *data, size_t len) +{ + if (len < sizeof(struct packet_auth)) { + send_error(client, ERR_INVALID_PACKET); + return -1; + } + + const struct packet_auth *auth = (const struct packet_auth *)data; + + char username[21] = {0}; + char password[101] = {0}; + memcpy(username, auth->username, 20); + memcpy(password, auth->password, 100); + + if (strlen(username) == 0 || strlen(password) == 0) { + send_error(client, ERR_INVALID_CREDENTIALS); + return -1; + } + + sqlite3_stmt *stmt; + const char *sql = "SELECT id, password_hash FROM users WHERE username = ?;"; + int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); + if (rc != SQLITE_OK) { + send_error(client, ERR_DATABASE); + return -1; + } + + sqlite3_bind_text(stmt, 1, username, -1, SQLITE_STATIC); + + rc = sqlite3_step(stmt); + if (rc != SQLITE_ROW) { + sqlite3_finalize(stmt); + send_error(client, ERR_INVALID_CREDENTIALS); + return -1; + } + + int user_id = sqlite3_column_int(stmt, 0); + const char *stored_hash = (const char *)sqlite3_column_text(stmt, 1); + + if (!verify_password(password, stored_hash)) { + sqlite3_finalize(stmt); + send_error(client, ERR_INVALID_CREDENTIALS); + return -1; + } + + sqlite3_finalize(stmt); + + client->user_id = user_id; + client->is_authenticated = true; + strncpy(client->username, username, sizeof(client->username) - 1); + client->username[sizeof(client->username) - 1] = '\0'; + + send_ok(client); + return 0; +} + +static int handle_join(struct client *client, const char *data, size_t len) +{ + if (!client->is_authenticated) { + send_error(client, ERR_NOT_AUTHENTICATED); + return -1; + } + + if (len < sizeof(struct packet_join)) { + send_error(client, ERR_INVALID_PACKET); + return -1; + } + + const struct packet_join *join = (const struct packet_join *)data; + + if (is_public_room(join->room_id)) { + client->current_room_id = join->room_id; + send_ok(client); + return 0; + } + + if (is_dm_room_member(join->room_id, client->user_id)) { + client->current_room_id = join->room_id; + send_ok(client); + return 0; + } + + send_error(client, ERR_ACCESS_DENIED); + return -1; +} + +static int handle_leave(struct client *client, const char *data, size_t len) +{ + if (!client->is_authenticated) { + send_error(client, ERR_NOT_AUTHENTICATED); + return -1; + } + + if (len < sizeof(struct packet_leave)) { + send_error(client, ERR_INVALID_PACKET); + return -1; + } + + const struct packet_leave *leave = (const struct packet_leave *)data; + if (client->current_room_id == leave->room_id) { + client->current_room_id = 0; + } + + send_ok(client); + return 0; +} + +static int64_t get_or_create_dm_room(uint32_t user1_id, uint32_t user2_id) +{ + uint32_t low_id = user1_id < user2_id ? user1_id : user2_id; + uint32_t high_id = user1_id < user2_id ? user2_id : user1_id; + + sqlite3_stmt *stmt; + const char *select_sql = "SELECT id FROM dm_rooms WHERE user1_id = ? AND user2_id = ?;"; + int rc = sqlite3_prepare_v2(db, select_sql, -1, &stmt, NULL); + if (rc != SQLITE_OK) return -1; + + sqlite3_bind_int(stmt, 1, low_id); + sqlite3_bind_int(stmt, 2, high_id); + + rc = sqlite3_step(stmt); + if (rc == SQLITE_ROW) { + int64_t room_id = sqlite3_column_int64(stmt, 0); + sqlite3_finalize(stmt); + return room_id; + } + sqlite3_finalize(stmt); + + const char *insert_sql = "INSERT INTO dm_rooms (user1_id, user2_id) VALUES (?, ?);"; + rc = sqlite3_prepare_v2(db, insert_sql, -1, &stmt, NULL); + if (rc != SQLITE_OK) return -1; + + sqlite3_bind_int(stmt, 1, low_id); + sqlite3_bind_int(stmt, 2, high_id); + + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + + if (rc != SQLITE_DONE) return -1; + + return sqlite3_last_insert_rowid(db); +} + +static int is_dm_room_member(uint64_t room_id, uint32_t user_id) +{ + sqlite3_stmt *stmt; + const char *sql = "SELECT 1 FROM dm_rooms WHERE id = ? AND (user1_id = ? OR user2_id = ?);"; + int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); + if (rc != SQLITE_OK) return 0; + + sqlite3_bind_int64(stmt, 1, room_id); + sqlite3_bind_int(stmt, 2, user_id); + sqlite3_bind_int(stmt, 3, user_id); + + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + + return rc == SQLITE_ROW; +} + +static int is_public_room(uint64_t room_id) +{ + sqlite3_stmt *stmt; + const char *sql = "SELECT 1 FROM rooms WHERE id = ?;"; + int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); + if (rc != SQLITE_OK) return 0; + + sqlite3_bind_int64(stmt, 1, room_id); + + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + + return rc == SQLITE_ROW; +} + +static int is_room_owner(uint64_t room_id, uint32_t user_id) +{ + sqlite3_stmt *stmt; + const char *sql = "SELECT 1 FROM rooms WHERE id = ? AND owner_id = ?;"; + int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); + if (rc != SQLITE_OK) return 0; + + sqlite3_bind_int64(stmt, 1, room_id); + sqlite3_bind_int(stmt, 2, user_id); + + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + + return rc == SQLITE_ROW; +} + +static int handle_create_room(struct client *client, const char *data, size_t len) +{ + if (!client->is_authenticated) { + send_error(client, ERR_NOT_AUTHENTICATED); + return -1; + } + + if (len < sizeof(struct packet_create_room)) { + send_error(client, ERR_INVALID_PACKET); + return -1; + } + + const struct packet_create_room *create = (const struct packet_create_room *)data; + + char name[33] = {0}; + memcpy(name, create->name, 32); + + if (strlen(name) == 0) { + send_error(client, ERR_INVALID_PACKET); + return -1; + } + + sqlite3_stmt *stmt; + const char *sql = "INSERT INTO rooms (name, owner_id) VALUES (?, ?);"; + int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); + if (rc != SQLITE_OK) { + send_error(client, ERR_DATABASE); + return -1; + } + + sqlite3_bind_text(stmt, 1, name, -1, SQLITE_STATIC); + sqlite3_bind_int(stmt, 2, client->user_id); + + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + + if (rc == SQLITE_CONSTRAINT) { + send_error(client, ERR_ROOM_NAME_TAKEN); + return -1; + } else if (rc != SQLITE_DONE) { + send_error(client, ERR_DATABASE); + return -1; + } + + int64_t room_id = sqlite3_last_insert_rowid(db); + + struct packet_room_created response; + response.room_id = (uint64_t)room_id; + memset(response.name, 0, sizeof(response.name)); + strncpy(response.name, name, sizeof(response.name) - 1); + + send_packet(client, PACKET_ROOM_CREATED, &response, sizeof(response)); + return 0; +} + +static int handle_delete_room(struct client *client, const char *data, size_t len) +{ + if (!client->is_authenticated) { + send_error(client, ERR_NOT_AUTHENTICATED); + return -1; + } + + if (len < sizeof(struct packet_delete_room)) { + send_error(client, ERR_INVALID_PACKET); + return -1; + } + + const struct packet_delete_room *del = (const struct packet_delete_room *)data; + + if (!is_public_room(del->room_id)) { + send_error(client, ERR_ROOM_NOT_FOUND); + return -1; + } + + if (!is_room_owner(del->room_id, client->user_id)) { + send_error(client, ERR_NOT_ROOM_OWNER); + return -1; + } + + sqlite3_stmt *stmt; + const char *sql = "DELETE FROM rooms WHERE id = ?;"; + int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); + if (rc != SQLITE_OK) { + send_error(client, ERR_DATABASE); + return -1; + } + + sqlite3_bind_int64(stmt, 1, del->room_id); + + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + + if (rc != SQLITE_DONE) { + send_error(client, ERR_DATABASE); + return -1; + } + + for (int i = 0; i < client_count; i++) { + if (clients[i]->current_room_id == del->room_id) { + clients[i]->current_room_id = 0; + } + } + + send_ok(client); + return 0; +} + +static int handle_list_rooms(struct client *client) +{ + if (!client->is_authenticated) { + send_error(client, ERR_NOT_AUTHENTICATED); + return -1; + } + + sqlite3_stmt *stmt; + const char *sql = "SELECT r.id, r.name, u.username FROM rooms r " + "JOIN users u ON r.owner_id = u.id ORDER BY r.name;"; + int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); + if (rc != SQLITE_OK) { + send_error(client, ERR_DATABASE); + return -1; + } + + struct room_list_entry entries[256]; + uint32_t count = 0; + + while ((rc = sqlite3_step(stmt)) == SQLITE_ROW && count < 256) { + entries[count].room_id = sqlite3_column_int64(stmt, 0); + + memset(entries[count].name, 0, sizeof(entries[count].name)); + const char *name = (const char *)sqlite3_column_text(stmt, 1); + if (name) strncpy(entries[count].name, name, 31); + + memset(entries[count].owner, 0, sizeof(entries[count].owner)); + const char *owner = (const char *)sqlite3_column_text(stmt, 2); + if (owner) strncpy(entries[count].owner, owner, 19); + + count++; + } + sqlite3_finalize(stmt); + + size_t data_len = sizeof(uint32_t) + count * sizeof(struct room_list_entry); + char *data = malloc(data_len); + if (!data) { + send_error(client, ERR_DATABASE); + return -1; + } + + memcpy(data, &count, sizeof(uint32_t)); + memcpy(data + sizeof(uint32_t), entries, count * sizeof(struct room_list_entry)); + + send_packet(client, PACKET_ROOM_LIST, data, data_len); + free(data); + return 0; +} + +static int handle_dm_open(struct client *client, const char *data, size_t len) +{ + if (!client->is_authenticated) { + send_error(client, ERR_NOT_AUTHENTICATED); + return -1; + } + + if (len < sizeof(struct packet_dm_open)) { + send_error(client, ERR_INVALID_PACKET); + return -1; + } + + const struct packet_dm_open *dm = (const struct packet_dm_open *)data; + + char target_username[21] = {0}; + memcpy(target_username, dm->username, 20); + + if (strlen(target_username) == 0) { + send_error(client, ERR_INVALID_PACKET); + return -1; + } + + sqlite3_stmt *stmt; + const char *sql = "SELECT id FROM users WHERE username = ?;"; + int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); + if (rc != SQLITE_OK) { + send_error(client, ERR_DATABASE); + return -1; + } + + sqlite3_bind_text(stmt, 1, target_username, -1, SQLITE_STATIC); + + rc = sqlite3_step(stmt); + if (rc != SQLITE_ROW) { + sqlite3_finalize(stmt); + send_error(client, ERR_USER_NOT_FOUND); + return -1; + } + + uint32_t target_user_id = sqlite3_column_int(stmt, 0); + sqlite3_finalize(stmt); + + int64_t room_id = get_or_create_dm_room(client->user_id, target_user_id); + if (room_id < 0) { + send_error(client, ERR_DATABASE); + return -1; + } + + struct packet_dm_room response = { .room_id = (uint64_t)room_id }; + send_packet(client, PACKET_DM_ROOM, &response, sizeof(response)); + return 0; +} + +static void broadcast_to_room(uint64_t room_id, const char *username, + const char *message, size_t msg_len) +{ + size_t bcast_len = sizeof(uint64_t) + 32 + msg_len; + char *bcast_data = malloc(bcast_len); + if (!bcast_data) return; + + memcpy(bcast_data, &room_id, sizeof(uint64_t)); + memset(bcast_data + sizeof(uint64_t), 0, 32); + strncpy(bcast_data + sizeof(uint64_t), username, 31); + memcpy(bcast_data + sizeof(uint64_t) + 32, message, msg_len); + + for (int i = 0; i < client_count; i++) { + struct client *c = clients[i]; + if (c->is_authenticated && c->current_room_id == room_id) { + send_packet(c, PACKET_TEXT, bcast_data, bcast_len); + } + } + + free(bcast_data); +} + +static int handle_text(struct client *client, const char *data, size_t len) +{ + if (!client->is_authenticated) { + send_error(client, ERR_NOT_AUTHENTICATED); + return -1; + } + + if (len < sizeof(uint64_t)) { + send_error(client, ERR_INVALID_PACKET); + return -1; + } + + const struct packet_text *text = (const struct packet_text *)data; + size_t msg_len = len - sizeof(uint64_t); + + if (client->current_room_id != text->room_id) { + send_error(client, ERR_NOT_AUTHENTICATED); + return -1; + } + + broadcast_to_room(text->room_id, client->username, text->message, msg_len); + return 0; +} + +void network_handle_data(struct client *client, const char *data, size_t len) +{ + if (len < sizeof(struct packet_header)) { + send_error(client, ERR_INVALID_PACKET); + return; + } + + const struct packet_header *hdr = (const struct packet_header *)data; + + if (hdr->size > len || hdr->size < sizeof(struct packet_header)) { + send_error(client, ERR_INVALID_PACKET); + return; + } + + const char *payload = data + sizeof(struct packet_header); + size_t payload_len = hdr->size - sizeof(struct packet_header); + + switch (hdr->type) { + case PACKET_REGISTER: + handle_register(client, payload, payload_len); + break; + case PACKET_AUTHENTICATE: + handle_authenticate(client, payload, payload_len); + break; + case PACKET_JOIN: + handle_join(client, payload, payload_len); + break; + case PACKET_LEAVE: + handle_leave(client, payload, payload_len); + break; + case PACKET_TEXT: + handle_text(client, payload, payload_len); + break; + case PACKET_DM_OPEN: + handle_dm_open(client, payload, payload_len); + break; + case PACKET_CREATE_ROOM: + handle_create_room(client, payload, payload_len); + break; + case PACKET_DELETE_ROOM: + handle_delete_room(client, payload, payload_len); + break; + case PACKET_LIST_ROOMS: + handle_list_rooms(client); + break; + default: + send_error(client, ERR_INVALID_PACKET); + break; + } } diff --git a/network.h b/network.h index 8d8c647..eae8bd7 100644 --- a/network.h +++ b/network.h @@ -3,9 +3,115 @@ #include +struct client; + +typedef enum { + PACKET_ERROR = 0, + PACKET_REGISTER = 1, + PACKET_AUTHENTICATE = 2, + PACKET_JOIN = 3, + PACKET_TEXT = 4, + PACKET_LEAVE = 5, + PACKET_OK = 6, + PACKET_DM_OPEN = 7, + PACKET_DM_ROOM = 8, + PACKET_CREATE_ROOM = 9, + PACKET_DELETE_ROOM = 10, + PACKET_ROOM_CREATED = 11, + PACKET_LIST_ROOMS = 12, + PACKET_ROOM_LIST = 13 +} packet_type; + +typedef enum { + ERR_OK = 0, + ERR_UNKNOWN = 1, + ERR_INVALID_PACKET = 2, + ERR_NOT_AUTHENTICATED = 3, + ERR_ALREADY_REGISTERED = 4, + ERR_INVALID_CREDENTIALS = 5, + ERR_REGISTRATION_DISABLED = 6, + ERR_DATABASE = 7, + ERR_USER_NOT_FOUND = 8, + ERR_ACCESS_DENIED = 9, + ERR_ROOM_NOT_FOUND = 10, + ERR_ROOM_NAME_TAKEN = 11, + ERR_NOT_ROOM_OWNER = 12 +} error_code; + struct packet_header { uint16_t size; uint8_t type; }; +struct packet_register { + char username[20]; + char password[100]; +}; + +struct packet_auth { + char username[20]; + char password[100]; +}; + +struct packet_join { + uint64_t room_id; +}; + +struct packet_leave { + uint64_t room_id; +}; + +struct packet_text { + uint64_t room_id; + char message[]; +}; + +struct packet_error { + uint8_t code; +}; + +struct packet_text_broadcast { + uint64_t room_id; + char username[32]; + char message[]; +}; + +struct packet_dm_open { + char username[20]; +}; + +struct packet_dm_room { + uint64_t room_id; +}; + +struct packet_create_room { + char name[32]; +}; + +struct packet_delete_room { + uint64_t room_id; +}; + +struct packet_room_created { + uint64_t room_id; + char name[32]; +}; + +struct room_list_entry { + uint64_t room_id; + char name[32]; + char owner[20]; +}; + +struct packet_room_list { + uint32_t count; + struct room_list_entry rooms[]; +}; + +int network_init(void); +void network_shutdown(void); +void network_handle_data(struct client *client, const char *data, size_t len); +void network_client_add(struct client *client); +void network_client_remove(struct client *client); + #endif diff --git a/test.h b/test.h new file mode 100644 index 0000000..78a8802 --- /dev/null +++ b/test.h @@ -0,0 +1,53 @@ +/* Simple unit test framework for asfur */ +#ifndef TEST_H +#define TEST_H + +#include +#include +#include + +static int tests_run = 0; +static int tests_passed = 0; +static int tests_failed = 0; + +#define TEST_ASSERT(cond, msg) do { \ + if (!(cond)) { \ + printf(" FAIL: %s\n", msg); \ + return 1; \ + } \ +} while(0) + +#define TEST_ASSERT_EQ(a, b, msg) do { \ + if ((a) != (b)) { \ + printf(" FAIL: %s (expected %d, got %d)\n", msg, (int)(b), (int)(a)); \ + return 1; \ + } \ +} while(0) + +#define TEST_ASSERT_STR_EQ(a, b, msg) do { \ + if (strcmp((a), (b)) != 0) { \ + printf(" FAIL: %s (expected '%s', got '%s')\n", msg, (b), (a)); \ + return 1; \ + } \ +} while(0) + +#define RUN_TEST(test_fn) do { \ + tests_run++; \ + printf("Running %s...\n", #test_fn); \ + if (test_fn() == 0) { \ + printf(" PASS\n"); \ + tests_passed++; \ + } else { \ + tests_failed++; \ + } \ +} while(0) + +#define TEST_SUMMARY() do { \ + printf("\n=== Test Summary ===\n"); \ + printf("Total: %d\n", tests_run); \ + printf("Passed: %d\n", tests_passed); \ + printf("Failed: %d\n", tests_failed); \ + return tests_failed > 0 ? 1 : 0; \ +} while(0) + +#endif diff --git a/test_network.c b/test_network.c new file mode 100644 index 0000000..d7fc99f --- /dev/null +++ b/test_network.c @@ -0,0 +1,1117 @@ +#define _GNU_SOURCE +#include +#include +#include +#include +#include "test.h" +#include "network.h" +#include "client.h" +#include "password.h" + +/* Mock structures and globals for capturing sent packets */ +#define MAX_SENT_PACKETS 32 +#define MAX_PACKET_SIZE 4096 + +static struct { + char data[MAX_PACKET_SIZE]; + size_t len; +} sent_packets[MAX_SENT_PACKETS]; +static int sent_packet_count = 0; + +/* Mock ssl_write_msg - captures packets for verification */ +void ssl_write_msg(struct client *client, const char *data, size_t len) +{ + (void)client; + if (sent_packet_count < MAX_SENT_PACKETS && len < MAX_PACKET_SIZE) { + memcpy(sent_packets[sent_packet_count].data, data, len); + sent_packets[sent_packet_count].len = len; + sent_packet_count++; + } +} + +/* Helper to reset captured packets */ +static void reset_sent_packets(void) +{ + sent_packet_count = 0; + memset(sent_packets, 0, sizeof(sent_packets)); +} + +/* Helper to get last sent packet header */ +static struct packet_header *get_last_packet_header(void) +{ + if (sent_packet_count == 0) return NULL; + return (struct packet_header *)sent_packets[sent_packet_count - 1].data; +} + +/* Helper to get last sent packet payload */ +static const char *get_last_packet_payload(void) +{ + if (sent_packet_count == 0) return NULL; + return sent_packets[sent_packet_count - 1].data + sizeof(struct packet_header); +} + +/* Helper to create a mock client */ +static struct client *create_mock_client(void) +{ + struct client *c = calloc(1, sizeof(struct client)); + c->user_id = 0; + c->current_room_id = 0; + c->is_authenticated = false; + c->username[0] = '\0'; + return c; +} + +/* Helper to build a packet */ +static size_t build_packet(char *buf, uint8_t type, const void *payload, size_t payload_len) +{ + size_t total = sizeof(struct packet_header) + payload_len; + struct packet_header *hdr = (struct packet_header *)buf; + hdr->size = (uint16_t)total; + hdr->type = type; + if (payload && payload_len > 0) { + memcpy(buf + sizeof(struct packet_header), payload, payload_len); + } + return total; +} + +/* ==== PACKET PARSING TESTS ==== */ + +static int test_invalid_packet_too_short(void) +{ + struct client *client = create_mock_client(); + reset_sent_packets(); + + /* Send a packet that's too short (less than header size) */ + char data[] = { 0x00, 0x01 }; /* Only 2 bytes */ + network_handle_data(client, data, sizeof(data)); + + TEST_ASSERT(sent_packet_count == 1, "Should send error response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ERROR, "Should be error packet"); + const struct packet_error *err = (const struct packet_error *)get_last_packet_payload(); + TEST_ASSERT_EQ(err->code, ERR_INVALID_PACKET, "Should be INVALID_PACKET error"); + + free(client); + return 0; +} + +static int test_invalid_packet_size_mismatch(void) +{ + struct client *client = create_mock_client(); + reset_sent_packets(); + + /* Header claims size of 100 but we only provide 3 bytes */ + char data[3]; + struct packet_header *hdr = (struct packet_header *)data; + hdr->size = 100; + hdr->type = PACKET_REGISTER; + + network_handle_data(client, data, sizeof(data)); + + TEST_ASSERT(sent_packet_count == 1, "Should send error response"); + hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ERROR, "Should be error packet"); + + free(client); + return 0; +} + +static int test_invalid_packet_unknown_type(void) +{ + struct client *client = create_mock_client(); + reset_sent_packets(); + + char buf[128]; + size_t len = build_packet(buf, 255, NULL, 0); /* Unknown type 255 */ + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send error response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ERROR, "Should be error packet"); + const struct packet_error *err = (const struct packet_error *)get_last_packet_payload(); + TEST_ASSERT_EQ(err->code, ERR_INVALID_PACKET, "Should be INVALID_PACKET error"); + + free(client); + return 0; +} + +/* ==== REGISTRATION TESTS ==== */ + +static int test_register_success(void) +{ + struct client *client = create_mock_client(); + reset_sent_packets(); + + struct packet_register reg; + memset(®, 0, sizeof(reg)); + strncpy(reg.username, "testuser1", sizeof(reg.username)); + strncpy(reg.password, "testpass123", sizeof(reg.password)); + + char buf[256]; + size_t len = build_packet(buf, PACKET_REGISTER, ®, sizeof(reg)); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_OK, "Should be OK packet for successful registration"); + + free(client); + return 0; +} + +static int test_register_duplicate_username(void) +{ + struct client *client = create_mock_client(); + reset_sent_packets(); + + /* Register first user */ + struct packet_register reg; + memset(®, 0, sizeof(reg)); + strncpy(reg.username, "dupuser", sizeof(reg.username)); + strncpy(reg.password, "pass123", sizeof(reg.password)); + + char buf[256]; + size_t len = build_packet(buf, PACKET_REGISTER, ®, sizeof(reg)); + network_handle_data(client, buf, len); + + /* Try to register same username again */ + reset_sent_packets(); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send error response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ERROR, "Should be error packet"); + const struct packet_error *err = (const struct packet_error *)get_last_packet_payload(); + TEST_ASSERT_EQ(err->code, ERR_ALREADY_REGISTERED, "Should be ALREADY_REGISTERED error"); + + free(client); + return 0; +} + +static int test_register_empty_username(void) +{ + struct client *client = create_mock_client(); + reset_sent_packets(); + + struct packet_register reg; + memset(®, 0, sizeof(reg)); + /* Empty username, valid password */ + strncpy(reg.password, "testpass123", sizeof(reg.password)); + + char buf[256]; + size_t len = build_packet(buf, PACKET_REGISTER, ®, sizeof(reg)); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send error response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ERROR, "Should be error packet"); + const struct packet_error *err = (const struct packet_error *)get_last_packet_payload(); + TEST_ASSERT_EQ(err->code, ERR_INVALID_PACKET, "Should be INVALID_PACKET error"); + + free(client); + return 0; +} + +static int test_register_empty_password(void) +{ + struct client *client = create_mock_client(); + reset_sent_packets(); + + struct packet_register reg; + memset(®, 0, sizeof(reg)); + strncpy(reg.username, "validuser", sizeof(reg.username)); + /* Empty password */ + + char buf[256]; + size_t len = build_packet(buf, PACKET_REGISTER, ®, sizeof(reg)); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send error response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ERROR, "Should be error packet"); + + free(client); + return 0; +} + +static int test_register_packet_too_small(void) +{ + struct client *client = create_mock_client(); + reset_sent_packets(); + + /* Send register packet with payload smaller than expected */ + char small_payload[10] = {0}; + char buf[64]; + size_t len = build_packet(buf, PACKET_REGISTER, small_payload, sizeof(small_payload)); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send error response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ERROR, "Should be error packet"); + const struct packet_error *err = (const struct packet_error *)get_last_packet_payload(); + TEST_ASSERT_EQ(err->code, ERR_INVALID_PACKET, "Should be INVALID_PACKET error"); + + free(client); + return 0; +} + +/* ==== AUTHENTICATION TESTS ==== */ + +static int test_authenticate_success(void) +{ + /* First register a user */ + struct client *client = create_mock_client(); + reset_sent_packets(); + + struct packet_register reg; + memset(®, 0, sizeof(reg)); + strncpy(reg.username, "authuser", sizeof(reg.username)); + strncpy(reg.password, "authpass", sizeof(reg.password)); + + char buf[256]; + size_t len = build_packet(buf, PACKET_REGISTER, ®, sizeof(reg)); + network_handle_data(client, buf, len); + + /* Now authenticate */ + reset_sent_packets(); + + struct packet_auth auth; + memset(&auth, 0, sizeof(auth)); + strncpy(auth.username, "authuser", sizeof(auth.username)); + strncpy(auth.password, "authpass", sizeof(auth.password)); + + len = build_packet(buf, PACKET_AUTHENTICATE, &auth, sizeof(auth)); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_OK, "Should be OK packet"); + TEST_ASSERT(client->is_authenticated, "Client should be authenticated"); + TEST_ASSERT(client->user_id > 0, "Client should have user_id"); + TEST_ASSERT_STR_EQ(client->username, "authuser", "Username should match"); + + free(client); + return 0; +} + +static int test_authenticate_wrong_password(void) +{ + struct client *client = create_mock_client(); + reset_sent_packets(); + + /* Register user */ + struct packet_register reg; + memset(®, 0, sizeof(reg)); + strncpy(reg.username, "wrongpassuser", sizeof(reg.username)); + strncpy(reg.password, "correctpass", sizeof(reg.password)); + + char buf[256]; + size_t len = build_packet(buf, PACKET_REGISTER, ®, sizeof(reg)); + network_handle_data(client, buf, len); + + /* Try to auth with wrong password */ + reset_sent_packets(); + + struct packet_auth auth; + memset(&auth, 0, sizeof(auth)); + strncpy(auth.username, "wrongpassuser", sizeof(auth.username)); + strncpy(auth.password, "wrongpass", sizeof(auth.password)); + + len = build_packet(buf, PACKET_AUTHENTICATE, &auth, sizeof(auth)); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send error response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ERROR, "Should be error packet"); + const struct packet_error *err = (const struct packet_error *)get_last_packet_payload(); + TEST_ASSERT_EQ(err->code, ERR_INVALID_CREDENTIALS, "Should be INVALID_CREDENTIALS error"); + TEST_ASSERT(!client->is_authenticated, "Client should not be authenticated"); + + free(client); + return 0; +} + +static int test_authenticate_unknown_user(void) +{ + struct client *client = create_mock_client(); + reset_sent_packets(); + + struct packet_auth auth; + memset(&auth, 0, sizeof(auth)); + strncpy(auth.username, "nonexistent", sizeof(auth.username)); + strncpy(auth.password, "somepass", sizeof(auth.password)); + + char buf[256]; + size_t len = build_packet(buf, PACKET_AUTHENTICATE, &auth, sizeof(auth)); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send error response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ERROR, "Should be error packet"); + const struct packet_error *err = (const struct packet_error *)get_last_packet_payload(); + TEST_ASSERT_EQ(err->code, ERR_INVALID_CREDENTIALS, "Should be INVALID_CREDENTIALS error"); + + free(client); + return 0; +} + +static int test_authenticate_empty_credentials(void) +{ + struct client *client = create_mock_client(); + reset_sent_packets(); + + struct packet_auth auth; + memset(&auth, 0, sizeof(auth)); + /* Both username and password empty */ + + char buf[256]; + size_t len = build_packet(buf, PACKET_AUTHENTICATE, &auth, sizeof(auth)); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send error response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ERROR, "Should be error packet"); + const struct packet_error *err = (const struct packet_error *)get_last_packet_payload(); + TEST_ASSERT_EQ(err->code, ERR_INVALID_CREDENTIALS, "Should be INVALID_CREDENTIALS error"); + + free(client); + return 0; +} + +/* ==== ROOM OPERATION TESTS ==== */ + +static struct client *setup_authenticated_client(const char *username, const char *password) +{ + struct client *client = create_mock_client(); + + /* Register */ + struct packet_register reg; + memset(®, 0, sizeof(reg)); + strncpy(reg.username, username, sizeof(reg.username)); + strncpy(reg.password, password, sizeof(reg.password)); + + char buf[256]; + size_t len = build_packet(buf, PACKET_REGISTER, ®, sizeof(reg)); + network_handle_data(client, buf, len); + + /* Authenticate */ + struct packet_auth auth; + memset(&auth, 0, sizeof(auth)); + strncpy(auth.username, username, sizeof(auth.username)); + strncpy(auth.password, password, sizeof(auth.password)); + + len = build_packet(buf, PACKET_AUTHENTICATE, &auth, sizeof(auth)); + network_handle_data(client, buf, len); + + reset_sent_packets(); + return client; +} + +static int test_create_room_success(void) +{ + struct client *client = setup_authenticated_client("roomcreator", "pass123"); + + struct packet_create_room create; + memset(&create, 0, sizeof(create)); + strncpy(create.name, "TestRoom", sizeof(create.name)); + + char buf[256]; + size_t len = build_packet(buf, PACKET_CREATE_ROOM, &create, sizeof(create)); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ROOM_CREATED, "Should be ROOM_CREATED packet"); + + const struct packet_room_created *resp = (const struct packet_room_created *)get_last_packet_payload(); + TEST_ASSERT(resp->room_id > 0, "Room ID should be positive"); + + free(client); + return 0; +} + +static int test_create_room_duplicate_name(void) +{ + struct client *client = setup_authenticated_client("roomcreator2", "pass123"); + + struct packet_create_room create; + memset(&create, 0, sizeof(create)); + strncpy(create.name, "DuplicateRoom", sizeof(create.name)); + + char buf[256]; + size_t len = build_packet(buf, PACKET_CREATE_ROOM, &create, sizeof(create)); + network_handle_data(client, buf, len); + + /* Try to create room with same name */ + reset_sent_packets(); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send error response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ERROR, "Should be error packet"); + const struct packet_error *err = (const struct packet_error *)get_last_packet_payload(); + TEST_ASSERT_EQ(err->code, ERR_ROOM_NAME_TAKEN, "Should be ROOM_NAME_TAKEN error"); + + free(client); + return 0; +} + +static int test_create_room_not_authenticated(void) +{ + struct client *client = create_mock_client(); + reset_sent_packets(); + + struct packet_create_room create; + memset(&create, 0, sizeof(create)); + strncpy(create.name, "UnauthorizedRoom", sizeof(create.name)); + + char buf[256]; + size_t len = build_packet(buf, PACKET_CREATE_ROOM, &create, sizeof(create)); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send error response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ERROR, "Should be error packet"); + const struct packet_error *err = (const struct packet_error *)get_last_packet_payload(); + TEST_ASSERT_EQ(err->code, ERR_NOT_AUTHENTICATED, "Should be NOT_AUTHENTICATED error"); + + free(client); + return 0; +} + +static int test_create_room_empty_name(void) +{ + struct client *client = setup_authenticated_client("roomcreator3", "pass123"); + + struct packet_create_room create; + memset(&create, 0, sizeof(create)); + /* Empty name */ + + char buf[256]; + size_t len = build_packet(buf, PACKET_CREATE_ROOM, &create, sizeof(create)); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send error response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ERROR, "Should be error packet"); + const struct packet_error *err = (const struct packet_error *)get_last_packet_payload(); + TEST_ASSERT_EQ(err->code, ERR_INVALID_PACKET, "Should be INVALID_PACKET error"); + + free(client); + return 0; +} + +static int test_join_room_success(void) +{ + struct client *client = setup_authenticated_client("roomjoiner", "pass123"); + + /* Create a room first */ + struct packet_create_room create; + memset(&create, 0, sizeof(create)); + strncpy(create.name, "JoinableRoom", sizeof(create.name)); + + char buf[256]; + size_t len = build_packet(buf, PACKET_CREATE_ROOM, &create, sizeof(create)); + network_handle_data(client, buf, len); + + const struct packet_room_created *created = (const struct packet_room_created *)get_last_packet_payload(); + uint64_t room_id = created->room_id; + + /* Now join the room */ + reset_sent_packets(); + + struct packet_join join; + join.room_id = room_id; + + len = build_packet(buf, PACKET_JOIN, &join, sizeof(join)); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_OK, "Should be OK packet"); + TEST_ASSERT_EQ(client->current_room_id, room_id, "Client should be in room"); + + free(client); + return 0; +} + +static int test_join_room_not_authenticated(void) +{ + struct client *client = create_mock_client(); + reset_sent_packets(); + + struct packet_join join; + join.room_id = 1; + + char buf[64]; + size_t len = build_packet(buf, PACKET_JOIN, &join, sizeof(join)); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send error response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ERROR, "Should be error packet"); + const struct packet_error *err = (const struct packet_error *)get_last_packet_payload(); + TEST_ASSERT_EQ(err->code, ERR_NOT_AUTHENTICATED, "Should be NOT_AUTHENTICATED error"); + + free(client); + return 0; +} + +static int test_join_room_nonexistent(void) +{ + struct client *client = setup_authenticated_client("roomjoiner2", "pass123"); + + struct packet_join join; + join.room_id = 99999; /* Non-existent room */ + + char buf[64]; + size_t len = build_packet(buf, PACKET_JOIN, &join, sizeof(join)); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send error response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ERROR, "Should be error packet"); + const struct packet_error *err = (const struct packet_error *)get_last_packet_payload(); + TEST_ASSERT_EQ(err->code, ERR_ACCESS_DENIED, "Should be ACCESS_DENIED error"); + + free(client); + return 0; +} + +static int test_leave_room_success(void) +{ + struct client *client = setup_authenticated_client("roomleaver", "pass123"); + + /* Create and join a room */ + struct packet_create_room create; + memset(&create, 0, sizeof(create)); + strncpy(create.name, "LeavableRoom", sizeof(create.name)); + + char buf[256]; + size_t len = build_packet(buf, PACKET_CREATE_ROOM, &create, sizeof(create)); + network_handle_data(client, buf, len); + + const struct packet_room_created *created = (const struct packet_room_created *)get_last_packet_payload(); + uint64_t room_id = created->room_id; + + struct packet_join join; + join.room_id = room_id; + len = build_packet(buf, PACKET_JOIN, &join, sizeof(join)); + network_handle_data(client, buf, len); + + /* Now leave the room */ + reset_sent_packets(); + + struct packet_leave leave; + leave.room_id = room_id; + len = build_packet(buf, PACKET_LEAVE, &leave, sizeof(leave)); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_OK, "Should be OK packet"); + TEST_ASSERT_EQ(client->current_room_id, 0, "Client should not be in any room"); + + free(client); + return 0; +} + +static int test_delete_room_success(void) +{ + struct client *client = setup_authenticated_client("roomdeleter", "pass123"); + + /* Create a room */ + struct packet_create_room create; + memset(&create, 0, sizeof(create)); + strncpy(create.name, "DeletableRoom", sizeof(create.name)); + + char buf[256]; + size_t len = build_packet(buf, PACKET_CREATE_ROOM, &create, sizeof(create)); + network_handle_data(client, buf, len); + + const struct packet_room_created *created = (const struct packet_room_created *)get_last_packet_payload(); + uint64_t room_id = created->room_id; + + /* Delete the room */ + reset_sent_packets(); + + struct packet_delete_room del; + del.room_id = room_id; + len = build_packet(buf, PACKET_DELETE_ROOM, &del, sizeof(del)); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_OK, "Should be OK packet"); + + free(client); + return 0; +} + +static int test_delete_room_not_owner(void) +{ + /* Create room as one user */ + struct client *owner = setup_authenticated_client("roomowner", "pass123"); + + struct packet_create_room create; + memset(&create, 0, sizeof(create)); + strncpy(create.name, "OwnedRoom", sizeof(create.name)); + + char buf[256]; + size_t len = build_packet(buf, PACKET_CREATE_ROOM, &create, sizeof(create)); + network_handle_data(owner, buf, len); + + const struct packet_room_created *created = (const struct packet_room_created *)get_last_packet_payload(); + uint64_t room_id = created->room_id; + + /* Try to delete as different user */ + struct client *other = setup_authenticated_client("otheruserx", "pass123"); + + struct packet_delete_room del; + del.room_id = room_id; + len = build_packet(buf, PACKET_DELETE_ROOM, &del, sizeof(del)); + network_handle_data(other, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send error response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ERROR, "Should be error packet"); + const struct packet_error *err = (const struct packet_error *)get_last_packet_payload(); + TEST_ASSERT_EQ(err->code, ERR_NOT_ROOM_OWNER, "Should be NOT_ROOM_OWNER error"); + + free(owner); + free(other); + return 0; +} + +static int test_delete_room_nonexistent(void) +{ + struct client *client = setup_authenticated_client("delnon", "pass123"); + + struct packet_delete_room del; + del.room_id = 99999; + + char buf[64]; + size_t len = build_packet(buf, PACKET_DELETE_ROOM, &del, sizeof(del)); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send error response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ERROR, "Should be error packet"); + const struct packet_error *err = (const struct packet_error *)get_last_packet_payload(); + TEST_ASSERT_EQ(err->code, ERR_ROOM_NOT_FOUND, "Should be ROOM_NOT_FOUND error"); + + free(client); + return 0; +} + +static int test_list_rooms(void) +{ + struct client *client = setup_authenticated_client("lister", "pass123"); + + /* Create a few rooms */ + struct packet_create_room create; + char buf[256]; + + memset(&create, 0, sizeof(create)); + strncpy(create.name, "ListRoom1", sizeof(create.name)); + size_t len = build_packet(buf, PACKET_CREATE_ROOM, &create, sizeof(create)); + network_handle_data(client, buf, len); + + memset(&create, 0, sizeof(create)); + strncpy(create.name, "ListRoom2", sizeof(create.name)); + len = build_packet(buf, PACKET_CREATE_ROOM, &create, sizeof(create)); + network_handle_data(client, buf, len); + + /* List rooms */ + reset_sent_packets(); + len = build_packet(buf, PACKET_LIST_ROOMS, NULL, 0); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ROOM_LIST, "Should be ROOM_LIST packet"); + + const char *payload = get_last_packet_payload(); + uint32_t count; + memcpy(&count, payload, sizeof(uint32_t)); + TEST_ASSERT(count >= 2, "Should have at least 2 rooms"); + + free(client); + return 0; +} + +static int test_list_rooms_not_authenticated(void) +{ + struct client *client = create_mock_client(); + reset_sent_packets(); + + char buf[64]; + size_t len = build_packet(buf, PACKET_LIST_ROOMS, NULL, 0); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send error response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ERROR, "Should be error packet"); + const struct packet_error *err = (const struct packet_error *)get_last_packet_payload(); + TEST_ASSERT_EQ(err->code, ERR_NOT_AUTHENTICATED, "Should be NOT_AUTHENTICATED error"); + + free(client); + return 0; +} + +/* ==== DM TESTS ==== */ + +static int test_dm_open_success(void) +{ + /* Create two users */ + struct client *client1 = setup_authenticated_client("dmuser1", "pass123"); + struct client *client2 = setup_authenticated_client("dmuser2", "pass123"); + + reset_sent_packets(); + + /* Open DM with user2 */ + struct packet_dm_open dm; + memset(&dm, 0, sizeof(dm)); + strncpy(dm.username, "dmuser2", sizeof(dm.username)); + + char buf[64]; + size_t len = build_packet(buf, PACKET_DM_OPEN, &dm, sizeof(dm)); + network_handle_data(client1, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_DM_ROOM, "Should be DM_ROOM packet"); + + const struct packet_dm_room *resp = (const struct packet_dm_room *)get_last_packet_payload(); + TEST_ASSERT(resp->room_id > 0, "Room ID should be positive"); + + free(client1); + free(client2); + return 0; +} + +static int test_dm_open_user_not_found(void) +{ + struct client *client = setup_authenticated_client("dmuser3", "pass123"); + + struct packet_dm_open dm; + memset(&dm, 0, sizeof(dm)); + strncpy(dm.username, "nonexistentuser", sizeof(dm.username)); + + char buf[64]; + size_t len = build_packet(buf, PACKET_DM_OPEN, &dm, sizeof(dm)); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send error response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ERROR, "Should be error packet"); + const struct packet_error *err = (const struct packet_error *)get_last_packet_payload(); + TEST_ASSERT_EQ(err->code, ERR_USER_NOT_FOUND, "Should be USER_NOT_FOUND error"); + + free(client); + return 0; +} + +static int test_dm_open_not_authenticated(void) +{ + struct client *client = create_mock_client(); + reset_sent_packets(); + + struct packet_dm_open dm; + memset(&dm, 0, sizeof(dm)); + strncpy(dm.username, "someuser", sizeof(dm.username)); + + char buf[64]; + size_t len = build_packet(buf, PACKET_DM_OPEN, &dm, sizeof(dm)); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send error response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ERROR, "Should be error packet"); + const struct packet_error *err = (const struct packet_error *)get_last_packet_payload(); + TEST_ASSERT_EQ(err->code, ERR_NOT_AUTHENTICATED, "Should be NOT_AUTHENTICATED error"); + + free(client); + return 0; +} + +static int test_dm_room_reuse(void) +{ + /* Create two users */ + struct client *client1 = setup_authenticated_client("dmreuse1", "pass123"); + struct client *client2 = setup_authenticated_client("dmreuse2", "pass123"); + + /* Open DM from client1 to client2 */ + struct packet_dm_open dm; + memset(&dm, 0, sizeof(dm)); + strncpy(dm.username, "dmreuse2", sizeof(dm.username)); + + char buf[64]; + size_t len = build_packet(buf, PACKET_DM_OPEN, &dm, sizeof(dm)); + network_handle_data(client1, buf, len); + + const struct packet_dm_room *resp1 = (const struct packet_dm_room *)get_last_packet_payload(); + uint64_t room_id1 = resp1->room_id; + + /* Open DM from client2 to client1 - should get same room */ + reset_sent_packets(); + memset(&dm, 0, sizeof(dm)); + strncpy(dm.username, "dmreuse1", sizeof(dm.username)); + len = build_packet(buf, PACKET_DM_OPEN, &dm, sizeof(dm)); + network_handle_data(client2, buf, len); + + const struct packet_dm_room *resp2 = (const struct packet_dm_room *)get_last_packet_payload(); + uint64_t room_id2 = resp2->room_id; + + TEST_ASSERT_EQ(room_id1, room_id2, "DM room should be reused"); + + free(client1); + free(client2); + return 0; +} + +/* ==== MESSAGE TESTS ==== */ + +static int test_text_success(void) +{ + struct client *client = setup_authenticated_client("msgsender", "pass123"); + + /* Create and join a room */ + struct packet_create_room create; + memset(&create, 0, sizeof(create)); + strncpy(create.name, "MessageRoom", sizeof(create.name)); + + char buf[256]; + size_t len = build_packet(buf, PACKET_CREATE_ROOM, &create, sizeof(create)); + network_handle_data(client, buf, len); + + const struct packet_room_created *created = (const struct packet_room_created *)get_last_packet_payload(); + uint64_t room_id = created->room_id; + + struct packet_join join; + join.room_id = room_id; + len = build_packet(buf, PACKET_JOIN, &join, sizeof(join)); + network_handle_data(client, buf, len); + + /* Send a message */ + network_client_add(client); + reset_sent_packets(); + + char text_buf[128]; + uint64_t *room_ptr = (uint64_t *)text_buf; + *room_ptr = room_id; + const char *msg = "Hello, world!"; + memcpy(text_buf + sizeof(uint64_t), msg, strlen(msg)); + + len = build_packet(buf, PACKET_TEXT, text_buf, sizeof(uint64_t) + strlen(msg)); + network_handle_data(client, buf, len); + + /* Should receive broadcast back since we're in the room */ + TEST_ASSERT(sent_packet_count >= 1, "Should receive broadcast"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_TEXT, "Should be TEXT packet (broadcast)"); + + network_client_remove(client); + free(client); + return 0; +} + +static int test_text_not_authenticated(void) +{ + struct client *client = create_mock_client(); + reset_sent_packets(); + + char text_buf[64]; + uint64_t *room_ptr = (uint64_t *)text_buf; + *room_ptr = 1; + memcpy(text_buf + sizeof(uint64_t), "test", 4); + + char buf[128]; + size_t len = build_packet(buf, PACKET_TEXT, text_buf, sizeof(uint64_t) + 4); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send error response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ERROR, "Should be error packet"); + const struct packet_error *err = (const struct packet_error *)get_last_packet_payload(); + TEST_ASSERT_EQ(err->code, ERR_NOT_AUTHENTICATED, "Should be NOT_AUTHENTICATED error"); + + free(client); + return 0; +} + +static int test_text_wrong_room(void) +{ + struct client *client = setup_authenticated_client("wrongroom", "pass123"); + + /* Create a room but don't join it */ + struct packet_create_room create; + memset(&create, 0, sizeof(create)); + strncpy(create.name, "WrongRoom", sizeof(create.name)); + + char buf[256]; + size_t len = build_packet(buf, PACKET_CREATE_ROOM, &create, sizeof(create)); + network_handle_data(client, buf, len); + + const struct packet_room_created *created = (const struct packet_room_created *)get_last_packet_payload(); + uint64_t room_id = created->room_id; + + /* Try to send message to room we haven't joined */ + reset_sent_packets(); + + char text_buf[64]; + uint64_t *room_ptr = (uint64_t *)text_buf; + *room_ptr = room_id; + memcpy(text_buf + sizeof(uint64_t), "test", 4); + + len = build_packet(buf, PACKET_TEXT, text_buf, sizeof(uint64_t) + 4); + network_handle_data(client, buf, len); + + TEST_ASSERT(sent_packet_count == 1, "Should send error response"); + struct packet_header *hdr = get_last_packet_header(); + TEST_ASSERT_EQ(hdr->type, PACKET_ERROR, "Should be error packet"); + + free(client); + return 0; +} + +/* ==== CLIENT MANAGEMENT TESTS ==== */ + +static int test_client_add_remove(void) +{ + struct client *client1 = create_mock_client(); + struct client *client2 = create_mock_client(); + + network_client_add(client1); + network_client_add(client2); + + /* Remove first client */ + network_client_remove(client1); + + /* Remove second client */ + network_client_remove(client2); + + /* Should handle removing non-existent client gracefully */ + network_client_remove(client1); + + free(client1); + free(client2); + return 0; +} + +/* ==== PASSWORD TESTS ==== */ + +static int test_password_hash_verify(void) +{ + const char *password = "testpassword123"; + char *hash = hash_password(password); + + TEST_ASSERT(hash != NULL, "Hash should not be NULL"); + TEST_ASSERT(strlen(hash) > 0, "Hash should not be empty"); + TEST_ASSERT(verify_password(password, hash), "Password should verify"); + TEST_ASSERT(!verify_password("wrongpassword", hash), "Wrong password should not verify"); + + free(hash); + return 0; +} + +static int test_password_different_hashes(void) +{ + const char *password = "samepassword"; + char *hash1 = hash_password(password); + /* Sleep briefly to ensure different time-based seed */ + /* Note: generate_salt uses srand(time(NULL)), so hashes within same + * second may be identical. This tests that both hashes verify correctly. */ + char *hash2 = hash_password(password); + + TEST_ASSERT(hash1 != NULL && hash2 != NULL, "Hashes should not be NULL"); + /* Both should verify regardless of whether salt differs */ + TEST_ASSERT(verify_password(password, hash1), "Password should verify with hash1"); + TEST_ASSERT(verify_password(password, hash2), "Password should verify with hash2"); + /* Wrong password should fail on both */ + TEST_ASSERT(!verify_password("wrongpass", hash1), "Wrong password should not verify with hash1"); + TEST_ASSERT(!verify_password("wrongpass", hash2), "Wrong password should not verify with hash2"); + + free(hash1); + free(hash2); + return 0; +} + +/* ==== MAIN ==== */ + +int main(void) +{ + printf("=== asfur unit tests ===\n\n"); + + /* Remove database file to ensure clean state for testing */ + /* CONFIG_DB_PATH is defined in config.h as "asfur.db" */ + remove("asfur.db"); + + if (network_init() != 0) { + fprintf(stderr, "Failed to initialize network for testing\n"); + return 1; + } + + /* Packet parsing tests */ + printf("\n--- Packet Parsing Tests ---\n"); + RUN_TEST(test_invalid_packet_too_short); + RUN_TEST(test_invalid_packet_size_mismatch); + RUN_TEST(test_invalid_packet_unknown_type); + + /* Registration tests */ + printf("\n--- Registration Tests ---\n"); + RUN_TEST(test_register_success); + RUN_TEST(test_register_duplicate_username); + RUN_TEST(test_register_empty_username); + RUN_TEST(test_register_empty_password); + RUN_TEST(test_register_packet_too_small); + + /* Authentication tests */ + printf("\n--- Authentication Tests ---\n"); + RUN_TEST(test_authenticate_success); + RUN_TEST(test_authenticate_wrong_password); + RUN_TEST(test_authenticate_unknown_user); + RUN_TEST(test_authenticate_empty_credentials); + + /* Room operation tests */ + printf("\n--- Room Operation Tests ---\n"); + RUN_TEST(test_create_room_success); + RUN_TEST(test_create_room_duplicate_name); + RUN_TEST(test_create_room_not_authenticated); + RUN_TEST(test_create_room_empty_name); + RUN_TEST(test_join_room_success); + RUN_TEST(test_join_room_not_authenticated); + RUN_TEST(test_join_room_nonexistent); + RUN_TEST(test_leave_room_success); + RUN_TEST(test_delete_room_success); + RUN_TEST(test_delete_room_not_owner); + RUN_TEST(test_delete_room_nonexistent); + RUN_TEST(test_list_rooms); + RUN_TEST(test_list_rooms_not_authenticated); + + /* DM tests */ + printf("\n--- DM Tests ---\n"); + RUN_TEST(test_dm_open_success); + RUN_TEST(test_dm_open_user_not_found); + RUN_TEST(test_dm_open_not_authenticated); + RUN_TEST(test_dm_room_reuse); + + /* Message tests */ + printf("\n--- Message Tests ---\n"); + RUN_TEST(test_text_success); + RUN_TEST(test_text_not_authenticated); + RUN_TEST(test_text_wrong_room); + + /* Client management tests */ + printf("\n--- Client Management Tests ---\n"); + RUN_TEST(test_client_add_remove); + + /* Password tests */ + printf("\n--- Password Tests ---\n"); + RUN_TEST(test_password_hash_verify); + RUN_TEST(test_password_different_hashes); + + network_shutdown(); + + TEST_SUMMARY(); +}