audit: Add typespecific uid and gid comparators

The audit filter code guarantees that uid are always compared with
uids and gids are always compared with gids, as the comparason
operations are type specific.  Take advantage of this proper to define
audit_uid_comparator and audit_gid_comparator which use the type safe
comparasons from uidgid.h.

Build on audit_uid_comparator and audit_gid_comparator and replace
audit_compare_id with audit_compare_uid and audit_compare_gid.  This
is one of those odd cases where being type safe and duplicating code
leads to simpler shorter and more concise code.

Don't allow bitmask operations in uid and gid comparisons in
audit_data_to_entry.  Bitmask operations are already denined in
audit_rule_to_entry.

Convert constants in audit_rule_to_entry and audit_data_to_entry into
kuids and kgids when appropriate.

Convert the uid and gid field in struct audit_names to be of type
kuid_t and kgid_t respectively, so that the new uid and gid comparators
can be applied in a type safe manner.

Cc: Al Viro <viro@zeniv.linux.org.uk>
Cc: Eric Paris <eparis@redhat.com>
Signed-off-by: "Eric W. Biederman" <ebiederm@xmission.com>
diff --git a/kernel/audit.h b/kernel/audit.h
index 8167668..4b428bb 100644
--- a/kernel/audit.h
+++ b/kernel/audit.h
@@ -76,6 +76,8 @@
 
 extern int audit_match_class(int class, unsigned syscall);
 extern int audit_comparator(const u32 left, const u32 op, const u32 right);
+extern int audit_uid_comparator(kuid_t left, u32 op, kuid_t right);
+extern int audit_gid_comparator(kgid_t left, u32 op, kgid_t right);
 extern int audit_compare_dname_path(const char *dname, const char *path,
 				    int *dirlen);
 extern struct sk_buff *	    audit_make_reply(int pid, int seq, int type,
diff --git a/kernel/auditfilter.c b/kernel/auditfilter.c
index e242dd9..b30320c 100644
--- a/kernel/auditfilter.c
+++ b/kernel/auditfilter.c
@@ -342,6 +342,8 @@
 
 		f->type = rule->fields[i] & ~(AUDIT_NEGATE|AUDIT_OPERATORS);
 		f->val = rule->values[i];
+		f->uid = INVALID_UID;
+		f->gid = INVALID_GID;
 
 		err = -EINVAL;
 		if (f->op == Audit_bad)
@@ -350,16 +352,32 @@
 		switch(f->type) {
 		default:
 			goto exit_free;
-		case AUDIT_PID:
 		case AUDIT_UID:
 		case AUDIT_EUID:
 		case AUDIT_SUID:
 		case AUDIT_FSUID:
+		case AUDIT_LOGINUID:
+			/* bit ops not implemented for uid comparisons */
+			if (f->op == Audit_bitmask || f->op == Audit_bittest)
+				goto exit_free;
+
+			f->uid = make_kuid(current_user_ns(), f->val);
+			if (!uid_valid(f->uid))
+				goto exit_free;
+			break;
 		case AUDIT_GID:
 		case AUDIT_EGID:
 		case AUDIT_SGID:
 		case AUDIT_FSGID:
-		case AUDIT_LOGINUID:
+			/* bit ops not implemented for gid comparisons */
+			if (f->op == Audit_bitmask || f->op == Audit_bittest)
+				goto exit_free;
+
+			f->gid = make_kgid(current_user_ns(), f->val);
+			if (!gid_valid(f->gid))
+				goto exit_free;
+			break;
+		case AUDIT_PID:
 		case AUDIT_PERS:
 		case AUDIT_MSGTYPE:
 		case AUDIT_PPID:
@@ -437,19 +455,39 @@
 
 		f->type = data->fields[i];
 		f->val = data->values[i];
+		f->uid = INVALID_UID;
+		f->gid = INVALID_GID;
 		f->lsm_str = NULL;
 		f->lsm_rule = NULL;
 		switch(f->type) {
-		case AUDIT_PID:
 		case AUDIT_UID:
 		case AUDIT_EUID:
 		case AUDIT_SUID:
 		case AUDIT_FSUID:
+		case AUDIT_LOGINUID:
+		case AUDIT_OBJ_UID:
+			/* bit ops not implemented for uid comparisons */
+			if (f->op == Audit_bitmask || f->op == Audit_bittest)
+				goto exit_free;
+
+			f->uid = make_kuid(current_user_ns(), f->val);
+			if (!uid_valid(f->uid))
+				goto exit_free;
+			break;
 		case AUDIT_GID:
 		case AUDIT_EGID:
 		case AUDIT_SGID:
 		case AUDIT_FSGID:
-		case AUDIT_LOGINUID:
+		case AUDIT_OBJ_GID:
+			/* bit ops not implemented for gid comparisons */
+			if (f->op == Audit_bitmask || f->op == Audit_bittest)
+				goto exit_free;
+
+			f->gid = make_kgid(current_user_ns(), f->val);
+			if (!gid_valid(f->gid))
+				goto exit_free;
+			break;
+		case AUDIT_PID:
 		case AUDIT_PERS:
 		case AUDIT_MSGTYPE:
 		case AUDIT_PPID:
@@ -461,8 +499,6 @@
 		case AUDIT_ARG1:
 		case AUDIT_ARG2:
 		case AUDIT_ARG3:
-		case AUDIT_OBJ_UID:
-		case AUDIT_OBJ_GID:
 			break;
 		case AUDIT_ARCH:
 			entry->rule.arch_f = f;
@@ -707,6 +743,23 @@
 			if (strcmp(a->filterkey, b->filterkey))
 				return 1;
 			break;
+		case AUDIT_UID:
+		case AUDIT_EUID:
+		case AUDIT_SUID:
+		case AUDIT_FSUID:
+		case AUDIT_LOGINUID:
+		case AUDIT_OBJ_UID:
+			if (!uid_eq(a->fields[i].uid, b->fields[i].uid))
+				return 1;
+			break;
+		case AUDIT_GID:
+		case AUDIT_EGID:
+		case AUDIT_SGID:
+		case AUDIT_FSGID:
+		case AUDIT_OBJ_GID:
+			if (!gid_eq(a->fields[i].gid, b->fields[i].gid))
+				return 1;
+			break;
 		default:
 			if (a->fields[i].val != b->fields[i].val)
 				return 1;
@@ -1198,6 +1251,52 @@
 	}
 }
 
+int audit_uid_comparator(kuid_t left, u32 op, kuid_t right)
+{
+	switch (op) {
+	case Audit_equal:
+		return uid_eq(left, right);
+	case Audit_not_equal:
+		return !uid_eq(left, right);
+	case Audit_lt:
+		return uid_lt(left, right);
+	case Audit_le:
+		return uid_lte(left, right);
+	case Audit_gt:
+		return uid_gt(left, right);
+	case Audit_ge:
+		return uid_gte(left, right);
+	case Audit_bitmask:
+	case Audit_bittest:
+	default:
+		BUG();
+		return 0;
+	}
+}
+
+int audit_gid_comparator(kgid_t left, u32 op, kgid_t right)
+{
+	switch (op) {
+	case Audit_equal:
+		return gid_eq(left, right);
+	case Audit_not_equal:
+		return !gid_eq(left, right);
+	case Audit_lt:
+		return gid_lt(left, right);
+	case Audit_le:
+		return gid_lte(left, right);
+	case Audit_gt:
+		return gid_gt(left, right);
+	case Audit_ge:
+		return gid_gte(left, right);
+	case Audit_bitmask:
+	case Audit_bittest:
+	default:
+		BUG();
+		return 0;
+	}
+}
+
 /* Compare given dentry name with last component in given path,
  * return of 0 indicates a match. */
 int audit_compare_dname_path(const char *dname, const char *path,
@@ -1251,14 +1350,14 @@
 			result = audit_comparator(task_pid_vnr(current), f->op, f->val);
 			break;
 		case AUDIT_UID:
-			result = audit_comparator(current_uid(), f->op, f->val);
+			result = audit_uid_comparator(current_uid(), f->op, f->uid);
 			break;
 		case AUDIT_GID:
-			result = audit_comparator(current_gid(), f->op, f->val);
+			result = audit_gid_comparator(current_gid(), f->op, f->gid);
 			break;
 		case AUDIT_LOGINUID:
-			result = audit_comparator(audit_get_loginuid(current),
-						  f->op, f->val);
+			result = audit_uid_comparator(audit_get_loginuid(current),
+						  f->op, f->uid);
 			break;
 		case AUDIT_SUBJ_USER:
 		case AUDIT_SUBJ_ROLE:
diff --git a/kernel/auditsc.c b/kernel/auditsc.c
index 4b96415..0b5b8a2 100644
--- a/kernel/auditsc.c
+++ b/kernel/auditsc.c
@@ -113,8 +113,8 @@
 	unsigned long	ino;
 	dev_t		dev;
 	umode_t		mode;
-	uid_t		uid;
-	gid_t		gid;
+	kuid_t		uid;
+	kgid_t		gid;
 	dev_t		rdev;
 	u32		osid;
 	struct audit_cap_data fcap;
@@ -464,37 +464,47 @@
 	return 0;
 }
 
-static int audit_compare_id(uid_t uid1,
-			    struct audit_names *name,
-			    unsigned long name_offset,
-			    struct audit_field *f,
-			    struct audit_context *ctx)
+static int audit_compare_uid(kuid_t uid,
+			     struct audit_names *name,
+			     struct audit_field *f,
+			     struct audit_context *ctx)
 {
 	struct audit_names *n;
-	unsigned long addr;
-	uid_t uid2;
 	int rc;
-
-	BUILD_BUG_ON(sizeof(uid_t) != sizeof(gid_t));
-
+ 
 	if (name) {
-		addr = (unsigned long)name;
-		addr += name_offset;
-
-		uid2 = *(uid_t *)addr;
-		rc = audit_comparator(uid1, f->op, uid2);
+		rc = audit_uid_comparator(uid, f->op, name->uid);
 		if (rc)
 			return rc;
 	}
-
+ 
 	if (ctx) {
 		list_for_each_entry(n, &ctx->names_list, list) {
-			addr = (unsigned long)n;
-			addr += name_offset;
+			rc = audit_uid_comparator(uid, f->op, n->uid);
+			if (rc)
+				return rc;
+		}
+	}
+	return 0;
+}
 
-			uid2 = *(uid_t *)addr;
-
-			rc = audit_comparator(uid1, f->op, uid2);
+static int audit_compare_gid(kgid_t gid,
+			     struct audit_names *name,
+			     struct audit_field *f,
+			     struct audit_context *ctx)
+{
+	struct audit_names *n;
+	int rc;
+ 
+	if (name) {
+		rc = audit_gid_comparator(gid, f->op, name->gid);
+		if (rc)
+			return rc;
+	}
+ 
+	if (ctx) {
+		list_for_each_entry(n, &ctx->names_list, list) {
+			rc = audit_gid_comparator(gid, f->op, n->gid);
 			if (rc)
 				return rc;
 		}
@@ -511,80 +521,62 @@
 	switch (f->val) {
 	/* process to file object comparisons */
 	case AUDIT_COMPARE_UID_TO_OBJ_UID:
-		return audit_compare_id(cred->uid,
-					name, offsetof(struct audit_names, uid),
-					f, ctx);
+		return audit_compare_uid(cred->uid, name, f, ctx);
 	case AUDIT_COMPARE_GID_TO_OBJ_GID:
-		return audit_compare_id(cred->gid,
-					name, offsetof(struct audit_names, gid),
-					f, ctx);
+		return audit_compare_gid(cred->gid, name, f, ctx);
 	case AUDIT_COMPARE_EUID_TO_OBJ_UID:
-		return audit_compare_id(cred->euid,
-					name, offsetof(struct audit_names, uid),
-					f, ctx);
+		return audit_compare_uid(cred->euid, name, f, ctx);
 	case AUDIT_COMPARE_EGID_TO_OBJ_GID:
-		return audit_compare_id(cred->egid,
-					name, offsetof(struct audit_names, gid),
-					f, ctx);
+		return audit_compare_gid(cred->egid, name, f, ctx);
 	case AUDIT_COMPARE_AUID_TO_OBJ_UID:
-		return audit_compare_id(tsk->loginuid,
-					name, offsetof(struct audit_names, uid),
-					f, ctx);
+		return audit_compare_uid(tsk->loginuid, name, f, ctx);
 	case AUDIT_COMPARE_SUID_TO_OBJ_UID:
-		return audit_compare_id(cred->suid,
-					name, offsetof(struct audit_names, uid),
-					f, ctx);
+		return audit_compare_uid(cred->suid, name, f, ctx);
 	case AUDIT_COMPARE_SGID_TO_OBJ_GID:
-		return audit_compare_id(cred->sgid,
-					name, offsetof(struct audit_names, gid),
-					f, ctx);
+		return audit_compare_gid(cred->sgid, name, f, ctx);
 	case AUDIT_COMPARE_FSUID_TO_OBJ_UID:
-		return audit_compare_id(cred->fsuid,
-					name, offsetof(struct audit_names, uid),
-					f, ctx);
+		return audit_compare_uid(cred->fsuid, name, f, ctx);
 	case AUDIT_COMPARE_FSGID_TO_OBJ_GID:
-		return audit_compare_id(cred->fsgid,
-					name, offsetof(struct audit_names, gid),
-					f, ctx);
+		return audit_compare_gid(cred->fsgid, name, f, ctx);
 	/* uid comparisons */
 	case AUDIT_COMPARE_UID_TO_AUID:
-		return audit_comparator(cred->uid, f->op, tsk->loginuid);
+		return audit_uid_comparator(cred->uid, f->op, tsk->loginuid);
 	case AUDIT_COMPARE_UID_TO_EUID:
-		return audit_comparator(cred->uid, f->op, cred->euid);
+		return audit_uid_comparator(cred->uid, f->op, cred->euid);
 	case AUDIT_COMPARE_UID_TO_SUID:
-		return audit_comparator(cred->uid, f->op, cred->suid);
+		return audit_uid_comparator(cred->uid, f->op, cred->suid);
 	case AUDIT_COMPARE_UID_TO_FSUID:
-		return audit_comparator(cred->uid, f->op, cred->fsuid);
+		return audit_uid_comparator(cred->uid, f->op, cred->fsuid);
 	/* auid comparisons */
 	case AUDIT_COMPARE_AUID_TO_EUID:
-		return audit_comparator(tsk->loginuid, f->op, cred->euid);
+		return audit_uid_comparator(tsk->loginuid, f->op, cred->euid);
 	case AUDIT_COMPARE_AUID_TO_SUID:
-		return audit_comparator(tsk->loginuid, f->op, cred->suid);
+		return audit_uid_comparator(tsk->loginuid, f->op, cred->suid);
 	case AUDIT_COMPARE_AUID_TO_FSUID:
-		return audit_comparator(tsk->loginuid, f->op, cred->fsuid);
+		return audit_uid_comparator(tsk->loginuid, f->op, cred->fsuid);
 	/* euid comparisons */
 	case AUDIT_COMPARE_EUID_TO_SUID:
-		return audit_comparator(cred->euid, f->op, cred->suid);
+		return audit_uid_comparator(cred->euid, f->op, cred->suid);
 	case AUDIT_COMPARE_EUID_TO_FSUID:
-		return audit_comparator(cred->euid, f->op, cred->fsuid);
+		return audit_uid_comparator(cred->euid, f->op, cred->fsuid);
 	/* suid comparisons */
 	case AUDIT_COMPARE_SUID_TO_FSUID:
-		return audit_comparator(cred->suid, f->op, cred->fsuid);
+		return audit_uid_comparator(cred->suid, f->op, cred->fsuid);
 	/* gid comparisons */
 	case AUDIT_COMPARE_GID_TO_EGID:
-		return audit_comparator(cred->gid, f->op, cred->egid);
+		return audit_gid_comparator(cred->gid, f->op, cred->egid);
 	case AUDIT_COMPARE_GID_TO_SGID:
-		return audit_comparator(cred->gid, f->op, cred->sgid);
+		return audit_gid_comparator(cred->gid, f->op, cred->sgid);
 	case AUDIT_COMPARE_GID_TO_FSGID:
-		return audit_comparator(cred->gid, f->op, cred->fsgid);
+		return audit_gid_comparator(cred->gid, f->op, cred->fsgid);
 	/* egid comparisons */
 	case AUDIT_COMPARE_EGID_TO_SGID:
-		return audit_comparator(cred->egid, f->op, cred->sgid);
+		return audit_gid_comparator(cred->egid, f->op, cred->sgid);
 	case AUDIT_COMPARE_EGID_TO_FSGID:
-		return audit_comparator(cred->egid, f->op, cred->fsgid);
+		return audit_gid_comparator(cred->egid, f->op, cred->fsgid);
 	/* sgid comparison */
 	case AUDIT_COMPARE_SGID_TO_FSGID:
-		return audit_comparator(cred->sgid, f->op, cred->fsgid);
+		return audit_gid_comparator(cred->sgid, f->op, cred->fsgid);
 	default:
 		WARN(1, "Missing AUDIT_COMPARE define.  Report as a bug\n");
 		return 0;
@@ -630,28 +622,28 @@
 			}
 			break;
 		case AUDIT_UID:
-			result = audit_comparator(cred->uid, f->op, f->val);
+			result = audit_uid_comparator(cred->uid, f->op, f->uid);
 			break;
 		case AUDIT_EUID:
-			result = audit_comparator(cred->euid, f->op, f->val);
+			result = audit_uid_comparator(cred->euid, f->op, f->uid);
 			break;
 		case AUDIT_SUID:
-			result = audit_comparator(cred->suid, f->op, f->val);
+			result = audit_uid_comparator(cred->suid, f->op, f->uid);
 			break;
 		case AUDIT_FSUID:
-			result = audit_comparator(cred->fsuid, f->op, f->val);
+			result = audit_uid_comparator(cred->fsuid, f->op, f->uid);
 			break;
 		case AUDIT_GID:
-			result = audit_comparator(cred->gid, f->op, f->val);
+			result = audit_gid_comparator(cred->gid, f->op, f->gid);
 			break;
 		case AUDIT_EGID:
-			result = audit_comparator(cred->egid, f->op, f->val);
+			result = audit_gid_comparator(cred->egid, f->op, f->gid);
 			break;
 		case AUDIT_SGID:
-			result = audit_comparator(cred->sgid, f->op, f->val);
+			result = audit_gid_comparator(cred->sgid, f->op, f->gid);
 			break;
 		case AUDIT_FSGID:
-			result = audit_comparator(cred->fsgid, f->op, f->val);
+			result = audit_gid_comparator(cred->fsgid, f->op, f->gid);
 			break;
 		case AUDIT_PERS:
 			result = audit_comparator(tsk->personality, f->op, f->val);
@@ -717,10 +709,10 @@
 			break;
 		case AUDIT_OBJ_UID:
 			if (name) {
-				result = audit_comparator(name->uid, f->op, f->val);
+				result = audit_uid_comparator(name->uid, f->op, f->uid);
 			} else if (ctx) {
 				list_for_each_entry(n, &ctx->names_list, list) {
-					if (audit_comparator(n->uid, f->op, f->val)) {
+					if (audit_uid_comparator(n->uid, f->op, f->uid)) {
 						++result;
 						break;
 					}
@@ -729,10 +721,10 @@
 			break;
 		case AUDIT_OBJ_GID:
 			if (name) {
-				result = audit_comparator(name->gid, f->op, f->val);
+				result = audit_gid_comparator(name->gid, f->op, f->gid);
 			} else if (ctx) {
 				list_for_each_entry(n, &ctx->names_list, list) {
-					if (audit_comparator(n->gid, f->op, f->val)) {
+					if (audit_gid_comparator(n->gid, f->op, f->gid)) {
 						++result;
 						break;
 					}
@@ -750,7 +742,7 @@
 		case AUDIT_LOGINUID:
 			result = 0;
 			if (ctx)
-				result = audit_comparator(tsk->loginuid, f->op, f->val);
+				result = audit_uid_comparator(tsk->loginuid, f->op, f->uid);
 			break;
 		case AUDIT_SUBJ_USER:
 		case AUDIT_SUBJ_ROLE: