#include "call.h" #include "errloc.h" #include "fn.h" #include "lex.h" #include "lit.h" #include "parse.h" #include "prv.h" #include "stmt.h" #include "type.h" #include #include #include #include static int callpr(const struct lex *l, struct prv *p); static int callp(const struct lex *l, struct prv *p); static int callr(const struct lex *l, struct prv *p); static int callv(const struct lex *l, struct prv *p); static const struct seq seqs[] = { {(struct step[]){ {ID}, {ID, "using"}, {ANY, NULL, 1}, {ID, "returning"}, {ID}, {0}}, .fn = callpr}, {(struct step[]){ {ID}, {ID, "using"}, {ANY, NULL, 1}, {0}}, .fn = callp}, {(struct step[]){ {ID}, {ID, "returning"}, {ID}, {0}}, .fn = callr}, {(struct step[]){ {ID}, {0}}, .fn = callv}, {0} }; static int param(const struct lex *l, struct prv *p, const struct tk *tk, struct call *c) { const struct fn *fn = fn_cur(p); size_t n = c->nparams + 1; struct callparam *params; const struct lit *lit = NULL; const struct stentry *e = NULL; const struct type *t = NULL; int sign = 0; union callv u = {0}; if (tk->type == LIT) { if (!(lit = lit_push(tk, p))) return -1; } else if (tk->type == NUM) { if ((sign = *tk->s == '-')) u.v = strtoll(tk->s, NULL, 0); else u.uv = strtoull(tk->s, NULL, 0); } else if ((t = type_find(fn, tk->s))) { if (t->type != C) { errloc(tk, "type \"%s\" not a constant", tk->s); return -1; } } else if (!(e = fn_var(fn, tk))) { errloc(tk, "undefined reference to parameter \"%s\"", tk->s); return -1; } if (!(params = realloc(c->params, n * sizeof *params))) { perror("realloc(3)"); return -1; } c->params = params; c->params[c->nparams++] = (struct callparam) { .entry = e, .lit = lit, .t = t, .u = u }; return 0; } static int callpr(const struct lex *l, struct prv *p) { const struct tk *tk = p->stk + 2; struct call *c = &stmt_cur(p)->u.call; const struct pr *pr = c->pr; const struct stentry *e; size_t n = 0; while (!lex_eof(l, tk) && !kw((tk++)->s)) n++; if (pr->variadic) { if (n < pr->nparams) { errloc(tk, "function \"%s\" expects at least %zu parameters, " "but %zu were given", c->tk->s, pr->nparams, n); return -1; } } else if (n != pr->nparams) { tk -= 2; errloc(tk, "function \"%s\" expects %zu parameters, " "but %zu were given", c->tk->s, pr->nparams, n); return -1; } else if (!pr->ret) { errloc(tk, "function \"%s\" does not expect any return variable", c->tk->s); return -1; } tk = p->stk + 2; for (size_t i = 0; i < n; i++) if (param(l, p, tk++, c)) return -1; if (!(e = fn_var(fn_cur(p), ++tk))) { errloc(tk, "undefined reference to return variable \"%s\"", tk->s); return -1; } else if (pop(l, p)) return -1; fprintf(stderr, "\t\tadding call to %s using %zu params " "and %s as return variable\n", c->tk->s, n, tk->s); c->ret = e; c->nparams = n; p->stk = ++tk; return 1; } static int callr(const struct lex *l, struct prv *p) { const struct tk *tk = p->stk + 2; const struct fn *fn = fn_cur(p); const struct stentry *e; struct call *c = &stmt_cur(p)->u.call; const struct pr *pr = c->pr; if (!pr->ret) { errloc(tk, "function \"%s\" does not expect any return variable", c->tk->s); return -1; } else if (pr->nparams) { errloc(tk, "function \"%s\" expects %zu parameters, " "but none were given", c->tk->s, pr->nparams); return -1; } else if (!(e = fn_var(fn, tk))) { errloc(tk, "undefined reference to return variable \"%s\"", tk->s); return -1; } else if (pop(l, p)) return -1; fprintf(stderr, "\t\tadding call to %s with %s as return variable\n", c->tk->s, tk->s); c->ret = e; p->stk = ++tk; return 1; } static int callp(const struct lex *l, struct prv *p) { const struct tk *tk = p->stk + 2; struct call *c = &stmt_cur(p)->u.call; const struct pr *pr = c->pr; size_t n = 0; while (!lex_eof(l, tk) && !kw((tk++)->s)) n++; if (pr->variadic) { if (n < pr->nparams) { errloc(tk, "function \"%s\" expects at least %zu parameters, " "but %zu were given", c->tk->s, pr->nparams, n); return -1; } } else if (pr->nparams != n) { errloc(tk, "function \"%s\" expects %zu parameters, " "but %zu were given", c->tk->s, pr->nparams, n); return -1; } tk = p->stk + 2; for (size_t i = 0; i < n; i++) if (param(l, p, tk++, c)) return -1; if (pop(l, p)) return -1; fprintf(stderr, "\t\tadding call to %s with 1 param\n", c->tk->s); p->stk = tk; return 1; } static int callv(const struct lex *l, struct prv *p) { const struct tk *tk = p->stk; struct call *c = &stmt_cur(p)->u.call; const struct pr *pr = c->pr; if (pr->nparams) { errloc(tk, "function \"%s\" expects %zu parameters, " "but none were given", c->tk->s, pr->nparams); return -1; } else if (pop(l, p)) return -1; fprintf(stderr, "\t\tadding call to %s with %s as return variable\n", c->tk->s, tk->s); p->stk = ++tk; return 1; } static const struct pr *find_proto(const struct fn *fn, const char *s) { const struct type *t = type_find(fn, s); if (t && t->type == P) return t->u.p.fn->pr; return NULL; } static const struct pr *find(const struct ast *ast, const char *s) { for (size_t i = 0; i < ast->nfns; i++) { const struct fn *fn = &ast->fns[i]; const struct pr *pr; if ((pr = find_proto(fn, s))) return pr; else if (!strcmp(fn->tk->s, s)) return fn->pr; } return NULL; } static int gettmp(char *tmp, size_t n) { size_t j; for (j = 0; j < n; j++) { char c = tmp[j]; if (!c) return tmp[j] = 'a'; else if (c != 'z') return ++(tmp[j]); } return EOF; } static int param_lit(const struct lit *lit, struct cgen *c) { printf("l $%s", lit->name); return 0; } static int param_id(const struct stentry *e, struct cgen *c, char *tmp, size_t n) { const char *name = e->tk->s; const struct type *t = e->t; printf("%s ", cgen_sz(t->sz)); if (cgen_global(c->fn, e)) printf("$%s", name); else if (cgen_abity(t)) { if (gettmp(tmp, n) == EOF) return -1; printf("%%%s", tmp); } else printf("%%%s_", name); return 0; } static int param_c(const struct type *t, struct cgen *cg) { const union c *c = &t->u.c; /* TODO: adapt type to expected param size */ fputs("l ", stdout); if (t->sign) printf("%lld", c->v); else printf("%llu", c->uv); return 0; } static int param_num(const struct callparam *cp, struct cgen *c) { const union callv *u = &cp->u; /* TODO: adapt type to expected param size */ fputs("w ", stdout); if (cp->sign) printf("%lld", u->v); else printf("%llu", u->uv); return 0; } static int param_load(const struct call *m, struct cgen *c) { char tmp[sizeof "abcdefghijklmnopqrstuvwxyz"] = {0}; for (size_t i = 0; i < m->nparams; i++) { const struct callparam *cp = &m->params[i]; const struct stentry *e = cp->entry; if (!e || cgen_global(c->fn, e) || !cgen_abity(e->t)) continue; else if (gettmp(tmp, sizeof tmp - 1) == EOF) { fprintf(stderr, "%s: exhausted temporaries\n", __func__); return -1; } printf("%%%s =%s load%s %%%s_\n", tmp, cgen_sz(e->t->sz), cgen_load(e->t), e->tk->s); } return 0; } int call_cgen(const struct call *m, struct cgen *c) { char tmps[sizeof "abcdefghijklmnopqrstuvwxyz"] = {0}; const struct stentry *ret = m->ret; const char *tmp; size_t sz; if (param_load(m, c)) return -1; else if (ret) { tmp = cgen_tmp((sz = ret->t->sz)); printf("%%%s =%s ", tmp, cgen_sz(sz)); } printf("call $%s(", m->tk->s); for (size_t i = 0; i < m->nparams; i++) { const struct callparam *cp = &m->params[i]; int var = m->pr->variadic; if (cp->lit) param_lit(cp->lit, c); else if (cp->entry) param_id(cp->entry, c, tmps, sizeof tmps - 1); else if (cp->t) param_c(cp->t, c); else param_num(cp, c); if (var && i + 1 == var - 1) fputs(", ...", stdout); if (i + 1 < m->nparams) fputs(", ", stdout); } puts(")"); if (ret) { const char *name = ret->tk->s; printf("store%s %%%s, ", cgen_sz(sz), tmp); if (cgen_global(c->fn, ret)) printf("%s", name); else printf("%%%s_", name); putchar('\n'); } return 0; } void call_free(struct call *c) { if (c) free(c->params); } int call(const struct lex *l, struct prv *p) { const struct tk *tk = p->tk; const struct pr *pr; struct fn *fn = fn_cur(p); size_t n = fn->nstmts +1; struct stmt *stmts; struct pos init = { .seq = seqs, .stseq = seqs, .step = seqs->steps }; /* TODO: function pointers */ if (lex_eof(l, tk)) { errloc(tk - 1, "incomplete call statement"); return -1; } else if (!(pr = find(p->ast, tk->s))) { errloc(tk, "undefined reference to function \"%s\"", tk->s); return -1; } else if (!(stmts = realloc(fn->stmts, n * sizeof *stmts))) { perror("realloc(3)"); return -1; } fn->stmts = stmts; fn->stmts[fn->nstmts++] = (struct stmt) { .type = CALL, .u.call = { .tk = tk, .pr = pr } }; if (push(&init, p)) return -1; p->stk = p->tk; return 1; }