asfur/network.c

682 lines
17 KiB
C

#define _GNU_SOURCE
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sqlite3.h>
#include <arpa/inet.h>
#include "network.h"
#include "client.h"
#include "password.h"
#include "config.h"
#define MAX_CLIENTS 1024
static sqlite3 *db = NULL;
static struct client *clients[MAX_CLIENTS];
static int client_count = 0;
static const char *schema_sql =
"CREATE TABLE IF NOT EXISTS users ("
" id INTEGER PRIMARY KEY AUTOINCREMENT,"
" username TEXT UNIQUE NOT NULL,"
" password_hash TEXT NOT NULL,"
" created_at INTEGER DEFAULT (strftime('%s', 'now'))"
");"
"CREATE TABLE IF NOT EXISTS dm_rooms ("
" id INTEGER PRIMARY KEY AUTOINCREMENT,"
" user1_id INTEGER NOT NULL,"
" user2_id INTEGER NOT NULL,"
" created_at INTEGER DEFAULT (strftime('%s', 'now')),"
" UNIQUE(user1_id, user2_id),"
" FOREIGN KEY(user1_id) REFERENCES users(id),"
" FOREIGN KEY(user2_id) REFERENCES users(id)"
");"
"CREATE TABLE IF NOT EXISTS rooms ("
" id INTEGER PRIMARY KEY AUTOINCREMENT,"
" name TEXT UNIQUE NOT NULL,"
" owner_id INTEGER NOT NULL,"
" created_at INTEGER DEFAULT (strftime('%s', 'now')),"
" FOREIGN KEY(owner_id) REFERENCES users(id)"
");";
int network_init(void)
{
int rc = sqlite3_open(CONFIG_DB_PATH, &db);
if (rc != SQLITE_OK) {
fprintf(stderr, "Cannot open database: %s\n", sqlite3_errmsg(db));
return -1;
}
char *err_msg = NULL;
rc = sqlite3_exec(db, schema_sql, NULL, NULL, &err_msg);
if (rc != SQLITE_OK) {
fprintf(stderr, "SQL error: %s\n", err_msg);
sqlite3_free(err_msg);
return -1;
}
return 0;
}
void network_shutdown(void)
{
if (db) {
sqlite3_close(db);
db = NULL;
}
}
void network_client_add(struct client *client)
{
if (client_count >= MAX_CLIENTS) return;
clients[client_count++] = client;
}
void network_client_remove(struct client *client)
{
for (int i = 0; i < client_count; i++) {
if (clients[i] == client) {
clients[i] = clients[--client_count];
return;
}
}
}
static void send_error(struct client *client, error_code code);
static void send_ok(struct client *client);
static int is_dm_room_member(uint64_t room_id, uint32_t user_id);
static int is_public_room(uint64_t room_id);
extern void ssl_write_msg(struct client *client, const char *data, size_t len);
static void send_packet(struct client *client, uint8_t type, const void *data, size_t data_len)
{
size_t total_len = sizeof(struct packet_header) + data_len;
char *buf = malloc(total_len);
if (!buf) return;
struct packet_header *hdr = (struct packet_header *)buf;
hdr->size = htons((uint16_t)total_len);
hdr->type = type;
if (data && data_len > 0) {
memcpy(buf + sizeof(struct packet_header), data, data_len);
}
ssl_write_msg(client, buf, total_len);
free(buf);
}
static void send_error(struct client *client, error_code code)
{
struct packet_error err = { .code = (uint8_t)code };
send_packet(client, PACKET_ERROR, &err, sizeof(err));
}
static void send_ok(struct client *client)
{
send_packet(client, PACKET_OK, NULL, 0);
}
static int handle_register(struct client *client, const char *data, size_t len)
{
if (len < sizeof(struct packet_register)) {
send_error(client, ERR_INVALID_PACKET);
return -1;
}
#ifndef CONFIG_ALLOW_REGISTER
send_error(client, ERR_REGISTRATION_DISABLED);
return -1;
#else
if (!CONFIG_ALLOW_REGISTER) {
send_error(client, ERR_REGISTRATION_DISABLED);
return -1;
}
#endif
const struct packet_register *reg = (const struct packet_register *)data;
char username[21] = {0};
char password[101] = {0};
memcpy(username, reg->username, 20);
memcpy(password, reg->password, 100);
if (strlen(username) == 0 || strlen(password) == 0) {
send_error(client, ERR_INVALID_PACKET);
return -1;
}
char *hash = hash_password(password);
if (!hash) {
send_error(client, ERR_DATABASE);
return -1;
}
sqlite3_stmt *stmt;
const char *sql = "INSERT INTO users (username, password_hash) VALUES (?, ?);";
int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
if (rc != SQLITE_OK) {
free(hash);
send_error(client, ERR_DATABASE);
return -1;
}
sqlite3_bind_text(stmt, 1, username, -1, SQLITE_STATIC);
sqlite3_bind_text(stmt, 2, hash, -1, SQLITE_STATIC);
rc = sqlite3_step(stmt);
sqlite3_finalize(stmt);
free(hash);
if (rc == SQLITE_CONSTRAINT) {
send_error(client, ERR_ALREADY_REGISTERED);
return -1;
} else if (rc != SQLITE_DONE) {
send_error(client, ERR_DATABASE);
return -1;
}
send_ok(client);
return 0;
}
static int handle_authenticate(struct client *client, const char *data, size_t len)
{
if (len < sizeof(struct packet_auth)) {
send_error(client, ERR_INVALID_PACKET);
return -1;
}
const struct packet_auth *auth = (const struct packet_auth *)data;
char username[21] = {0};
char password[101] = {0};
memcpy(username, auth->username, 20);
memcpy(password, auth->password, 100);
if (strlen(username) == 0 || strlen(password) == 0) {
send_error(client, ERR_INVALID_CREDENTIALS);
return -1;
}
sqlite3_stmt *stmt;
const char *sql = "SELECT id, password_hash FROM users WHERE username = ?;";
int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
if (rc != SQLITE_OK) {
send_error(client, ERR_DATABASE);
return -1;
}
sqlite3_bind_text(stmt, 1, username, -1, SQLITE_STATIC);
rc = sqlite3_step(stmt);
if (rc != SQLITE_ROW) {
sqlite3_finalize(stmt);
send_error(client, ERR_INVALID_CREDENTIALS);
return -1;
}
int user_id = sqlite3_column_int(stmt, 0);
const char *stored_hash = (const char *)sqlite3_column_text(stmt, 1);
if (!verify_password(password, stored_hash)) {
sqlite3_finalize(stmt);
send_error(client, ERR_INVALID_CREDENTIALS);
return -1;
}
sqlite3_finalize(stmt);
client->user_id = user_id;
client->is_authenticated = true;
strncpy(client->username, username, sizeof(client->username) - 1);
client->username[sizeof(client->username) - 1] = '\0';
send_ok(client);
return 0;
}
static int handle_join(struct client *client, const char *data, size_t len)
{
if (!client->is_authenticated) {
send_error(client, ERR_NOT_AUTHENTICATED);
return -1;
}
if (len < sizeof(struct packet_join)) {
send_error(client, ERR_INVALID_PACKET);
return -1;
}
const struct packet_join *join = (const struct packet_join *)data;
if (is_public_room(join->room_id)) {
client->current_room_id = join->room_id;
send_ok(client);
return 0;
}
if (is_dm_room_member(join->room_id, client->user_id)) {
client->current_room_id = join->room_id;
send_ok(client);
return 0;
}
send_error(client, ERR_ACCESS_DENIED);
return -1;
}
static int handle_leave(struct client *client, const char *data, size_t len)
{
if (!client->is_authenticated) {
send_error(client, ERR_NOT_AUTHENTICATED);
return -1;
}
if (len < sizeof(struct packet_leave)) {
send_error(client, ERR_INVALID_PACKET);
return -1;
}
const struct packet_leave *leave = (const struct packet_leave *)data;
if (client->current_room_id == leave->room_id) {
client->current_room_id = 0;
}
send_ok(client);
return 0;
}
static int64_t get_or_create_dm_room(uint32_t user1_id, uint32_t user2_id)
{
uint32_t low_id = user1_id < user2_id ? user1_id : user2_id;
uint32_t high_id = user1_id < user2_id ? user2_id : user1_id;
sqlite3_stmt *stmt;
const char *select_sql = "SELECT id FROM dm_rooms WHERE user1_id = ? AND user2_id = ?;";
int rc = sqlite3_prepare_v2(db, select_sql, -1, &stmt, NULL);
if (rc != SQLITE_OK) return -1;
sqlite3_bind_int(stmt, 1, low_id);
sqlite3_bind_int(stmt, 2, high_id);
rc = sqlite3_step(stmt);
if (rc == SQLITE_ROW) {
int64_t room_id = sqlite3_column_int64(stmt, 0);
sqlite3_finalize(stmt);
return room_id;
}
sqlite3_finalize(stmt);
const char *insert_sql = "INSERT INTO dm_rooms (user1_id, user2_id) VALUES (?, ?);";
rc = sqlite3_prepare_v2(db, insert_sql, -1, &stmt, NULL);
if (rc != SQLITE_OK) return -1;
sqlite3_bind_int(stmt, 1, low_id);
sqlite3_bind_int(stmt, 2, high_id);
rc = sqlite3_step(stmt);
sqlite3_finalize(stmt);
if (rc != SQLITE_DONE) return -1;
return sqlite3_last_insert_rowid(db);
}
static int is_dm_room_member(uint64_t room_id, uint32_t user_id)
{
sqlite3_stmt *stmt;
const char *sql = "SELECT 1 FROM dm_rooms WHERE id = ? AND (user1_id = ? OR user2_id = ?);";
int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
if (rc != SQLITE_OK) return 0;
sqlite3_bind_int64(stmt, 1, room_id);
sqlite3_bind_int(stmt, 2, user_id);
sqlite3_bind_int(stmt, 3, user_id);
rc = sqlite3_step(stmt);
sqlite3_finalize(stmt);
return rc == SQLITE_ROW;
}
static int is_public_room(uint64_t room_id)
{
sqlite3_stmt *stmt;
const char *sql = "SELECT 1 FROM rooms WHERE id = ?;";
int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
if (rc != SQLITE_OK) return 0;
sqlite3_bind_int64(stmt, 1, room_id);
rc = sqlite3_step(stmt);
sqlite3_finalize(stmt);
return rc == SQLITE_ROW;
}
static int is_room_owner(uint64_t room_id, uint32_t user_id)
{
sqlite3_stmt *stmt;
const char *sql = "SELECT 1 FROM rooms WHERE id = ? AND owner_id = ?;";
int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
if (rc != SQLITE_OK) return 0;
sqlite3_bind_int64(stmt, 1, room_id);
sqlite3_bind_int(stmt, 2, user_id);
rc = sqlite3_step(stmt);
sqlite3_finalize(stmt);
return rc == SQLITE_ROW;
}
static int handle_create_room(struct client *client, const char *data, size_t len)
{
if (!client->is_authenticated) {
send_error(client, ERR_NOT_AUTHENTICATED);
return -1;
}
if (len < sizeof(struct packet_create_room)) {
send_error(client, ERR_INVALID_PACKET);
return -1;
}
const struct packet_create_room *create = (const struct packet_create_room *)data;
char name[33] = {0};
memcpy(name, create->name, 32);
if (strlen(name) == 0) {
send_error(client, ERR_INVALID_PACKET);
return -1;
}
sqlite3_stmt *stmt;
const char *sql = "INSERT INTO rooms (name, owner_id) VALUES (?, ?);";
int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
if (rc != SQLITE_OK) {
send_error(client, ERR_DATABASE);
return -1;
}
sqlite3_bind_text(stmt, 1, name, -1, SQLITE_STATIC);
sqlite3_bind_int(stmt, 2, client->user_id);
rc = sqlite3_step(stmt);
sqlite3_finalize(stmt);
if (rc == SQLITE_CONSTRAINT) {
send_error(client, ERR_ROOM_NAME_TAKEN);
return -1;
} else if (rc != SQLITE_DONE) {
send_error(client, ERR_DATABASE);
return -1;
}
int64_t room_id = sqlite3_last_insert_rowid(db);
struct packet_room_created response;
response.room_id = (uint64_t)room_id;
memset(response.name, 0, sizeof(response.name));
strncpy(response.name, name, sizeof(response.name) - 1);
send_packet(client, PACKET_ROOM_CREATED, &response, sizeof(response));
return 0;
}
static int handle_delete_room(struct client *client, const char *data, size_t len)
{
if (!client->is_authenticated) {
send_error(client, ERR_NOT_AUTHENTICATED);
return -1;
}
if (len < sizeof(struct packet_delete_room)) {
send_error(client, ERR_INVALID_PACKET);
return -1;
}
const struct packet_delete_room *del = (const struct packet_delete_room *)data;
if (!is_public_room(del->room_id)) {
send_error(client, ERR_ROOM_NOT_FOUND);
return -1;
}
if (!is_room_owner(del->room_id, client->user_id)) {
send_error(client, ERR_NOT_ROOM_OWNER);
return -1;
}
sqlite3_stmt *stmt;
const char *sql = "DELETE FROM rooms WHERE id = ?;";
int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
if (rc != SQLITE_OK) {
send_error(client, ERR_DATABASE);
return -1;
}
sqlite3_bind_int64(stmt, 1, del->room_id);
rc = sqlite3_step(stmt);
sqlite3_finalize(stmt);
if (rc != SQLITE_DONE) {
send_error(client, ERR_DATABASE);
return -1;
}
for (int i = 0; i < client_count; i++) {
if (clients[i]->current_room_id == del->room_id) {
clients[i]->current_room_id = 0;
}
}
send_ok(client);
return 0;
}
static int handle_list_rooms(struct client *client)
{
if (!client->is_authenticated) {
send_error(client, ERR_NOT_AUTHENTICATED);
return -1;
}
sqlite3_stmt *stmt;
const char *sql = "SELECT r.id, r.name, u.username FROM rooms r "
"JOIN users u ON r.owner_id = u.id ORDER BY r.name;";
int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
if (rc != SQLITE_OK) {
send_error(client, ERR_DATABASE);
return -1;
}
struct room_list_entry entries[256];
uint32_t count = 0;
while ((rc = sqlite3_step(stmt)) == SQLITE_ROW && count < 256) {
entries[count].room_id = sqlite3_column_int64(stmt, 0);
memset(entries[count].name, 0, sizeof(entries[count].name));
const char *name = (const char *)sqlite3_column_text(stmt, 1);
if (name) strncpy(entries[count].name, name, 31);
memset(entries[count].owner, 0, sizeof(entries[count].owner));
const char *owner = (const char *)sqlite3_column_text(stmt, 2);
if (owner) strncpy(entries[count].owner, owner, 19);
count++;
}
sqlite3_finalize(stmt);
size_t data_len = sizeof(uint32_t) + count * sizeof(struct room_list_entry);
char *data = malloc(data_len);
if (!data) {
send_error(client, ERR_DATABASE);
return -1;
}
memcpy(data, &count, sizeof(uint32_t));
memcpy(data + sizeof(uint32_t), entries, count * sizeof(struct room_list_entry));
send_packet(client, PACKET_ROOM_LIST, data, data_len);
free(data);
return 0;
}
static int handle_dm_open(struct client *client, const char *data, size_t len)
{
if (!client->is_authenticated) {
send_error(client, ERR_NOT_AUTHENTICATED);
return -1;
}
if (len < sizeof(struct packet_dm_open)) {
send_error(client, ERR_INVALID_PACKET);
return -1;
}
const struct packet_dm_open *dm = (const struct packet_dm_open *)data;
char target_username[21] = {0};
memcpy(target_username, dm->username, 20);
if (strlen(target_username) == 0) {
send_error(client, ERR_INVALID_PACKET);
return -1;
}
sqlite3_stmt *stmt;
const char *sql = "SELECT id FROM users WHERE username = ?;";
int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL);
if (rc != SQLITE_OK) {
send_error(client, ERR_DATABASE);
return -1;
}
sqlite3_bind_text(stmt, 1, target_username, -1, SQLITE_STATIC);
rc = sqlite3_step(stmt);
if (rc != SQLITE_ROW) {
sqlite3_finalize(stmt);
send_error(client, ERR_USER_NOT_FOUND);
return -1;
}
uint32_t target_user_id = sqlite3_column_int(stmt, 0);
sqlite3_finalize(stmt);
int64_t room_id = get_or_create_dm_room(client->user_id, target_user_id);
if (room_id < 0) {
send_error(client, ERR_DATABASE);
return -1;
}
struct packet_dm_room response = { .room_id = (uint64_t)room_id };
send_packet(client, PACKET_DM_ROOM, &response, sizeof(response));
return 0;
}
static void broadcast_to_room(uint64_t room_id, const char *username,
const char *message, size_t msg_len)
{
size_t bcast_len = sizeof(uint64_t) + 32 + msg_len;
char *bcast_data = malloc(bcast_len);
if (!bcast_data) return;
memcpy(bcast_data, &room_id, sizeof(uint64_t));
memset(bcast_data + sizeof(uint64_t), 0, 32);
strncpy(bcast_data + sizeof(uint64_t), username, 31);
memcpy(bcast_data + sizeof(uint64_t) + 32, message, msg_len);
for (int i = 0; i < client_count; i++) {
struct client *c = clients[i];
if (c->is_authenticated && c->current_room_id == room_id) {
send_packet(c, PACKET_TEXT, bcast_data, bcast_len);
}
}
free(bcast_data);
}
static int handle_text(struct client *client, const char *data, size_t len)
{
if (!client->is_authenticated) {
send_error(client, ERR_NOT_AUTHENTICATED);
return -1;
}
if (len < sizeof(uint64_t)) {
send_error(client, ERR_INVALID_PACKET);
return -1;
}
const struct packet_text *text = (const struct packet_text *)data;
size_t msg_len = len - sizeof(uint64_t);
if (client->current_room_id != text->room_id) {
send_error(client, ERR_NOT_AUTHENTICATED);
return -1;
}
broadcast_to_room(text->room_id, client->username, text->message, msg_len);
return 0;
}
void network_handle_data(struct client *client, const char *data, size_t len)
{
if (len < sizeof(struct packet_header)) {
send_error(client, ERR_INVALID_PACKET);
return;
}
struct packet_header *hdr = (struct packet_header *)data;
hdr->size = ntohs(hdr->size);
if (hdr->size > len || hdr->size < sizeof(struct packet_header)) {
send_error(client, ERR_INVALID_PACKET);
return;
}
const char *payload = data + sizeof(struct packet_header);
size_t payload_len = hdr->size - sizeof(struct packet_header);
switch (hdr->type) {
case PACKET_REGISTER:
handle_register(client, payload, payload_len);
break;
case PACKET_AUTHENTICATE:
handle_authenticate(client, payload, payload_len);
break;
case PACKET_JOIN:
handle_join(client, payload, payload_len);
break;
case PACKET_LEAVE:
handle_leave(client, payload, payload_len);
break;
case PACKET_TEXT:
handle_text(client, payload, payload_len);
break;
case PACKET_DM_OPEN:
handle_dm_open(client, payload, payload_len);
break;
case PACKET_CREATE_ROOM:
handle_create_room(client, payload, payload_len);
break;
case PACKET_DELETE_ROOM:
handle_delete_room(client, payload, payload_len);
break;
case PACKET_LIST_ROOMS:
handle_list_rooms(client);
break;
default:
send_error(client, ERR_INVALID_PACKET);
break;
}
}