import psycopg2
import re
import time

from simple_naive_bayes import naivebayes,getwords,ident

con = psycopg2.connect(host='dw02.dev01.corp.metaweb.com', database='common',
                       user='postgres')
cur=con.cursor()

category_list=['Category:Ninjas',
               'Category:Pirates',
               'Category:American_assassins',
               'Category:Jesters'
              ]
print "Training classifiers..."
# Get the members of every category
id_classes={}
for category in category_list:
    cur.execute("select article_wpid from wikipedia.category_members "+
                "where category_name like '%s'" % category)
    rec=cur.fetchall()
    
    for wpid, in rec:
        id_classes.setdefault(wpid,set()).add(category)

# Remove Henri Caesar before training
del id_classes[17574574]

# Train a classifier for each class
classifiers=[(cat,naivebayes(ident)) for cat in category_list]

for wpid in id_classes:
    cur.execute("select name,text from wikipedia.articles where wpid=%s" % wpid)
    name,text=cur.fetchone()
    if name.startswith('Category:'): continue 
   
    #print name
    words=getwords(text[0:1024])
    
    for cat,cl in classifiers:
        if cat in id_classes[wpid]:
            cl.train(words,1)
        else:
            cl.train(words,0)

print

# Bill Gates and Henri Caesar
test_set=[3747,17574574]

for wpid in test_set:
    cur.execute("select name,text from wikipedia.articles where wpid=%s" % wpid)
    name,text=cur.fetchone()
    if name.startswith('Category:'): continue
    
    print name
    words=getwords(text[0:1024])
    
    for cat,cl in classifiers:
        py,pn=cl.prob(words,1),cl.prob(words,0)
        print '%s\t%s\t%f' % (cat,cl.classify(words),py/pn if pn>0 else 100)

    print

cur.close()
con.close()

# Try building a string that will fall into multiple categories
test_string = "Toby Segaran lived during the sengoku period in Japan. He spent many years at sea battling big Japanese ships."
print test_string
words=getwords(test_string)
for cat,cl in classifiers:
    py,pn=cl.prob(words,"yes"),cl.prob(words,"no")
    print '%s\t%s\t%f' % (cat,cl.classify(words),py/pn if pn>0 else 100)

