/*
 * Install a module in the kernel.
 * Modified by Jon Tombs.
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <a.out.h>
#include <sys/stat.h>
#include <sys/utsname.h>
#include <linux/unistd.h>
#include <linux/module.h>

/*
 * This is here as syscall.h and sys/syscall.h redefine the defines in
 * unistd.h why doesn't unistd #include them?
 */

extern int syscall(int, ...);


struct symbol {
	struct nlist n;
	struct symbol *child[2];
};

struct exec header;
char *textseg;
char *dataseg;
struct symbol *symroot;
int nsymbols;
struct symbol *symtab;
char *stringtab;
unsigned long addr;
struct kernel_sym *ksymtab;
int nksyms;
struct utsname uts_info;

void relocate(char *, int, int, FILE *);
void defsym(const char *, unsigned long, int);
struct symbol *findsym(const char *, struct symbol *);
unsigned long looksym(const char *);
void *ckalloc(size_t);

static int create_module(const char *name, unsigned long size);
static int init_module(const char *, void *, unsigned, struct mod_routines *);
static int delete_module(const char *);
static int get_kernel_syms(struct kernel_sym *);

int
main(int argc, char **argv) {
	char *filename;
	char *modname;
	int progsize;
	char *p;
	int pos;
	int len;
	FILE *fp;
	long filesize;
	struct symbol *sp;
	int i;
	struct kernel_sym *ksym;
	unsigned long init_func, cleanup_func;
        unsigned long kernel_version;
	struct mod_routines routines;
	int fatal_error;



	if (argc != 2) {
		fputs("Usage: insmod module\n", stderr);
		exit(2);
	}
	filename = argv[1];

	/* construct the module name */
	if ((p = strrchr(filename, '/')) != NULL)
		p++;
	else
		p = filename;
	len = strlen(p);
	if (len > 2 && strcmp(p + len - 2, ".o") == 0)
		len -= 2;
	modname = (char*) ckalloc(len + 1);
	memcpy(modname, p, len);
	modname[len] = '\0';

	/* open file and read header */
	if ((fp = fopen(filename, "r")) == NULL) {
		fprintf(stderr, "Cannot open %s\n", filename);
		exit(2);
	}
	fread(&header, sizeof header, 1, fp);
	if (feof(fp) || ferror(fp)) {
		fprintf(stderr, "Could not read header of %s\n", filename);
		exit(2);
	}
	if (N_MAGIC(header) != OMAGIC) {
		fprintf(stderr, "%s: not an object file\n", filename);
		exit(2);
	}

	/* read in text and data segments */
	textseg = (char*) ckalloc(header.a_text + header.a_data);
	fread(textseg, header.a_text + header.a_data, 1, fp);
	if (feof(fp) || ferror(fp)) {
		fprintf(stderr, "Error reading %s\n", filename);
		exit(2);
	}
	dataseg = textseg + header.a_text;

	/* read in the symbol table */
	fseek(fp, 0L, SEEK_END);
	filesize = ftell(fp);
	fseek(fp, N_SYMOFF(header), SEEK_SET);
	nsymbols = header.a_syms / sizeof (struct nlist);
	symtab = (struct symbol*) ckalloc(nsymbols * sizeof (*symtab));
	for (i = nsymbols, sp = symtab ; --i >= 0 ; sp++) {
		fread(&sp->n, sizeof sp->n, 1, fp);
		sp->child[0] = sp->child[1] = NULL;
	}
	if (feof(fp) || ferror(fp)) {
		fprintf(stderr, "Error reading %s\n", filename);
		exit(2);
	}
	len = filesize - N_STROFF(header);
	stringtab = (char*) ckalloc(len);
	fread(stringtab, len, 1, fp);
	if (feof(fp) || ferror(fp)) {
		fprintf(stderr, "Error reading %s\n", filename);
		exit(2);
	}
	symroot = NULL;
	for (i = nsymbols, sp = symtab ; --i >= 0 ; sp++) {
		pos = sp->n.n_un.n_strx;
		if (pos < 0 || pos >= len) {
			fprintf(stderr, "Bad nlist entry\n");
			exit(2);
		}
		/* look up name and add sp to binary tree */
		findsym(stringtab + sp->n.n_un.n_strx, sp);
	}
        /* Now check if module and kernel version match
         */
        if ((i = uname( &uts_info)) != 0) {
          fprintf( stderr, "uname call failed with code %d\n", i);
          exit( 2);
        }
	kernel_version = looksym( "_kernel_version");
        if (strcmp( (char*) textseg + kernel_version, uts_info.release)) {
          fprintf( stderr,
	          "Error: module's `kernel_version' doesn't match the current kernel.\n%s\n%s\n",
                  "       Check the module is correct for the current kernel, then",
		  "       recompile the module with the correct `kernel_version'.");
          exit( 2);
        }

	/* get initialization and cleanup routines */
	init_func = looksym("_init_module");
	cleanup_func = looksym("_cleanup_module");
	/* get the kernel symbol table */
	nksyms = get_kernel_syms(NULL);

	if (nksyms <= 0) {
		fprintf(stderr, "get_kernel_sys failed: Cannot find Kernel symbols!\n");
		exit(2);
	}
	ksymtab = (struct kernel_sym *) ckalloc(nksyms * sizeof *ksymtab);
	if (get_kernel_syms(ksymtab) != nksyms) {
		fprintf(stderr, "Kernel symbol problem\n");
		exit(2);
	}

	/* bind undefined symbols */
	defsym("_mod_use_count_", 0 - sizeof (int), N_BSS | N_EXT);
	progsize = header.a_text + header.a_data + header.a_bss;
	ksym = ksymtab;
	for (i = nksyms ; --i >= 0 ; ksym++) {
		defsym(ksym->name, ksym->value, N_ABS | N_EXT);
	}

	/* allocate space for "common" symbols */
	/* and check for undefined symbols */
	fatal_error = 0;
	for (sp = symtab ; sp < symtab + nsymbols ; sp++) {
		if (sp->n.n_type == (N_UNDF | N_EXT) && sp->n.n_value != 0) {
			sp->n.n_type = N_BSS | N_EXT;
			len = sp->n.n_value;
			sp->n.n_value = progsize;
			progsize += len;
			progsize = (progsize + 3) &~ 03;
		} else if ((sp->n.n_type &~ N_EXT) == N_UNDF) {
			fprintf(stderr, "%s undefined\n", sp->n.n_un.n_name);
			fatal_error = 1;
		}
	}
	if (fatal_error)
		exit(2);

	/* create the module */
	errno = 0;
	/* We add "sizeof (int)" to skip over the use count */
	addr = create_module(modname, progsize) + sizeof (int);
	switch (errno) {
	case EEXIST:
		fprintf(stderr, "A module named %s already exists\n", modname);
		exit(2);
	case ENOMEM:
		fprintf(stderr, "Cannot allocate space for module\n");
		exit(2);
	case 0:
		break;
	default:
		perror("create_module");
		exit(2);
	}

	/* perform relocation */
	fseek(fp, N_TRELOFF(header), SEEK_SET);
	relocate(textseg, header.a_text, header.a_trsize, fp);
	relocate(dataseg, header.a_data, header.a_drsize, fp);
	init_func += addr;
	cleanup_func += addr;

	/* load the module into the kernel */
	routines.init = (int (*)(void)) init_func;
	routines.cleanup = (void (*)(void)) cleanup_func;
	if (init_module(modname, textseg, header.a_text + header.a_data, &routines) < 0) {
		if (errno == EBUSY) {
			fprintf(stderr, "Initialization of %s failed\n", modname);
		} else {
			perror("init_module");
		}
		delete_module(modname);
	}

#ifdef notdef
	init = looksym("_init_module");
	printf("seg = 0x%lx, init = 0x%lx\n", addr, init);
	dummy = (void (*)()) init;
	dummy();
#endif
	exit(0);
}


void
relocate(char *seg, int segsize, int len, FILE *fp) {
	struct relocation_info rel;
	unsigned long val;
	struct symbol *sp;

	while ((len -= sizeof rel) >= 0) {
		fread(&rel, sizeof rel, 1, fp);
#ifdef DEBUG
		printf("relocate %s:%u\n", seg == textseg? "text" : "data",
			rel.r_address);
#endif
		if (rel.r_address < 0 || rel.r_address >= segsize) {
			fprintf(stderr, "Bad relocation\n");
			exit(2);
		}
		if (rel.r_length != 2) {
			fprintf(stderr, "Unimplemented relocation:  r_length = %d\n", rel.r_length);
			exit(2);
		}
		val = * (long *) (seg + rel.r_address);
		if (rel.r_pcrel) {
			val -= addr;
#ifdef DEBUG
			printf("pc relative\n");
#endif
		}
		if (rel.r_extern) {
			if (rel.r_symbolnum >= nsymbols) {
				fprintf(stderr, "Bad relocation\n");
				exit(2);
			}
			sp = symtab + rel.r_symbolnum;
			val += sp->n.n_value;
			if ((sp->n.n_type &~ N_EXT) != N_ABS)
				val += addr;
		} else if (rel.r_symbolnum != N_ABS) {
			val += addr;
		}
		* (long *) (seg + rel.r_address) = val;
	}
}


void
defsym(const char *name, unsigned long value, int type) {
	struct symbol *sp;

	if ((sp = findsym(name, NULL)) != NULL) {
		if (sp->n.n_type != (N_UNDF | N_EXT)) {
			fprintf(stderr, "%s multiply defined\n", name);
			return;
		}
		sp->n.n_type = type;
		sp->n.n_value = value;
	}
}


/*
 * Look up an name in the symbol table.  If "add" is not null, add a
 * the entry to the table.  The table is stored as a splay tree.
 */
struct symbol *
findsym(const char *key, struct symbol *add) {
	struct symbol *left, *right;
	struct symbol **leftp, **rightp;
	struct symbol *sp1, *sp2, *sp3;
	int cmp;
	int path1, path2;

	if (add) {
		add->n.n_un.n_name = (char *)key;
		if ((add->n.n_type & N_EXT) == 0)
			return add;
	}
	sp1 = symroot;
	if (sp1 == NULL)
		return add? symroot = add : NULL;
	leftp = &left, rightp = &right;
	for (;;) {
		cmp = strncmp( sp1->n.n_un.n_name, key, SYM_MAX_NAME);
		if (cmp == 0)
			break;
		if (cmp > 0) {
			sp2 = sp1->child[0];
			path1 = 0;
		} else {
			sp2 = sp1->child[1];
			path1 = 1;
		}
		if (sp2 == NULL) {
			if (! add)
				break;
			sp2 = add;
		}
		cmp = strncmp( sp2->n.n_un.n_name, key, SYM_MAX_NAME);
		if (cmp == 0) {
one_level_only:
			if (path1 == 0) {	/* sp2 is left child of sp1 */
				*rightp = sp1;
				rightp = &sp1->child[0];
			} else {
				*leftp = sp1;
				leftp = &sp1->child[1];
			}
			sp1 = sp2;
			break;
		}
		if (cmp > 0) {
			sp3 = sp2->child[0];
			path2 = 0;
		} else {
			sp3 = sp2->child[1];
			path2 = 1;
		}
		if (sp3 == NULL) {
			if (! add)
				goto one_level_only;
			sp3 = add;
		}
		if (path1 == 0) {
			if (path2 == 0) {
				sp1->child[0] = sp2->child[1];
				sp2->child[1] = sp1;
				*rightp = sp2;
				rightp = &sp2->child[0];
			} else {
				*rightp = sp1;
				rightp = &sp1->child[0];
				*leftp = sp2;
				leftp = &sp2->child[1];
			}
		} else {
			if (path2 == 0) {
				*leftp = sp1;
				leftp = &sp1->child[1];
				*rightp = sp2;
				rightp = &sp2->child[0];
			} else {
				sp1->child[1] = sp2->child[0];
				sp2->child[0] = sp1;
				*leftp = sp2;
				leftp = &sp2->child[1];
			}
		}
		sp1 = sp3;
	}
	/*
	 * Now sp1 points to the result of the search.  If cmp is zero,
	 * we had a match; otherwise not.
	 */
	*leftp = sp1->child[0];
	*rightp = sp1->child[1];
	sp1->child[0] = left;
	sp1->child[1] = right;
	symroot = sp1;
	return cmp == 0? sp1 : NULL;
}


unsigned long
looksym(const char *name) {
	struct symbol *sp;

	sp = findsym(name, NULL);
	if (sp == NULL) {
		fprintf(stderr, "%s undefined\n", name);
		exit(2);
	}
	return sp->n.n_value;
}


void *
ckalloc(size_t nbytes) {
	void *p;

	if ((p = malloc(nbytes)) == NULL) {
		fputs("insmod:  malloc failed\n", stderr);
		exit(2);
	}
	return p;
}


static int create_module(const char *name, unsigned long size) {
	return syscall( __NR_create_module, name, size);
}

static int init_module(const char *name, void *code, unsigned codesize,
		struct mod_routines *routines) {
	return syscall( __NR_init_module, name, code, codesize, routines);
}

static int delete_module(const char *name) {
	return syscall( __NR_delete_module, name);
}

static int get_kernel_syms(struct kernel_sym *buffer) {
	return syscall( __NR_get_kernel_syms, buffer);
}
