#define _GNU_SOURCE #include #include #include #include #include #include "network.h" #include "client.h" #include "password.h" #include "config.h" #define MAX_CLIENTS 1024 static sqlite3 *db = NULL; static struct client *clients[MAX_CLIENTS]; static int client_count = 0; static const char *schema_sql = "CREATE TABLE IF NOT EXISTS users (" " id INTEGER PRIMARY KEY AUTOINCREMENT," " username TEXT UNIQUE NOT NULL," " password_hash TEXT NOT NULL," " created_at INTEGER DEFAULT (strftime('%s', 'now'))" ");" "CREATE TABLE IF NOT EXISTS dm_rooms (" " id INTEGER PRIMARY KEY AUTOINCREMENT," " user1_id INTEGER NOT NULL," " user2_id INTEGER NOT NULL," " created_at INTEGER DEFAULT (strftime('%s', 'now'))," " UNIQUE(user1_id, user2_id)," " FOREIGN KEY(user1_id) REFERENCES users(id)," " FOREIGN KEY(user2_id) REFERENCES users(id)" ");" "CREATE TABLE IF NOT EXISTS rooms (" " id INTEGER PRIMARY KEY AUTOINCREMENT," " name TEXT UNIQUE NOT NULL," " owner_id INTEGER NOT NULL," " created_at INTEGER DEFAULT (strftime('%s', 'now'))," " FOREIGN KEY(owner_id) REFERENCES users(id)" ");"; int network_init(void) { int rc = sqlite3_open(CONFIG_DB_PATH, &db); if (rc != SQLITE_OK) { fprintf(stderr, "Cannot open database: %s\n", sqlite3_errmsg(db)); return -1; } char *err_msg = NULL; rc = sqlite3_exec(db, schema_sql, NULL, NULL, &err_msg); if (rc != SQLITE_OK) { fprintf(stderr, "SQL error: %s\n", err_msg); sqlite3_free(err_msg); return -1; } return 0; } void network_shutdown(void) { if (db) { sqlite3_close(db); db = NULL; } } void network_client_add(struct client *client) { if (client_count >= MAX_CLIENTS) return; clients[client_count++] = client; } void network_client_remove(struct client *client) { for (int i = 0; i < client_count; i++) { if (clients[i] == client) { clients[i] = clients[--client_count]; return; } } } static void send_error(struct client *client, error_code code); static void send_ok(struct client *client); static int is_dm_room_member(uint64_t room_id, uint32_t user_id); static int is_public_room(uint64_t room_id); extern void ssl_write_msg(struct client *client, const char *data, size_t len); static void send_packet(struct client *client, uint8_t type, const void *data, size_t data_len) { size_t total_len = sizeof(struct packet_header) + data_len; char *buf = malloc(total_len); if (!buf) return; struct packet_header *hdr = (struct packet_header *)buf; hdr->size = htons((uint16_t)total_len); hdr->type = type; if (data && data_len > 0) { memcpy(buf + sizeof(struct packet_header), data, data_len); } ssl_write_msg(client, buf, total_len); free(buf); } static void send_error(struct client *client, error_code code) { struct packet_error err = { .code = (uint8_t)code }; send_packet(client, PACKET_ERROR, &err, sizeof(err)); } static void send_ok(struct client *client) { send_packet(client, PACKET_OK, NULL, 0); } static int handle_register(struct client *client, const char *data, size_t len) { if (len < sizeof(struct packet_register)) { send_error(client, ERR_INVALID_PACKET); return -1; } #ifndef CONFIG_ALLOW_REGISTER send_error(client, ERR_REGISTRATION_DISABLED); return -1; #else if (!CONFIG_ALLOW_REGISTER) { send_error(client, ERR_REGISTRATION_DISABLED); return -1; } #endif const struct packet_register *reg = (const struct packet_register *)data; char username[21] = {0}; char password[101] = {0}; memcpy(username, reg->username, 20); memcpy(password, reg->password, 100); if (strlen(username) == 0 || strlen(password) == 0) { send_error(client, ERR_INVALID_PACKET); return -1; } char *hash = hash_password(password); if (!hash) { send_error(client, ERR_DATABASE); return -1; } sqlite3_stmt *stmt; const char *sql = "INSERT INTO users (username, password_hash) VALUES (?, ?);"; int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); if (rc != SQLITE_OK) { free(hash); send_error(client, ERR_DATABASE); return -1; } sqlite3_bind_text(stmt, 1, username, -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 2, hash, -1, SQLITE_STATIC); rc = sqlite3_step(stmt); sqlite3_finalize(stmt); free(hash); if (rc == SQLITE_CONSTRAINT) { send_error(client, ERR_ALREADY_REGISTERED); return -1; } else if (rc != SQLITE_DONE) { send_error(client, ERR_DATABASE); return -1; } send_ok(client); return 0; } static int handle_authenticate(struct client *client, const char *data, size_t len) { if (len < sizeof(struct packet_auth)) { send_error(client, ERR_INVALID_PACKET); return -1; } const struct packet_auth *auth = (const struct packet_auth *)data; char username[21] = {0}; char password[101] = {0}; memcpy(username, auth->username, 20); memcpy(password, auth->password, 100); if (strlen(username) == 0 || strlen(password) == 0) { send_error(client, ERR_INVALID_CREDENTIALS); return -1; } sqlite3_stmt *stmt; const char *sql = "SELECT id, password_hash FROM users WHERE username = ?;"; int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); if (rc != SQLITE_OK) { send_error(client, ERR_DATABASE); return -1; } sqlite3_bind_text(stmt, 1, username, -1, SQLITE_STATIC); rc = sqlite3_step(stmt); if (rc != SQLITE_ROW) { sqlite3_finalize(stmt); send_error(client, ERR_INVALID_CREDENTIALS); return -1; } int user_id = sqlite3_column_int(stmt, 0); const char *stored_hash = (const char *)sqlite3_column_text(stmt, 1); if (!verify_password(password, stored_hash)) { sqlite3_finalize(stmt); send_error(client, ERR_INVALID_CREDENTIALS); return -1; } sqlite3_finalize(stmt); client->user_id = user_id; client->is_authenticated = true; strncpy(client->username, username, sizeof(client->username) - 1); client->username[sizeof(client->username) - 1] = '\0'; send_ok(client); return 0; } static int handle_join(struct client *client, const char *data, size_t len) { if (!client->is_authenticated) { send_error(client, ERR_NOT_AUTHENTICATED); return -1; } if (len < sizeof(struct packet_join)) { send_error(client, ERR_INVALID_PACKET); return -1; } const struct packet_join *join = (const struct packet_join *)data; if (is_public_room(join->room_id)) { client->current_room_id = join->room_id; send_ok(client); return 0; } if (is_dm_room_member(join->room_id, client->user_id)) { client->current_room_id = join->room_id; send_ok(client); return 0; } send_error(client, ERR_ACCESS_DENIED); return -1; } static int handle_leave(struct client *client, const char *data, size_t len) { if (!client->is_authenticated) { send_error(client, ERR_NOT_AUTHENTICATED); return -1; } if (len < sizeof(struct packet_leave)) { send_error(client, ERR_INVALID_PACKET); return -1; } const struct packet_leave *leave = (const struct packet_leave *)data; if (client->current_room_id == leave->room_id) { client->current_room_id = 0; } send_ok(client); return 0; } static int64_t get_or_create_dm_room(uint32_t user1_id, uint32_t user2_id) { uint32_t low_id = user1_id < user2_id ? user1_id : user2_id; uint32_t high_id = user1_id < user2_id ? user2_id : user1_id; sqlite3_stmt *stmt; const char *select_sql = "SELECT id FROM dm_rooms WHERE user1_id = ? AND user2_id = ?;"; int rc = sqlite3_prepare_v2(db, select_sql, -1, &stmt, NULL); if (rc != SQLITE_OK) return -1; sqlite3_bind_int(stmt, 1, low_id); sqlite3_bind_int(stmt, 2, high_id); rc = sqlite3_step(stmt); if (rc == SQLITE_ROW) { int64_t room_id = sqlite3_column_int64(stmt, 0); sqlite3_finalize(stmt); return room_id; } sqlite3_finalize(stmt); const char *insert_sql = "INSERT INTO dm_rooms (user1_id, user2_id) VALUES (?, ?);"; rc = sqlite3_prepare_v2(db, insert_sql, -1, &stmt, NULL); if (rc != SQLITE_OK) return -1; sqlite3_bind_int(stmt, 1, low_id); sqlite3_bind_int(stmt, 2, high_id); rc = sqlite3_step(stmt); sqlite3_finalize(stmt); if (rc != SQLITE_DONE) return -1; return sqlite3_last_insert_rowid(db); } static int is_dm_room_member(uint64_t room_id, uint32_t user_id) { sqlite3_stmt *stmt; const char *sql = "SELECT 1 FROM dm_rooms WHERE id = ? AND (user1_id = ? OR user2_id = ?);"; int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); if (rc != SQLITE_OK) return 0; sqlite3_bind_int64(stmt, 1, room_id); sqlite3_bind_int(stmt, 2, user_id); sqlite3_bind_int(stmt, 3, user_id); rc = sqlite3_step(stmt); sqlite3_finalize(stmt); return rc == SQLITE_ROW; } static int is_public_room(uint64_t room_id) { sqlite3_stmt *stmt; const char *sql = "SELECT 1 FROM rooms WHERE id = ?;"; int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); if (rc != SQLITE_OK) return 0; sqlite3_bind_int64(stmt, 1, room_id); rc = sqlite3_step(stmt); sqlite3_finalize(stmt); return rc == SQLITE_ROW; } static int is_room_owner(uint64_t room_id, uint32_t user_id) { sqlite3_stmt *stmt; const char *sql = "SELECT 1 FROM rooms WHERE id = ? AND owner_id = ?;"; int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); if (rc != SQLITE_OK) return 0; sqlite3_bind_int64(stmt, 1, room_id); sqlite3_bind_int(stmt, 2, user_id); rc = sqlite3_step(stmt); sqlite3_finalize(stmt); return rc == SQLITE_ROW; } static int handle_create_room(struct client *client, const char *data, size_t len) { if (!client->is_authenticated) { send_error(client, ERR_NOT_AUTHENTICATED); return -1; } if (len < sizeof(struct packet_create_room)) { send_error(client, ERR_INVALID_PACKET); return -1; } const struct packet_create_room *create = (const struct packet_create_room *)data; char name[33] = {0}; memcpy(name, create->name, 32); if (strlen(name) == 0) { send_error(client, ERR_INVALID_PACKET); return -1; } sqlite3_stmt *stmt; const char *sql = "INSERT INTO rooms (name, owner_id) VALUES (?, ?);"; int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); if (rc != SQLITE_OK) { send_error(client, ERR_DATABASE); return -1; } sqlite3_bind_text(stmt, 1, name, -1, SQLITE_STATIC); sqlite3_bind_int(stmt, 2, client->user_id); rc = sqlite3_step(stmt); sqlite3_finalize(stmt); if (rc == SQLITE_CONSTRAINT) { send_error(client, ERR_ROOM_NAME_TAKEN); return -1; } else if (rc != SQLITE_DONE) { send_error(client, ERR_DATABASE); return -1; } int64_t room_id = sqlite3_last_insert_rowid(db); struct packet_room_created response; response.room_id = (uint64_t)room_id; memset(response.name, 0, sizeof(response.name)); strncpy(response.name, name, sizeof(response.name) - 1); send_packet(client, PACKET_ROOM_CREATED, &response, sizeof(response)); return 0; } static int handle_delete_room(struct client *client, const char *data, size_t len) { if (!client->is_authenticated) { send_error(client, ERR_NOT_AUTHENTICATED); return -1; } if (len < sizeof(struct packet_delete_room)) { send_error(client, ERR_INVALID_PACKET); return -1; } const struct packet_delete_room *del = (const struct packet_delete_room *)data; if (!is_public_room(del->room_id)) { send_error(client, ERR_ROOM_NOT_FOUND); return -1; } if (!is_room_owner(del->room_id, client->user_id)) { send_error(client, ERR_NOT_ROOM_OWNER); return -1; } sqlite3_stmt *stmt; const char *sql = "DELETE FROM rooms WHERE id = ?;"; int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); if (rc != SQLITE_OK) { send_error(client, ERR_DATABASE); return -1; } sqlite3_bind_int64(stmt, 1, del->room_id); rc = sqlite3_step(stmt); sqlite3_finalize(stmt); if (rc != SQLITE_DONE) { send_error(client, ERR_DATABASE); return -1; } for (int i = 0; i < client_count; i++) { if (clients[i]->current_room_id == del->room_id) { clients[i]->current_room_id = 0; } } send_ok(client); return 0; } static int handle_list_rooms(struct client *client) { if (!client->is_authenticated) { send_error(client, ERR_NOT_AUTHENTICATED); return -1; } sqlite3_stmt *stmt; const char *sql = "SELECT r.id, r.name, u.username FROM rooms r " "JOIN users u ON r.owner_id = u.id ORDER BY r.name;"; int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); if (rc != SQLITE_OK) { send_error(client, ERR_DATABASE); return -1; } struct room_list_entry entries[256]; uint32_t count = 0; while ((rc = sqlite3_step(stmt)) == SQLITE_ROW && count < 256) { entries[count].room_id = sqlite3_column_int64(stmt, 0); memset(entries[count].name, 0, sizeof(entries[count].name)); const char *name = (const char *)sqlite3_column_text(stmt, 1); if (name) strncpy(entries[count].name, name, 31); memset(entries[count].owner, 0, sizeof(entries[count].owner)); const char *owner = (const char *)sqlite3_column_text(stmt, 2); if (owner) strncpy(entries[count].owner, owner, 19); count++; } sqlite3_finalize(stmt); size_t data_len = sizeof(uint32_t) + count * sizeof(struct room_list_entry); char *data = malloc(data_len); if (!data) { send_error(client, ERR_DATABASE); return -1; } memcpy(data, &count, sizeof(uint32_t)); memcpy(data + sizeof(uint32_t), entries, count * sizeof(struct room_list_entry)); send_packet(client, PACKET_ROOM_LIST, data, data_len); free(data); return 0; } static int handle_dm_open(struct client *client, const char *data, size_t len) { if (!client->is_authenticated) { send_error(client, ERR_NOT_AUTHENTICATED); return -1; } if (len < sizeof(struct packet_dm_open)) { send_error(client, ERR_INVALID_PACKET); return -1; } const struct packet_dm_open *dm = (const struct packet_dm_open *)data; char target_username[21] = {0}; memcpy(target_username, dm->username, 20); if (strlen(target_username) == 0) { send_error(client, ERR_INVALID_PACKET); return -1; } sqlite3_stmt *stmt; const char *sql = "SELECT id FROM users WHERE username = ?;"; int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); if (rc != SQLITE_OK) { send_error(client, ERR_DATABASE); return -1; } sqlite3_bind_text(stmt, 1, target_username, -1, SQLITE_STATIC); rc = sqlite3_step(stmt); if (rc != SQLITE_ROW) { sqlite3_finalize(stmt); send_error(client, ERR_USER_NOT_FOUND); return -1; } uint32_t target_user_id = sqlite3_column_int(stmt, 0); sqlite3_finalize(stmt); int64_t room_id = get_or_create_dm_room(client->user_id, target_user_id); if (room_id < 0) { send_error(client, ERR_DATABASE); return -1; } struct packet_dm_room response = { .room_id = (uint64_t)room_id }; send_packet(client, PACKET_DM_ROOM, &response, sizeof(response)); return 0; } static void broadcast_to_room(uint64_t room_id, const char *username, const char *message, size_t msg_len) { size_t bcast_len = sizeof(uint64_t) + 32 + msg_len; char *bcast_data = malloc(bcast_len); if (!bcast_data) return; memcpy(bcast_data, &room_id, sizeof(uint64_t)); memset(bcast_data + sizeof(uint64_t), 0, 32); strncpy(bcast_data + sizeof(uint64_t), username, 31); memcpy(bcast_data + sizeof(uint64_t) + 32, message, msg_len); for (int i = 0; i < client_count; i++) { struct client *c = clients[i]; if (c->is_authenticated && c->current_room_id == room_id) { send_packet(c, PACKET_TEXT, bcast_data, bcast_len); } } free(bcast_data); } static int handle_text(struct client *client, const char *data, size_t len) { if (!client->is_authenticated) { send_error(client, ERR_NOT_AUTHENTICATED); return -1; } if (len < sizeof(uint64_t)) { send_error(client, ERR_INVALID_PACKET); return -1; } const struct packet_text *text = (const struct packet_text *)data; size_t msg_len = len - sizeof(uint64_t); if (client->current_room_id != text->room_id) { send_error(client, ERR_NOT_AUTHENTICATED); return -1; } broadcast_to_room(text->room_id, client->username, text->message, msg_len); return 0; } void network_handle_data(struct client *client, const char *data, size_t len) { if (len < sizeof(struct packet_header)) { send_error(client, ERR_INVALID_PACKET); return; } struct packet_header *hdr = (struct packet_header *)data; hdr->size = ntohs(hdr->size); if (hdr->size > len || hdr->size < sizeof(struct packet_header)) { send_error(client, ERR_INVALID_PACKET); return; } const char *payload = data + sizeof(struct packet_header); size_t payload_len = hdr->size - sizeof(struct packet_header); switch (hdr->type) { case PACKET_REGISTER: handle_register(client, payload, payload_len); break; case PACKET_AUTHENTICATE: handle_authenticate(client, payload, payload_len); break; case PACKET_JOIN: handle_join(client, payload, payload_len); break; case PACKET_LEAVE: handle_leave(client, payload, payload_len); break; case PACKET_TEXT: handle_text(client, payload, payload_len); break; case PACKET_DM_OPEN: handle_dm_open(client, payload, payload_len); break; case PACKET_CREATE_ROOM: handle_create_room(client, payload, payload_len); break; case PACKET_DELETE_ROOM: handle_delete_room(client, payload, payload_len); break; case PACKET_LIST_ROOMS: handle_list_rooms(client); break; default: send_error(client, ERR_INVALID_PACKET); break; } }