diff --git a/parse.c b/parse.c index 87f59e5..16cdffb 100644 --- a/parse.c +++ b/parse.c @@ -3383,25 +3383,19 @@ static Node *new_add(Node *lhs, Node *rhs, Token *tok) { if (is_numeric(lhs->ty) && is_numeric(rhs->ty)) return new_binary(ND_ADD, lhs, rhs, tok); - if (lhs->ty->base && rhs->ty->base) - error_tok(tok, "invalid operands"); + Node **ofs = is_integer(lhs->ty) ? &lhs : is_integer(rhs->ty) ? &rhs : NULL; + Node *ptr = lhs->ty->base ? lhs : rhs->ty->base ? rhs : NULL; - // Canonicalize `num + ptr` to `ptr + num`. - if (!lhs->ty->base && rhs->ty->base) { - Node *tmp = lhs; - lhs = rhs; - rhs = tmp; - } + if (ptr && ofs) { + if (ptr->ty->base->kind == TY_VLA) + *ofs = new_binary(ND_MUL, *ofs, new_var_node(ptr->ty->base->vla_size, tok), tok); + else + *ofs = new_binary(ND_MUL, *ofs, new_long(ptr->ty->base->size, tok), tok); - // VLA + num - if (lhs->ty->base->kind == TY_VLA) { - rhs = new_binary(ND_MUL, rhs, new_var_node(lhs->ty->base->vla_size, tok), tok); return new_binary(ND_ADD, lhs, rhs, tok); } - // ptr + num - rhs = new_binary(ND_MUL, rhs, new_long(lhs->ty->base->size, tok), tok); - return new_binary(ND_ADD, lhs, rhs, tok); + error_tok(tok, "invalid operands"); } // Like `+`, `-` is overloaded for the pointer type. @@ -3413,15 +3407,13 @@ static Node *new_sub(Node *lhs, Node *rhs, Token *tok) { if (is_numeric(lhs->ty) && is_numeric(rhs->ty)) return new_binary(ND_SUB, lhs, rhs, tok); - // VLA - num - if (lhs->ty->base->kind == TY_VLA) { - rhs = new_binary(ND_MUL, rhs, new_var_node(lhs->ty->base->vla_size, tok), tok); - return new_binary(ND_SUB, lhs, rhs, tok); - } - // ptr - num if (lhs->ty->base && is_integer(rhs->ty)) { - rhs = new_binary(ND_MUL, rhs, new_long(lhs->ty->base->size, tok), tok); + if (lhs->ty->base->kind == TY_VLA) + rhs = new_binary(ND_MUL, rhs, new_var_node(lhs->ty->base->vla_size, tok), tok); + else + rhs = new_binary(ND_MUL, rhs, new_long(lhs->ty->base->size, tok), tok); + return new_binary(ND_SUB, lhs, rhs, tok); } diff --git a/type.c b/type.c index cc60bfa..8b022ea 100644 --- a/type.c +++ b/type.c @@ -399,17 +399,18 @@ void add_type(Node *node) { node->ty = ty_int; return; case ND_ADD: - case ND_SUB: - if (node->lhs->ty->base) { - if (node->lhs->ty->kind != TY_PTR) - node->lhs = new_cast(node->lhs, pointer_to(node->lhs->ty->base)); - node->rhs = new_cast(node->rhs, ty_ullong); - node->ty = node->lhs->ty; + case ND_SUB: { + Node **ptr = node->lhs->ty->base ? &node->lhs : node->rhs->ty->base ? &node->rhs : NULL; + if (ptr) { + if ((*ptr)->ty->kind != TY_PTR) + *ptr = new_cast(*ptr, pointer_to((*ptr)->ty->base)); + node->ty = (*ptr)->ty; return; } usual_arith_conv(&node->lhs, &node->rhs, false); node->ty = node->lhs->ty; return; + } case ND_MUL: case ND_DIV: case ND_MOD: