/*	$NetBSD: gen_gr.c,v 1.1.1.2 2012/09/09 16:07:58 christos Exp $	*/

/*
 * Copyright (c) 2004 by Internet Systems Consortium, Inc. ("ISC")
 * Copyright (c) 1996-1999 by Internet Software Consortium.
 *
 * Permission to use, copy, modify, and distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS.  IN NO EVENT SHALL ISC BE LIABLE FOR
 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
 * OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 */

#if !defined(LINT) && !defined(CODECENTER)
static const char rcsid[] = "Id: gen_gr.c,v 1.8 2005/04/27 04:56:23 sra Exp ";
#endif

/* Imports */

#include "port_before.h"

#ifndef WANT_IRS_GR
static int __bind_irs_gr_unneeded;
#else

#include <sys/types.h>

#include <isc/assertions.h>
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <netinet/in.h>
#include <arpa/nameser.h>
#include <resolv.h>

#include <isc/memcluster.h>
#include <irs.h>

#include "port_after.h"

#include "irs_p.h"
#include "gen_p.h"

/* Definitions */

struct pvt {
	struct irs_rule *	rules;
	struct irs_rule *	rule;
	struct irs_gr *		gr;
	/*
	 * Need space to store the entries read from the group file.
	 * The members list also needs space per member, and the
	 * strings making up the user names must be allocated
	 * somewhere.  Rather than doing lots of small allocations,
	 * we keep one buffer and resize it as needed.
	 */
	struct group		group;
	size_t			nmemb;    /*%< Malloc'd max index of gr_mem[]. */
	char *			membuf;
	size_t			membufsize;
	struct __res_state *	res;
	void			(*free_res)(void *);
};

/* Forward */

static void		gr_close(struct irs_gr *);
static struct group *	gr_next(struct irs_gr *);
static struct group *	gr_byname(struct irs_gr *, const char *);
static struct group *	gr_bygid(struct irs_gr *, gid_t);
static void		gr_rewind(struct irs_gr *);
static int		gr_list(struct irs_gr *, const char *,
				gid_t, gid_t *, int *);
static void		gr_minimize(struct irs_gr *);
static struct __res_state * gr_res_get(struct irs_gr *);
static void		gr_res_set(struct irs_gr *,
				      struct __res_state *,
				      void (*)(void *));

static int		grmerge(struct irs_gr *gr, const struct group *src,
				int preserve);

static int		countvec(char **vec);
static int		isnew(char **old, char *new);
static int		countnew(char **old, char **new);
static size_t		sizenew(char **old, char **new);
static int		newgid(int, gid_t *, gid_t);

/* Macros */

#define FREE_IF(x) do { if ((x) != NULL) { free(x); (x) = NULL; } } while (0)

/* Public */

struct irs_gr *
irs_gen_gr(struct irs_acc *this) {
	struct gen_p *accpvt = (struct gen_p *)this->private;
	struct irs_gr *gr;
	struct pvt *pvt;

	if (!(gr = memget(sizeof *gr))) {
		errno = ENOMEM;
		return (NULL);
	}
	memset(gr, 0x5e, sizeof *gr);
	if (!(pvt = memget(sizeof *pvt))) {
		memput(gr, sizeof *gr);
		errno = ENOMEM;
		return (NULL);
	}
	memset(pvt, 0, sizeof *pvt);
	pvt->rules = accpvt->map_rules[irs_gr];
	pvt->rule = pvt->rules;
	gr->private = pvt;
	gr->close = gr_close;
	gr->next = gr_next;
	gr->byname = gr_byname;
	gr->bygid = gr_bygid;
	gr->rewind = gr_rewind;
	gr->list = gr_list;
	gr->minimize = gr_minimize;
	gr->res_get = gr_res_get;
	gr->res_set = gr_res_set;
	return (gr);
}

/* Methods. */

static void
gr_close(struct irs_gr *this) {
	struct pvt *pvt = (struct pvt *)this->private;

	memput(pvt, sizeof *pvt);
	memput(this, sizeof *this);
}

static struct group *
gr_next(struct irs_gr *this) {
	struct pvt *pvt = (struct pvt *)this->private;
	struct group *rval;
	struct irs_gr *gr;

	while (pvt->rule) {
		gr = pvt->rule->inst->gr;
		rval = (*gr->next)(gr);
		if (rval)
			return (rval);
		if (!(pvt->rule->flags & IRS_CONTINUE))
			break;
		pvt->rule = pvt->rule->next;
		if (pvt->rule) {
			gr = pvt->rule->inst->gr;
			(*gr->rewind)(gr);
		}
	}
	return (NULL);
}

static struct group *
gr_byname(struct irs_gr *this, const char *name) {
	struct pvt *pvt = (struct pvt *)this->private;
	struct irs_rule *rule;
	struct group *tval;
	struct irs_gr *gr;
	int dirty;

	dirty = 0;
	for (rule = pvt->rules; rule; rule = rule->next) {
		gr = rule->inst->gr;
		tval = (*gr->byname)(gr, name);
		if (tval) {
			if (!grmerge(this, tval, dirty++))
				return (NULL);
			if (!(rule->flags & IRS_MERGE))
				break;
		} else {
			if (!(rule->flags & IRS_CONTINUE))
				break;
		}
	}
	if (dirty)
		return (&pvt->group);
	return (NULL);
}

static struct group *
gr_bygid(struct irs_gr *this, gid_t gid) {
	struct pvt *pvt = (struct pvt *)this->private;
	struct irs_rule *rule;
	struct group *tval;
	struct irs_gr *gr;
	int dirty;

	dirty = 0;
	for (rule = pvt->rules; rule; rule = rule->next) {
		gr = rule->inst->gr;
		tval = (*gr->bygid)(gr, gid);
		if (tval) {
			if (!grmerge(this, tval, dirty++))
				return (NULL);
			if (!(rule->flags & IRS_MERGE))
				break;
		} else {
			if (!(rule->flags & IRS_CONTINUE))
				break;
		}
	}
	if (dirty)
		return (&pvt->group);
	return (NULL);
}

static void
gr_rewind(struct irs_gr *this) {
	struct pvt *pvt = (struct pvt *)this->private;
	struct irs_gr *gr;

	pvt->rule = pvt->rules;
	if (pvt->rule) {
		gr = pvt->rule->inst->gr;
		(*gr->rewind)(gr);
	}
}

static int
gr_list(struct irs_gr *this, const char *name,
	gid_t basegid, gid_t *groups, int *ngroups)
{
	struct pvt *pvt = (struct pvt *)this->private;
	struct irs_rule *rule;
	struct irs_gr *gr;
	int t_ngroups, maxgroups;
	gid_t *t_groups;
	int n, t, rval = 0;

	maxgroups = *ngroups;
	*ngroups = 0;
	t_groups = (gid_t *)malloc(maxgroups * sizeof(gid_t));
	if (!t_groups) {
		errno = ENOMEM;
		return (-1);
	}

	for (rule = pvt->rules; rule; rule = rule->next) {
		t_ngroups = maxgroups;
		gr = rule->inst->gr;
		t = (*gr->list)(gr, name, basegid, t_groups, &t_ngroups);
		for (n = 0; n < t_ngroups; n++) {
			if (newgid(*ngroups, groups, t_groups[n])) {
				if (*ngroups == maxgroups) {
					rval = -1;
					goto done;
				}
				groups[(*ngroups)++] = t_groups[n];
			}
		}
		if (t == 0) {
			if (!(rule->flags & IRS_MERGE))
				break;
		} else {
			if (!(rule->flags & IRS_CONTINUE))
				break;
		}
	}
 done:
	free(t_groups);
	return (rval);
}

static void
gr_minimize(struct irs_gr *this) {
	struct pvt *pvt = (struct pvt *)this->private;
	struct irs_rule *rule;

	for (rule = pvt->rules; rule != NULL; rule = rule->next) {
		struct irs_gr *gr = rule->inst->gr;

		(*gr->minimize)(gr);
	}
}

static struct __res_state *
gr_res_get(struct irs_gr *this) {
	struct pvt *pvt = (struct pvt *)this->private;

	if (!pvt->res) {
		struct __res_state *res;
		res = (struct __res_state *)malloc(sizeof *res);
		if (!res) {
			errno = ENOMEM;
			return (NULL);
		}
		memset(res, 0, sizeof *res);
		gr_res_set(this, res, free);
	}

	return (pvt->res);
}

static void
gr_res_set(struct irs_gr *this, struct __res_state *res,
		void (*free_res)(void *)) {
	struct pvt *pvt = (struct pvt *)this->private;
	struct irs_rule *rule;

	if (pvt->res && pvt->free_res) {
		res_nclose(pvt->res);
		(*pvt->free_res)(pvt->res);
	}

	pvt->res = res;
	pvt->free_res = free_res;

	for (rule = pvt->rules; rule != NULL; rule = rule->next) {
		struct irs_gr *gr = rule->inst->gr;

		if (gr->res_set)
			(*gr->res_set)(gr, pvt->res, NULL);
	}
}

/* Private. */

static int
grmerge(struct irs_gr *this, const struct group *src, int preserve) {
	struct pvt *pvt = (struct pvt *)this->private;
	char *cp, **m, **p, *oldmembuf, *ep;
	int n, ndst, nnew;
	size_t used;

	if (!preserve) {
		pvt->group.gr_gid = src->gr_gid;
		if (pvt->nmemb < 1) {
			m = malloc(sizeof *m);
			if (m == NULL) {
				/* No harm done, no work done. */
				return (0);
			}
			pvt->group.gr_mem = m;
			pvt->nmemb = 1;
		}
		pvt->group.gr_mem[0] = NULL;
	}
	ndst = countvec(pvt->group.gr_mem);
	nnew = countnew(pvt->group.gr_mem, src->gr_mem);

	/*
	 * Make sure destination member array is large enough.
	 * p points to new portion.
	 */
	n = ndst + nnew + 1;
	if ((size_t)n > pvt->nmemb) {
		m = realloc(pvt->group.gr_mem, n * sizeof *m);
		if (m == NULL) {
			/* No harm done, no work done. */
			return (0);
		}
		pvt->group.gr_mem = m;
		pvt->nmemb = n;
	}
	p = pvt->group.gr_mem + ndst;

	/*
	 * Enlarge destination membuf; cp points at new portion.
	 */
	n = sizenew(pvt->group.gr_mem, src->gr_mem);
	INSIST((nnew == 0) == (n == 0));
	if (!preserve) {
		n += strlen(src->gr_name) + 1;
		n += strlen(src->gr_passwd) + 1;
	}
	if (n == 0) {
		/* No work to do. */
		return (1);
	}
	used = preserve ? pvt->membufsize : 0;
	cp = malloc(used + n);
	if (cp == NULL) {
		/* No harm done, no work done. */
		return (0);
	}
	ep = cp + used + n;
	if (used != 0)
		memcpy(cp, pvt->membuf, used);
	oldmembuf = pvt->membuf;
	pvt->membuf = cp;
	pvt->membufsize = used + n;
	cp += used;

	/*
	 * Adjust group.gr_mem.
	 */
	if (pvt->membuf != oldmembuf)
		for (m = pvt->group.gr_mem; *m; m++)
			*m = pvt->membuf + (*m - oldmembuf);

	/*
	 * Add new elements.
	 */
	for (m = src->gr_mem; *m; m++)
		if (isnew(pvt->group.gr_mem, *m)) {
			*p++ = cp;
			*p = NULL;
			n = strlen(*m) + 1;
			if (n > ep - cp) {
				FREE_IF(oldmembuf);
				return (0);
			}
			strcpy(cp, *m);		/* (checked) */
			cp += n;
		}
	if (preserve) {
		pvt->group.gr_name = pvt->membuf + 
				     (pvt->group.gr_name - oldmembuf);
		pvt->group.gr_passwd = pvt->membuf + 
				       (pvt->group.gr_passwd - oldmembuf);
	} else {
		pvt->group.gr_name = cp;
		n = strlen(src->gr_name) + 1;
		if (n > ep - cp) {
			FREE_IF(oldmembuf);
			return (0);
		}
		strcpy(cp, src->gr_name);	/* (checked) */
		cp += n;

		pvt->group.gr_passwd = cp;
		n = strlen(src->gr_passwd) + 1;
		if (n > ep - cp) {
			FREE_IF(oldmembuf);
			return (0);
		}
		strcpy(cp, src->gr_passwd);	/* (checked) */
		cp += n;
	}
	FREE_IF(oldmembuf);
	INSIST(cp >= pvt->membuf && cp <= &pvt->membuf[pvt->membufsize]);
	return (1);
}

static int
countvec(char **vec) {
	int n = 0;

	while (*vec++)
		n++;
	return (n);
}

static int
isnew(char **old, char *new) {
	for (; *old; old++)
		if (strcmp(*old, new) == 0)
			return (0);
	return (1);
}

static int
countnew(char **old, char **new) {
	int n = 0;

	for (; *new; new++)
		n += isnew(old, *new);
	return (n);
}

static size_t
sizenew(char **old, char **new) {
	size_t n = 0;

	for (; *new; new++)
		if (isnew(old, *new))
			n += strlen(*new) + 1;
	return (n);
}

static int
newgid(int ngroups, gid_t *groups, gid_t group) {
	ngroups--, groups++;
	for (; ngroups-- > 0; groups++)
		if (*groups == group)
			return (0);
	return (1);
}

#endif /* WANT_IRS_GR */
/*! \file */