bpf: allow access into map value arrays

Suppose you have a map array value that is something like this

struct foo {
	unsigned iter;
	int array[SOME_CONSTANT];
};

You can easily insert this into an array, but you cannot modify the contents of
foo->array[] after the fact.  This is because we have no way to verify we won't
go off the end of the array at verification time.  This patch provides a start
for this work.  We accomplish this by keeping track of a minimum and maximum
value a register could be while we're checking the code.  Then at the time we
try to do an access into a MAP_VALUE we verify that the maximum offset into that
region is a valid access into that memory region.  So in practice, code such as
this

unsigned index = 0;

if (foo->iter >= SOME_CONSTANT)
	foo->iter = index;
else
	index = foo->iter++;
foo->array[index] = bar;

would be allowed, as we can verify that index will always be between 0 and
SOME_CONSTANT-1.  If you wish to use signed values you'll have to have an extra
check to make sure the index isn't less than 0, or do something like index %=
SOME_CONSTANT.

Signed-off-by: Josef Bacik <jbacik@fb.com>
Acked-by: Alexei Starovoitov <ast@kernel.org>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 7ada315..99a7e5b 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -182,6 +182,7 @@
 	[CONST_PTR_TO_MAP]	= "map_ptr",
 	[PTR_TO_MAP_VALUE]	= "map_value",
 	[PTR_TO_MAP_VALUE_OR_NULL] = "map_value_or_null",
+	[PTR_TO_MAP_VALUE_ADJ]	= "map_value_adj",
 	[FRAME_PTR]		= "fp",
 	[PTR_TO_STACK]		= "fp",
 	[CONST_IMM]		= "imm",
@@ -209,10 +210,17 @@
 		else if (t == UNKNOWN_VALUE && reg->imm)
 			verbose("%lld", reg->imm);
 		else if (t == CONST_PTR_TO_MAP || t == PTR_TO_MAP_VALUE ||
-			 t == PTR_TO_MAP_VALUE_OR_NULL)
+			 t == PTR_TO_MAP_VALUE_OR_NULL ||
+			 t == PTR_TO_MAP_VALUE_ADJ)
 			verbose("(ks=%d,vs=%d)",
 				reg->map_ptr->key_size,
 				reg->map_ptr->value_size);
+		if (reg->min_value != BPF_REGISTER_MIN_RANGE)
+			verbose(",min_value=%llu",
+				(unsigned long long)reg->min_value);
+		if (reg->max_value != BPF_REGISTER_MAX_RANGE)
+			verbose(",max_value=%llu",
+				(unsigned long long)reg->max_value);
 	}
 	for (i = 0; i < MAX_BPF_STACK; i += BPF_REG_SIZE) {
 		if (state->stack_slot_type[i] == STACK_SPILL)
@@ -424,6 +432,8 @@
 	for (i = 0; i < MAX_BPF_REG; i++) {
 		regs[i].type = NOT_INIT;
 		regs[i].imm = 0;
+		regs[i].min_value = BPF_REGISTER_MIN_RANGE;
+		regs[i].max_value = BPF_REGISTER_MAX_RANGE;
 	}
 
 	/* frame pointer */
@@ -440,6 +450,12 @@
 	regs[regno].imm = 0;
 }
 
+static void reset_reg_range_values(struct bpf_reg_state *regs, u32 regno)
+{
+	regs[regno].min_value = BPF_REGISTER_MIN_RANGE;
+	regs[regno].max_value = BPF_REGISTER_MAX_RANGE;
+}
+
 enum reg_arg_type {
 	SRC_OP,		/* register is used as source operand */
 	DST_OP,		/* register is used as destination operand */
@@ -665,7 +681,7 @@
 static int check_ptr_alignment(struct bpf_verifier_env *env,
 			       struct bpf_reg_state *reg, int off, int size)
 {
-	if (reg->type != PTR_TO_PACKET) {
+	if (reg->type != PTR_TO_PACKET && reg->type != PTR_TO_MAP_VALUE_ADJ) {
 		if (off % size != 0) {
 			verbose("misaligned access off %d size %d\n",
 				off, size);
@@ -675,16 +691,6 @@
 		}
 	}
 
-	switch (env->prog->type) {
-	case BPF_PROG_TYPE_SCHED_CLS:
-	case BPF_PROG_TYPE_SCHED_ACT:
-	case BPF_PROG_TYPE_XDP:
-		break;
-	default:
-		verbose("verifier is misconfigured\n");
-		return -EACCES;
-	}
-
 	if (IS_ENABLED(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS))
 		/* misaligned access to packet is ok on x86,arm,arm64 */
 		return 0;
@@ -695,7 +701,8 @@
 	}
 
 	/* skb->data is NET_IP_ALIGN-ed */
-	if ((NET_IP_ALIGN + reg->off + off) % size != 0) {
+	if (reg->type == PTR_TO_PACKET &&
+	    (NET_IP_ALIGN + reg->off + off) % size != 0) {
 		verbose("misaligned packet access off %d+%d+%d size %d\n",
 			NET_IP_ALIGN, reg->off, off, size);
 		return -EACCES;
@@ -728,12 +735,52 @@
 	if (err)
 		return err;
 
-	if (reg->type == PTR_TO_MAP_VALUE) {
+	if (reg->type == PTR_TO_MAP_VALUE ||
+	    reg->type == PTR_TO_MAP_VALUE_ADJ) {
 		if (t == BPF_WRITE && value_regno >= 0 &&
 		    is_pointer_value(env, value_regno)) {
 			verbose("R%d leaks addr into map\n", value_regno);
 			return -EACCES;
 		}
+
+		/* If we adjusted the register to this map value at all then we
+		 * need to change off and size to min_value and max_value
+		 * respectively to make sure our theoretical access will be
+		 * safe.
+		 */
+		if (reg->type == PTR_TO_MAP_VALUE_ADJ) {
+			if (log_level)
+				print_verifier_state(state);
+			env->varlen_map_value_access = true;
+			/* The minimum value is only important with signed
+			 * comparisons where we can't assume the floor of a
+			 * value is 0.  If we are using signed variables for our
+			 * index'es we need to make sure that whatever we use
+			 * will have a set floor within our range.
+			 */
+			if ((s64)reg->min_value < 0) {
+				verbose("R%d min value is negative, either use unsigned index or do a if (index >=0) check.\n",
+					regno);
+				return -EACCES;
+			}
+			err = check_map_access(env, regno, reg->min_value + off,
+					       size);
+			if (err) {
+				verbose("R%d min value is outside of the array range\n",
+					regno);
+				return err;
+			}
+
+			/* If we haven't set a max value then we need to bail
+			 * since we can't be sure we won't do bad things.
+			 */
+			if (reg->max_value == BPF_REGISTER_MAX_RANGE) {
+				verbose("R%d unbounded memory access, make sure to bounds check any array access into a map\n",
+					regno);
+				return -EACCES;
+			}
+			off += reg->max_value;
+		}
 		err = check_map_access(env, regno, off, size);
 		if (!err && t == BPF_READ && value_regno >= 0)
 			mark_reg_unknown_value(state->regs, value_regno);
@@ -1195,6 +1242,7 @@
 		regs[BPF_REG_0].type = NOT_INIT;
 	} else if (fn->ret_type == RET_PTR_TO_MAP_VALUE_OR_NULL) {
 		regs[BPF_REG_0].type = PTR_TO_MAP_VALUE_OR_NULL;
+		regs[BPF_REG_0].max_value = regs[BPF_REG_0].min_value = 0;
 		/* remember map_ptr, so that check_map_access()
 		 * can check 'value_size' boundary of memory access
 		 * to map element returned from bpf_map_lookup_elem()
@@ -1416,6 +1464,106 @@
 	return 0;
 }
 
+static void check_reg_overflow(struct bpf_reg_state *reg)
+{
+	if (reg->max_value > BPF_REGISTER_MAX_RANGE)
+		reg->max_value = BPF_REGISTER_MAX_RANGE;
+	if ((s64)reg->min_value < BPF_REGISTER_MIN_RANGE)
+		reg->min_value = BPF_REGISTER_MIN_RANGE;
+}
+
+static void adjust_reg_min_max_vals(struct bpf_verifier_env *env,
+				    struct bpf_insn *insn)
+{
+	struct bpf_reg_state *regs = env->cur_state.regs, *dst_reg;
+	u64 min_val = BPF_REGISTER_MIN_RANGE, max_val = BPF_REGISTER_MAX_RANGE;
+	bool min_set = false, max_set = false;
+	u8 opcode = BPF_OP(insn->code);
+
+	dst_reg = &regs[insn->dst_reg];
+	if (BPF_SRC(insn->code) == BPF_X) {
+		check_reg_overflow(&regs[insn->src_reg]);
+		min_val = regs[insn->src_reg].min_value;
+		max_val = regs[insn->src_reg].max_value;
+
+		/* If the source register is a random pointer then the
+		 * min_value/max_value values represent the range of the known
+		 * accesses into that value, not the actual min/max value of the
+		 * register itself.  In this case we have to reset the reg range
+		 * values so we know it is not safe to look at.
+		 */
+		if (regs[insn->src_reg].type != CONST_IMM &&
+		    regs[insn->src_reg].type != UNKNOWN_VALUE) {
+			min_val = BPF_REGISTER_MIN_RANGE;
+			max_val = BPF_REGISTER_MAX_RANGE;
+		}
+	} else if (insn->imm < BPF_REGISTER_MAX_RANGE &&
+		   (s64)insn->imm > BPF_REGISTER_MIN_RANGE) {
+		min_val = max_val = insn->imm;
+		min_set = max_set = true;
+	}
+
+	/* We don't know anything about what was done to this register, mark it
+	 * as unknown.
+	 */
+	if (min_val == BPF_REGISTER_MIN_RANGE &&
+	    max_val == BPF_REGISTER_MAX_RANGE) {
+		reset_reg_range_values(regs, insn->dst_reg);
+		return;
+	}
+
+	switch (opcode) {
+	case BPF_ADD:
+		dst_reg->min_value += min_val;
+		dst_reg->max_value += max_val;
+		break;
+	case BPF_SUB:
+		dst_reg->min_value -= min_val;
+		dst_reg->max_value -= max_val;
+		break;
+	case BPF_MUL:
+		dst_reg->min_value *= min_val;
+		dst_reg->max_value *= max_val;
+		break;
+	case BPF_AND:
+		/* & is special since it could end up with 0 bits set. */
+		dst_reg->min_value &= min_val;
+		dst_reg->max_value = max_val;
+		break;
+	case BPF_LSH:
+		/* Gotta have special overflow logic here, if we're shifting
+		 * more than MAX_RANGE then just assume we have an invalid
+		 * range.
+		 */
+		if (min_val > ilog2(BPF_REGISTER_MAX_RANGE))
+			dst_reg->min_value = BPF_REGISTER_MIN_RANGE;
+		else
+			dst_reg->min_value <<= min_val;
+
+		if (max_val > ilog2(BPF_REGISTER_MAX_RANGE))
+			dst_reg->max_value = BPF_REGISTER_MAX_RANGE;
+		else
+			dst_reg->max_value <<= max_val;
+		break;
+	case BPF_RSH:
+		dst_reg->min_value >>= min_val;
+		dst_reg->max_value >>= max_val;
+		break;
+	case BPF_MOD:
+		/* % is special since it is an unsigned modulus, so the floor
+		 * will always be 0.
+		 */
+		dst_reg->min_value = 0;
+		dst_reg->max_value = max_val - 1;
+		break;
+	default:
+		reset_reg_range_values(regs, insn->dst_reg);
+		break;
+	}
+
+	check_reg_overflow(dst_reg);
+}
+
 /* check validity of 32-bit and 64-bit arithmetic operations */
 static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn)
 {
@@ -1479,6 +1627,11 @@
 		if (err)
 			return err;
 
+		/* we are setting our register to something new, we need to
+		 * reset its range values.
+		 */
+		reset_reg_range_values(regs, insn->dst_reg);
+
 		if (BPF_SRC(insn->code) == BPF_X) {
 			if (BPF_CLASS(insn->code) == BPF_ALU64) {
 				/* case: R1 = R2
@@ -1500,6 +1653,8 @@
 			 */
 			regs[insn->dst_reg].type = CONST_IMM;
 			regs[insn->dst_reg].imm = insn->imm;
+			regs[insn->dst_reg].max_value = insn->imm;
+			regs[insn->dst_reg].min_value = insn->imm;
 		}
 
 	} else if (opcode > BPF_END) {
@@ -1552,6 +1707,9 @@
 
 		dst_reg = &regs[insn->dst_reg];
 
+		/* first we want to adjust our ranges. */
+		adjust_reg_min_max_vals(env, insn);
+
 		/* pattern match 'bpf_add Rx, imm' instruction */
 		if (opcode == BPF_ADD && BPF_CLASS(insn->code) == BPF_ALU64 &&
 		    dst_reg->type == FRAME_PTR && BPF_SRC(insn->code) == BPF_K) {
@@ -1586,8 +1744,17 @@
 			return -EACCES;
 		}
 
-		/* mark dest operand */
-		mark_reg_unknown_value(regs, insn->dst_reg);
+		/* If we did pointer math on a map value then just set it to our
+		 * PTR_TO_MAP_VALUE_ADJ type so we can deal with any stores or
+		 * loads to this register appropriately, otherwise just mark the
+		 * register as unknown.
+		 */
+		if (env->allow_ptr_leaks &&
+		    (dst_reg->type == PTR_TO_MAP_VALUE ||
+		     dst_reg->type == PTR_TO_MAP_VALUE_ADJ))
+			dst_reg->type = PTR_TO_MAP_VALUE_ADJ;
+		else
+			mark_reg_unknown_value(regs, insn->dst_reg);
 	}
 
 	return 0;
@@ -1642,6 +1809,104 @@
 	}
 }
 
+/* Adjusts the register min/max values in the case that the dst_reg is the
+ * variable register that we are working on, and src_reg is a constant or we're
+ * simply doing a BPF_K check.
+ */
+static void reg_set_min_max(struct bpf_reg_state *true_reg,
+			    struct bpf_reg_state *false_reg, u64 val,
+			    u8 opcode)
+{
+	switch (opcode) {
+	case BPF_JEQ:
+		/* If this is false then we know nothing Jon Snow, but if it is
+		 * true then we know for sure.
+		 */
+		true_reg->max_value = true_reg->min_value = val;
+		break;
+	case BPF_JNE:
+		/* If this is true we know nothing Jon Snow, but if it is false
+		 * we know the value for sure;
+		 */
+		false_reg->max_value = false_reg->min_value = val;
+		break;
+	case BPF_JGT:
+		/* Unsigned comparison, the minimum value is 0. */
+		false_reg->min_value = 0;
+	case BPF_JSGT:
+		/* If this is false then we know the maximum val is val,
+		 * otherwise we know the min val is val+1.
+		 */
+		false_reg->max_value = val;
+		true_reg->min_value = val + 1;
+		break;
+	case BPF_JGE:
+		/* Unsigned comparison, the minimum value is 0. */
+		false_reg->min_value = 0;
+	case BPF_JSGE:
+		/* If this is false then we know the maximum value is val - 1,
+		 * otherwise we know the mimimum value is val.
+		 */
+		false_reg->max_value = val - 1;
+		true_reg->min_value = val;
+		break;
+	default:
+		break;
+	}
+
+	check_reg_overflow(false_reg);
+	check_reg_overflow(true_reg);
+}
+
+/* Same as above, but for the case that dst_reg is a CONST_IMM reg and src_reg
+ * is the variable reg.
+ */
+static void reg_set_min_max_inv(struct bpf_reg_state *true_reg,
+				struct bpf_reg_state *false_reg, u64 val,
+				u8 opcode)
+{
+	switch (opcode) {
+	case BPF_JEQ:
+		/* If this is false then we know nothing Jon Snow, but if it is
+		 * true then we know for sure.
+		 */
+		true_reg->max_value = true_reg->min_value = val;
+		break;
+	case BPF_JNE:
+		/* If this is true we know nothing Jon Snow, but if it is false
+		 * we know the value for sure;
+		 */
+		false_reg->max_value = false_reg->min_value = val;
+		break;
+	case BPF_JGT:
+		/* Unsigned comparison, the minimum value is 0. */
+		true_reg->min_value = 0;
+	case BPF_JSGT:
+		/*
+		 * If this is false, then the val is <= the register, if it is
+		 * true the register <= to the val.
+		 */
+		false_reg->min_value = val;
+		true_reg->max_value = val - 1;
+		break;
+	case BPF_JGE:
+		/* Unsigned comparison, the minimum value is 0. */
+		true_reg->min_value = 0;
+	case BPF_JSGE:
+		/* If this is false then constant < register, if it is true then
+		 * the register < constant.
+		 */
+		false_reg->min_value = val + 1;
+		true_reg->max_value = val;
+		break;
+	default:
+		break;
+	}
+
+	check_reg_overflow(false_reg);
+	check_reg_overflow(true_reg);
+}
+
 static int check_cond_jmp_op(struct bpf_verifier_env *env,
 			     struct bpf_insn *insn, int *insn_idx)
 {
@@ -1708,6 +1973,23 @@
 	if (!other_branch)
 		return -EFAULT;
 
+	/* detect if we are comparing against a constant value so we can adjust
+	 * our min/max values for our dst register.
+	 */
+	if (BPF_SRC(insn->code) == BPF_X) {
+		if (regs[insn->src_reg].type == CONST_IMM)
+			reg_set_min_max(&other_branch->regs[insn->dst_reg],
+					dst_reg, regs[insn->src_reg].imm,
+					opcode);
+		else if (dst_reg->type == CONST_IMM)
+			reg_set_min_max_inv(&other_branch->regs[insn->src_reg],
+					    &regs[insn->src_reg], dst_reg->imm,
+					    opcode);
+	} else {
+		reg_set_min_max(&other_branch->regs[insn->dst_reg],
+					dst_reg, insn->imm, opcode);
+	}
+
 	/* detect if R == 0 where R is returned from bpf_map_lookup_elem() */
 	if (BPF_SRC(insn->code) == BPF_K &&
 	    insn->imm == 0 && (opcode == BPF_JEQ || opcode == BPF_JNE) &&
@@ -2144,7 +2426,8 @@
  * whereas register type in current state is meaningful, it means that
  * the current state will reach 'bpf_exit' instruction safely
  */
-static bool states_equal(struct bpf_verifier_state *old,
+static bool states_equal(struct bpf_verifier_env *env,
+			 struct bpf_verifier_state *old,
 			 struct bpf_verifier_state *cur)
 {
 	struct bpf_reg_state *rold, *rcur;
@@ -2157,6 +2440,13 @@
 		if (memcmp(rold, rcur, sizeof(*rold)) == 0)
 			continue;
 
+		/* If the ranges were not the same, but everything else was and
+		 * we didn't do a variable access into a map then we are a-ok.
+		 */
+		if (!env->varlen_map_value_access &&
+		    rold->type == rcur->type && rold->imm == rcur->imm)
+			continue;
+
 		if (rold->type == NOT_INIT ||
 		    (rold->type == UNKNOWN_VALUE && rcur->type != NOT_INIT))
 			continue;
@@ -2213,7 +2503,7 @@
 		return 0;
 
 	while (sl != STATE_LIST_MARK) {
-		if (states_equal(&sl->state, &env->cur_state))
+		if (states_equal(env, &sl->state, &env->cur_state))
 			/* reached equivalent register/stack state,
 			 * prune the search
 			 */
@@ -2259,6 +2549,7 @@
 
 	init_reg_state(regs);
 	insn_idx = 0;
+	env->varlen_map_value_access = false;
 	for (;;) {
 		struct bpf_insn *insn;
 		u8 class;
@@ -2339,6 +2630,7 @@
 			if (err)
 				return err;
 
+			reset_reg_range_values(regs, insn->dst_reg);
 			if (BPF_SIZE(insn->code) != BPF_W &&
 			    BPF_SIZE(insn->code) != BPF_DW) {
 				insn_idx++;
@@ -2509,6 +2801,7 @@
 				verbose("invalid BPF_LD mode\n");
 				return -EINVAL;
 			}
+			reset_reg_range_values(regs, insn->dst_reg);
 		} else {
 			verbose("unknown insn class %d\n", class);
 			return -EINVAL;