/*
 *
 * 
 * per user traffic counting
 * 4.10.2003 Klaus Rechert
 *
 * input socketlookup from ipt-owner-socketlookup.patch by Patrick McHardy
 *
 *
 * 
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
 *
 */


#include <linux/types.h>
#include <linux/ip.h>
#include <linux/netfilter.h>
#include <linux/netfilter_ipv4.h>
#include <linux/module.h>
#include <linux/skbuff.h>
#include <linux/proc_fs.h>
#include <linux/version.h>
#include <linux/brlock.h>
#include <linux/vmalloc.h>
#include <net/sock.h>
#include <linux/file.h>

#include <linux/tcp.h>
#include <linux/udp.h>

#include <net/tcp.h>
#include <net/udp.h>
#include <net/route.h>


/* set #define PATCHED if Kernel is patched with socket_lookup.patch */
// #define PATCHED


#define HSIZE 128
static struct proc_dir_entry *dir;
static struct user *htable[HSIZE];



struct user {
	struct user *next;
	unsigned int uid;
	struct proc_dir_entry *procfile;
	unsigned long in;
	unsigned long out;
};


static int read_proc(char *page, char **start,
			     off_t off, int count,
			     int *eof, void *data)
{
	struct user *u;
	int len;

	u = (struct user *)data;

	MOD_INC_USE_COUNT;
#ifdef PATCHED
	len = sprintf(page, "in:out\n%ld:%ld\n", u->in, u->out);
#else
	len = sprintf(page, "%ld\n", u->out);
#endif
	MOD_DEC_USE_COUNT;

	return len;
}


static int new_proc(struct user *u) { 
	struct proc_dir_entry* proc;
	char buff[10];
	
	sprintf(buff, "%i", u->uid);
	proc = create_proc_read_entry(buff, 0444, dir, read_proc, u);
	
	if(proc) {
		u->procfile = proc;
		return 0;
	}
	else {
		return -1;
	}
}
	
static struct user *freeuser(struct user *u) {
	struct user *next;
	char buff[10];

	if(!u)
		return NULL;

	next = u->next;
	
	sprintf(buff, "%u", u->uid);
	remove_proc_entry(buff, dir);
	kfree(u);
	
	return next;
}


static struct user *new_user(unsigned int uid) 
{
	struct user *newuser;
	
	newuser = kmalloc(sizeof(struct user),  GFP_ATOMIC);
	if(newuser) {
		
		newuser->uid = uid;
		newuser->next = NULL;
		newuser->in = 0;
		newuser->out = 0;
		
		if (new_proc(newuser)) {
			freeuser(newuser);
			return NULL;
		}
		
	}
	return newuser;
}	

static int free_htable() 
{
	int i, ret = 0;
	struct user *u, *next;
	
	for (i = 0; i < HSIZE; i++) {
		if((u = htable[i]) != NULL) {
			next = freeuser(u);
			while(next != NULL) {
				next = freeuser(next);
			}
		}
	}
	return ret;
}	


static struct user *adduser(unsigned int uid) 
{
	int hash;
	struct user *newuser, *u;
	
	
	hash = uid % HSIZE;
	newuser = new_user(uid);
	
	if(!newuser) 
		return NULL;

	if(htable[hash] == NULL)
		htable[hash] = newuser;
	else {
		u = htable[hash];
		for(; u->next != NULL; u = u->next)
			;
		u->next = newuser;
	}
	return newuser;
}

static struct user *finduser(struct sock *sk) {

	int hash;
	struct user *u;
	unsigned int uid;

	if(sk && sk->socket && sk->socket->file) 
		uid = sk->socket->file->f_uid;
	else 
		return NULL;

	hash = uid % HSIZE;
	
	if(htable[hash] == NULL) 	
		return adduser(uid);
	else { 
		u = htable[hash];
		if(u->uid == uid)
			return u;
		
		for(; u->next != NULL; u = u->next) {
			if(u->uid == uid)
				return u;
		}
		
		return adduser(uid);

	}
}	

static unsigned int count(unsigned int hooknum,
                                       struct sk_buff **pskb,
                                       const struct net_device *in,
                                       const struct net_device *out,
                                       int (*okfn)(struct sk_buff *)) {

	
	
	struct user *u;
	unsigned int uid = 0;
	struct sk_buff *skb = *pskb;
	unsigned int len;
	struct sock *sk;
	struct tcphdr *tcph;
	struct udphdr *udph;
	
	if(hooknum == NF_IP_LOCAL_OUT) {
		
		sk = skb->sk;
		
		u = finduser(sk);
		if(!u) 
			return NF_ACCEPT;
			
		u->out += skb->len;
	}
#ifdef PATCHED
	else if(hooknum == NF_IP_LOCAL_IN) {

		if(skb->nh.iph->protocol == IPPROTO_TCP) {
			tcph = (struct tcphdr *)((u_int32_t *)skb->nh.iph + skb->nh.iph->ihl);
			sk = tcp_v4_lookup(skb->nh.iph->saddr, tcph->source, skb->nh.iph->daddr, tcph->dest, ((struct rtable*)skb->dst)->rt_iif);
			/*
			if(sk && sk->state == TCP_TIME_WAIT) {
				tcp_tw_put((struct tcp_tw_bucket *)sk);
				return NF_ACCEPT;
			}
			*/
		}
		else if (skb->nh.iph->protocol == IPPROTO_UDP) {
			udph = (struct udphdr *)((u_int32_t *)skb->nh.iph + skb->nh.iph->ihl);
			sk = udp_v4_lookup(skb->nh.iph->saddr, udph->source, skb->nh.iph->daddr, udph->dest, skb->dev->ifindex);
		}
		else
			return NF_ACCEPT;

		u = finduser(sk);
		if(!u)
			return NF_ACCEPT;

		u->in += skb->len;
	}
#endif
	return NF_ACCEPT;
}

static struct nf_hook_ops hook_out = { { NULL, NULL }, count, PF_INET, NF_IP_LOCAL_OUT, NF_IP_PRI_FIRST};
#ifdef PATCHED
static struct nf_hook_ops hook_in = { { NULL, NULL }, count, PF_INET, NF_IP_LOCAL_IN, NF_IP_PRI_LAST};
#endif

static int __init init(void)
{

	int ret = 0;
	
	dir = proc_mkdir("traffic", NULL);
	if(!dir)
		goto out;

	ret = nf_register_hook(&hook_out);	
	if(ret < 0) {
		printk("traffic:: could not register Hook: NF_LOKAL_OUT\n");
		goto out;
	}

#ifdef PATCHED
	ret = nf_register_hook(&hook_in);                                                                                                   
	if(ret < 0) {
		printk("traffic:: could not register Hook: NF_LOKAL_IN\n");
		goto unreg_hook;
	}
	else
		goto out;
		
	
unreg_hook:	
	nf_unregister_hook(&hook_out);
#endif
	
out:
	return ret;
		
}	

static void __exit fini(void)
{	

#ifdef PATCHED
	nf_unregister_hook(&hook_in);
#endif
	nf_unregister_hook(&hook_out);
	remove_proc_entry("traffic", NULL);
	free_htable();
}


MODULE_LICENSE("GPL");
module_init(init);
module_exit(fini);
