jancity/src/transport/src/transport.c

606 lines
14 KiB
C

#include <transport.h>
#include <transport_private.h>
#include <inttypes.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>
typedef uint8_t packet_header, packet_len;
int transport_connect(struct transport_handle *const h)
{
union transport_packet *const p = transport_packet_alloc(
TRANSPORT_PACKET_TYPE_CONNECT);
if (!p)
return -1;
return transport_append(h, p);
}
int transport_disconnect(struct transport_handle *const h)
{
union transport_packet *const p = transport_packet_alloc(
TRANSPORT_PACKET_TYPE_DISCONNECT);
if (!p)
return -1;
return transport_append(h, p);
}
int transport_send(struct transport_handle *const h,
const void *const buf, size_t n)
{
const uint8_t *b = buf;
while (n)
{
union transport_packet *const p = transport_packet_alloc(
TRANSPORT_PACKET_TYPE_DATA);
if (!p)
return -1;
struct transport_packet_data *const d = &p->data;
const size_t rem = n > sizeof d->buf ? sizeof d->buf : n;
memcpy(d->buf, b, rem);
d->n = rem;
if (transport_append(h, p))
return -1;
n -= rem;
b += rem;
}
return 0;
}
static enum transport_state header_next_state(const enum transport_packet_type t)
{
switch (t)
{
case TRANSPORT_PACKET_TYPE_CONNECT:
/* Fall through. */
case TRANSPORT_PACKET_TYPE_DISCONNECT:
/* Fall through. */
case TRANSPORT_PACKET_TYPE_NACK:
/* Fall through. */
return TRANSPORT_STATE_CHECKSUM;
case TRANSPORT_PACKET_TYPE_ACK:
return TRANSPORT_STATE_BODY;
case TRANSPORT_PACKET_TYPE_DATA:
return TRANSPORT_STATE_LEN;
default:
break;
}
return 0;
}
static int read_header(const struct transport_cfg *const cfg,
struct transport_input *const in, bool *const done)
{
packet_header header;
const int res = cfg->read(&header, sizeof header, cfg->arg);
if (res < 0)
return res;
else if (res != sizeof header)
return 1;
if (header >= MAX_TRANSPORT_PACKET_TYPES)
{
fprintf(stderr, "%s: invalid header %#" PRIx8 "\n",
__func__, header);
return -1;
}
if (!(in->p = transport_packet_alloc(header)))
{
fprintf(stderr, "%s: transport_packet_alloc failed\n", __func__);
return -1;
}
struct transport_common *const c = &in->p->common;
c->state = header_next_state(c->type);
return 0;
}
static int read_len(const struct transport_cfg *const cfg,
struct transport_input *const in, bool *const done)
{
packet_len len;
const int res = cfg->read(&len, sizeof len, cfg->arg);
if (res < 0)
return res;
else if (res != sizeof len)
return 1;
struct transport_packet_data *const d = &in->p->data;
if (len >= sizeof d->buf)
{
fprintf(stderr, "%s: invalid length %" PRIu8 "\n", __func__, len);
return -1;
}
d->n = len;
d->common.state = TRANSPORT_STATE_BODY;
return 0;
}
static int read_data_body(const struct transport_cfg *const cfg,
struct transport_input *const in)
{
struct transport_packet_data *const d = &in->p->data;
const size_t rem = d->n - d->written;
const int res = cfg->read(&d->buf[d->written], rem, cfg->arg);
if (res < 0)
return res;
else if (res != rem)
{
d->written += res;
return 1;
}
in->p->common.state = TRANSPORT_STATE_CHECKSUM;
return 0;
}
static int read_ack_body(const struct transport_cfg *const cfg,
struct transport_input *const in)
{
struct transport_packet_ack *const a = &in->p->ack;
const int res = cfg->read(&a->checksum, sizeof a->checksum, cfg->arg);
if (res < 0)
return res;
else if (res != sizeof a->checksum)
return 1;
in->p->common.state = TRANSPORT_STATE_CHECKSUM;
return 0;
}
static int read_body(const struct transport_cfg *const cfg,
struct transport_input *const in, bool *const done)
{
const struct transport_common *const c = &in->p->common;
switch (c->type)
{
case TRANSPORT_PACKET_TYPE_ACK:
return read_ack_body(cfg, in);
case TRANSPORT_PACKET_TYPE_DATA:
return read_data_body(cfg, in);
default:
break;
}
return -1;
}
static transport_checksum calc_checksum(const union transport_packet *const p)
{
const struct transport_common *const c = &p->common;
transport_checksum ret = c->type;
switch (c->type)
{
case TRANSPORT_PACKET_TYPE_DATA:
{
const struct transport_packet_data *const d = &p->data;
ret += d->n;
for (size_t i = 0; i < d->n; i++)
ret += d->buf[i];
}
break;
case TRANSPORT_PACKET_TYPE_ACK:
ret += p->ack.checksum;
break;
default:
break;
}
return ~ret;
}
static int read_checksum(const struct transport_cfg *const cfg,
struct transport_input *const in, bool *const done)
{
transport_checksum checksum;
const int res = cfg->read(&checksum, sizeof checksum, cfg->arg);
if (res < 0)
return res;
else if (res != sizeof checksum)
return 1;
const transport_checksum expected = calc_checksum(in->p);
*done = true;
if (checksum != expected)
{
fprintf(stderr, "%s: invalid checksum %#" PRIx8
", expected %#" PRIx8 "\n",
__func__, checksum, expected);
return -1;
}
return 0;
}
static int send_nack(struct transport_handle *const h)
{
union transport_packet *const p = transport_packet_alloc(
TRANSPORT_PACKET_TYPE_NACK);
if (!p)
return -1;
return transport_append(h, p);
}
static bool requires_ack(const enum transport_packet_type t)
{
static const bool r[] =
{
[TRANSPORT_PACKET_TYPE_CONNECT] = true,
[TRANSPORT_PACKET_TYPE_DISCONNECT] = true,
[TRANSPORT_PACKET_TYPE_ACK] = false,
[TRANSPORT_PACKET_TYPE_NACK] = false,
[TRANSPORT_PACKET_TYPE_DATA] = true
};
return r[t];
}
static int send_ack(struct transport_handle *const h)
{
union transport_packet *const p = transport_packet_alloc(
TRANSPORT_PACKET_TYPE_ACK);
if (!p)
return -1;
p->ack.checksum = calc_checksum(h->input.p);
return transport_append(h, p);
}
static void delete_input(struct transport_input *const in)
{
transport_packet_free(in->p);
in->p = NULL;
}
static void send_event(const struct transport_cfg *const cfg,
const union transport_packet *const p,
struct transport_event *const ev)
{
switch (p->common.type)
{
case TRANSPORT_PACKET_TYPE_CONNECT:
ev->common.type = TRANSPORT_EVENT_TYPE_CONNECT;
break;
case TRANSPORT_PACKET_TYPE_DISCONNECT:
ev->common.type = TRANSPORT_EVENT_TYPE_DISCONNECT;
break;
case TRANSPORT_PACKET_TYPE_DATA:
ev->common.type = TRANSPORT_EVENT_TYPE_DATA;
ev->u.data.buf = p->data.buf;
ev->u.data.n = p->data.n;
break;
default:
return;
}
if (cfg->received)
cfg->received(ev, cfg->arg);
}
static int remove_packet(struct transport_handle *const h,
union transport_packet **const pp)
{
size_t i = pp - h->packets;
const size_t n = i + 1;
if (n < h->n_packets)
memmove(pp, pp + 1, h->n_packets - n);
return transport_pop(h);
}
static void get_event(const struct transport_cfg *const cfg,
const union transport_packet *const p,
struct transport_event *const ev)
{
switch (p->common.type)
{
case TRANSPORT_PACKET_TYPE_CONNECT:
ev->common.type = TRANSPORT_EVENT_TYPE_CONNECT;
break;
case TRANSPORT_PACKET_TYPE_DISCONNECT:
ev->common.type = TRANSPORT_EVENT_TYPE_DISCONNECT;
break;
default:
return;
}
if (cfg->received)
cfg->received(ev, cfg->arg);
}
static int process_ack(struct transport_handle *const h,
const struct transport_packet_ack *const a)
{
for (size_t i = 0; i < h->n_packets; i++)
{
union transport_packet **const pp = &h->packets[i];
const union transport_packet *const p = *pp;
if (p->common.ttl && a->checksum == calc_checksum(p))
{
get_event(&h->cfg, p, &h->input.ev);
if (remove_packet(h, pp))
return -1;
break;
}
}
return 0;
}
static int process_nack(struct transport_handle *const h)
{
if (!h->n_packets)
{
fprintf(stderr, "%s: received nack without previous sent packet\n",
__func__);
return -1;
}
h->packets[0]->common.ttl = 0;
return 0;
}
static int process_input(struct transport_handle *const h)
{
const union transport_packet *const p = h->input.p;
switch (p->common.type)
{
case TRANSPORT_PACKET_TYPE_ACK:
return process_ack(h, &p->ack);
case TRANSPORT_PACKET_TYPE_NACK:
return process_nack(h);
default:
break;
}
return 0;
}
static int update_input(struct transport_handle *const h)
{
int ret = -1;
bool done = false;
struct transport_input *const in = &h->input;
static int (*const f[])(const struct transport_cfg *,
struct transport_input *, bool *) =
{
[TRANSPORT_STATE_HEADER] = read_header,
[TRANSPORT_STATE_LEN] = read_len,
[TRANSPORT_STATE_BODY] = read_body,
[TRANSPORT_STATE_CHECKSUM] = read_checksum
};
const struct transport_cfg *const cfg = &h->cfg;
while (!done)
{
const enum transport_state s = in->p ? in->p->common.state : 0;
const int res = f[s](cfg, in, &done);
if (res < 0)
{
ret = send_nack(h);
goto end;
}
else if (res)
return 0;
}
const union transport_packet *const p = in->p;
const enum transport_packet_type t = p->common.type;
if (process_input(h))
goto end;
else if (requires_ack(t))
{
if (send_ack(h))
{
fprintf(stderr, "%s: send_ack failed\n", __func__);
goto end;
}
send_event(&h->cfg, p, &h->input.ev);
}
ret = 0;
end:
delete_input(in);
return ret;
}
static int write_header(const struct transport_cfg *const cfg,
union transport_packet *const p)
{
const packet_header header = p->common.type;
struct transport_common *const c = &p->common;
const int res = cfg->write(&header, sizeof header, cfg->arg);
if (res < 0)
return res;
else if (res != sizeof header)
return 1;
c->state = header_next_state(c->type);
return 0;
}
static int write_len(const struct transport_cfg *const cfg,
union transport_packet *const p)
{
const packet_len len = p->data.n;
const int res = cfg->write(&len, sizeof len, cfg->arg);
if (res < 0)
return res;
else if (res != sizeof len)
return 1;
p->common.state = TRANSPORT_STATE_BODY;
return 0;
}
static int write_ack_body(const struct transport_cfg *const cfg,
struct transport_packet_ack *const a)
{
const int res = cfg->write(&a->checksum, sizeof a->checksum, cfg->arg);
if (res < 0)
return res;
else if (res != sizeof a->checksum)
return 1;
a->common.state = TRANSPORT_STATE_CHECKSUM;
return 0;
}
static int write_data_body(const struct transport_cfg *const cfg,
struct transport_packet_data *const d)
{
const size_t rem = d->n - d->written;
const int res = cfg->write(&d->buf[d->written], rem, cfg->arg);
if (res < 0)
return res;
else if (res != rem)
{
d->written += res;
return 1;
}
d->common.state = TRANSPORT_STATE_CHECKSUM;
return 0;
}
static int write_body(const struct transport_cfg *const cfg,
union transport_packet *const p)
{
struct transport_common *const c = &p->common;
switch (c->type)
{
case TRANSPORT_PACKET_TYPE_ACK:
return write_ack_body(cfg, &p->ack);
case TRANSPORT_PACKET_TYPE_DATA:
return write_data_body(cfg, &p->data);
default:
break;
}
return -1;
}
static int write_checksum(const struct transport_cfg *const cfg,
union transport_packet *const p)
{
const transport_checksum checksum = calc_checksum(p);
const int res = cfg->write(&checksum, sizeof checksum, cfg->arg);
if (res < 0)
return res;
else if (res != sizeof checksum)
return 1;
p->common.done = true;
return 0;
}
static int update_output(struct transport_handle *const h)
{
for (size_t i = 0; i < h->n_packets; i++)
{
union transport_packet *const p = h->packets[i];
struct transport_common *const c = &p->common;
if (!c->ttl)
{
static int (*const f[])(const struct transport_cfg *,
union transport_packet *) =
{
[TRANSPORT_STATE_HEADER] = write_header,
[TRANSPORT_STATE_LEN] = write_len,
[TRANSPORT_STATE_BODY] = write_body,
[TRANSPORT_STATE_CHECKSUM] = write_checksum
};
while (!c->done)
{
const int res = f[c->state](&h->cfg, p);
if (res < 0)
return res;
else if (res)
return 0;
}
/* Arbitrary value, also depends on frame rate, yet good enough. */
enum {TTL = 200};
c->ttl = TTL;
}
else
c->ttl--;
}
return 0;
}
int transport_update(struct transport_handle *const h)
{
return update_input(h) || update_output(h);
}