import FFPopSim as h
import numpy as np
from matplotlib import pyplot as plt
import random as rd

def get_mut_count_subtree(R):
	'''recursive function that returns a list of pairs of branch length and the number of down stream leaves\
	branch length is proportional to the number of mutations, the number of leaves to the frequency of the mutations.'''
	if R.is_terminal():
		return int(R.name.split('_')[1]),[]				#a terminal branch has no downstream mutations
	else:
		tmp_mut_count = []
		sub_tree_size = 0
		for C in R.clades:		
			#loop over all children of the node and accumulate their mutations
			csize,cmut  = get_mut_count_subtree(C)
			sub_tree_size += csize
			tmp_mut_count.extend(cmut)
			#add mutations that happened on the branch to the child. those are present in csize leafs
			tmp_mut_count.append([csize,C.branch_length])
		return sub_tree_size, tmp_mut_count


def get_SFS(T):
	'''returns the sample size and the site frequency spectrum of a tree T'''
	#get a list of all opportunities for mutations on the tree below the root.
	#all derived mutations happen there
	sample_size,SFS = get_mut_count_subtree(T.root)
	#convert to a numpy array and normalize
	SFS = np.asarray(SFS)
	SFS[:,0]/=sample_size
	return sample_size,SFS

print "This script is meant to illustrate and explore the effect of\n\
purifying selection on neutral and deleterious site frequency spectra (SFS)\n\n"

L = 2000   	#number of segregating sites
s = 1e-2 	#single site effect
N = 10000 	#population size
r = 0.05  	#outcrossing rate

nsamples = 100	#number of trees
burnin = 2000 	#either ~5*N or 5/s, depending on whether coalescence is dominated by drift or draft
dt = 100 		#time between samples

#set up population
pop=h.haploid_highd(L,all_polymorphic=True)

#set the population size via the carrying capacity
pop.carrying_capacity= N

#set the crossover rate, outcrossing_rate and recombination model
pop.outcrossing_rate = r
pop.recombination_model = h.CROSSOVERS
pop.crossover_rate = 10.0/pop.L

#set the effect sizes of the mutations that are injected (the same at each site in this case)
selection_coefficients = np.ones(L)*s*0.5
selection_coefficients[::2] = -selection_coefficients[::2]
selection_coefficients[L/2] = 1e-10
pop.set_fitness_additive(selection_coefficients)

#track the genealogy at a central locus L/2 (which one doesn't matter in the asexual case)
pop.track_locus_genealogy([L/2])

#initialize the populations
pop.set_wildtype(pop.carrying_capacity)

#burn in
print "\nEquilibrate:"
while pop.generation<burnin:
	print "Burn in: at", pop.generation, "out of", burnin, "generations"
	pop.evolve(100)

#set up bins non-uniformly make histograms of the allele frequencies
bins=np.exp(np.linspace(-3.5*np.log(10),2.5*np.log(10),21)) \
			/(1+np.exp(np.linspace(-3.5*np.log(10),2.5*np.log(10),21)))
bins[0]=1.5/pop.carrying_capacity; bins[-1]=1
bincenters = 0.5*(bins[1:]+bins[:-1])
dx= bins[1:]-bins[:-1]

#allocate arrays to accumulate the SFS
neutralSFS = np.zeros_like(bincenters)
deleteriousSFS = np.zeros_like(bincenters)
beneficialSFS = np.zeros_like(bincenters)
for si in xrange(nsamples):
	pop.evolve(dt)
	print "sample", si, "out of", nsamples
	#get the tree and pass it to the function that calculate the neutral allele frequency spectrum
	BPtree = pop.genealogy.get_tree(L/2).to_Biopython_tree()
	sample_size,tmpSFS = get_SFS(BPtree)
	y,x = np.histogram(tmpSFS[:,0], weights = tmpSFS[:,1], bins=bins)
	neutralSFS+=y
	#get the frequencies of selected alleles. Partition into deleterious and beneficial ones
	derived_af = pop.get_derived_allele_frequencies()
	y,x = np.histogram(derived_af[::2], bins=bins)
	deleteriousSFS+=y
	y,x = np.histogram(derived_af[1::2], bins=bins)
	beneficialSFS+=y

#plot the allele frequency spectra. 
#Use a log-it spacing of bins such that both asymptotics at 0 and 1 are seen nicely	
plt.plot(np.log(bincenters/(1-bincenters)), neutralSFS/dx/neutralSFS[1], label='neutral', lw=2)
plt.plot(np.log(bincenters/(1-bincenters)), deleteriousSFS/dx/deleteriousSFS[1], label='deleterious', lw=2)
plt.plot(np.log(bincenters/(1-bincenters)), beneficialSFS/dx/beneficialSFS[1], label='beneficial', lw=2)
plt.plot(np.log(np.asarray([0.001,0.5])/(1-np.asarray([0.001,0.5]))), [1e3,2], label=r'$\sim \nu^{-1}$', lw=2)
plt.plot(np.log(np.asarray([0.001,0.5])/(1-np.asarray([0.001,0.5]))), [1e1,1e-5*4], label=r'$\sim \nu^{-2}$', lw=2)
ax=plt.gca()
ax.set_yscale('log')
#label the log-it axis manually
tick_locations = np.asarray([0.001,0.01, 0.1, 0.5, 0.9, 0.99])
plt.xticks(np.log(tick_locations/(1-tick_locations)), map(str,[0.001,0.01, 0.1, 0.5, 0.9, 0.99, 0.99]))
plt.title('Site frequency spectrum, r='+str(r), fontsize=18)
plt.ylabel('SFS')
plt.xlabel(r'derived allele frequency $\nu$')
plt.legend()

plt.savefig('../figures/SFS_derived_'+"".join(map(str,['N=',N,'_r=',r,'_L=',L, '_s=',s,'.png'])))