import FFPopSim as h
import numpy as np
from matplotlib import pyplot as plt
import random as rd
from Bio import Phylo

print "This script is meant to illustrate and explore the effect of\n\
fitness variation and clonal interference on the fixation probability of \n\
beneficial mutations.\n\n"

L = 500   	#number of segregating sites
s = 1e-2 	#single site effect
N = 2000 	#population size
rvalues = [0.0,0.01,0.02,0.05, 0.1, 0.2,0.5]  	#outcrossing rate, simulation take longer for larger r

nsamples = 100	#number of samples
burnin = 1000 	#5/s or more
dt = 100 		#time between samples, 1/s should be ok

#bins used to calculate histograms of the population fitness distribution.
fit_bins = np.arange(-0.2,0.2,s)
bin_centers = 0.5*(fit_bins[1:]+fit_bins[:-1])

#lists to store the output
fitness_distribution = []
fixed_distribution = []
pfix = []

for r in rvalues:
	#set up population, switch on infinite sites mode
	pop=h.haploid_highd(L,all_polymorphic=True)
	#set the population size via the carrying capacity
	pop.carrying_capacity= N	
	#set the effect sizes of the mutations that are injected (the same at each site in this case)
	#note that coefficients are multiplied by 0.5 since FFPopSim represents alleles by +/- 1 internally
	pop.set_fitness_additive(np.ones(L)*s*0.5)
	#set the crossover rate, outcrossing_rate and recombination model
	pop.outcrossing_rate = r
	pop.recombination_model = h.FREE_RECOMBINATION	#unlinked loci
	#initialize the populations
	pop.set_wildtype(pop.carrying_capacity)
	
	print "Population parameters:"
	pop.status()

	#burn in
	print "\nEquilibrate:"
	while pop.generation<burnin:
		print "r=",r,"Burn in: at", pop.generation, "out of", burnin, "generations"
		pop.evolve(100)
	
	
	print "\nMeasure:"
	temp_dis = np.zeros_like(bin_centers)
	for si in xrange(nsamples):
		print "r=",r,"sample",si,"out of",nsamples
		#evolve a while before drawing the next sample
		pop.evolve(dt)
		#retrieve the fitness distribution as a list of pairs of clone sizes and fitness values (relative to the mean fitness)
		tmp_dis = np.asarray(zip(pop.get_clone_sizes(), pop.get_fitnesses()-pop.get_fitness_statistics().mean))
		#bin those fitness values into a histogram
		fit_dis,x = np.histogram(tmp_dis[:,1], weights=tmp_dis[:,0], bins=fit_bins, normed=True)
		#increment the histogram
		temp_dis+=fit_dis
	
	#normalize the population histogram and append to the collection for different L
	fitness_distribution.append(temp_dis/nsamples)
	#make a list of pairs of the fixation time and the initial fitness of each fixed mutation
	fixed_mutations = np.asarray([(a.birth+a.sweep_time, a.fitness) for a in pop.fixed_mutations])
	#reduce to the mutations that fixed after the burnin period
	late_fixed_mutations = fixed_mutations[fixed_mutations[:,0]>burnin,:]
	#calculate the total number of mutations introduced during this time
	total_mutations = np.sum(pop.number_of_mutations[burnin:])
	#make and save histogram of fixed mutations, as well as the overall fixation probability
	fix_dis,x = np.histogram(late_fixed_mutations[:,1], bins=fit_bins, normed=True)
	fixed_distribution.append(fix_dis)
	pfix.append(1.0*len(late_fixed_mutations)/total_mutations)

	print "\n\n","L=",L,"Fixation probability:",pfix[-1]



#plot the fixation probability
colors = ['g', 'b', 'r', 'm', 'c', 'g', 'b']	
plt.figure()
plt.title('Fixation Probability')
plt.plot(rvalues, np.asarray(pfix)/2/s, lw=2)
plt.xlabel('outcrossing rate')
plt.ylabel(r'$P_{fix}/2s$')
ax=plt.gca()
ax.set_xscale('log')
plt.savefig('../figures/'+'pfix_vs_r_'+'_'.join(map(str,['N',N, 's', s, 'L', L]))+'.png')

#plot the fitness distribution of the populations and the distribution from where fixed mutations originate
plt.figure(figsize=(8,4))
for ri in xrange(0,len(rvalues),2):
	r=rvalues[ri]
	plt.subplot(1,2,1)
	plt.plot(bin_centers/s, fitness_distribution[ri], label=r'$r='+str(r)+r'$', c=colors[ri], lw=2)
	plt.plot(bin_centers/s, fixed_distribution[ri], c=colors[ri], ls = '--', lw=2)
	plt.subplot(1,2,2)
	plt.plot(bin_centers/s, fixed_distribution[ri]/fitness_distribution[ri]*pfix[ri]/2/s, label=r'$r='+str(r)+r'$', c=colors[ri], lw=2)

plt.subplot(1,2,1)
plt.title(r'$\mathrm{solid:}\ n(\chi).\ \mathrm{dashed:}\ n(\chi)\phi(\chi,s)$')
plt.xlabel(r'fitness $\chi/s$')
plt.legend(loc=2)
plt.subplot(1,2,2)
plt.title(r'Fixation probability $\phi(\chi,s)/2s$')
ax=plt.gca()
ax.set_yscale('log')
plt.ylim([1e-3,50])
plt.xlabel(r'fitness $\chi/s$')
#plt.ylabel(r'$P_{fix}/2s$')

plt.savefig('../figures/'+'fit_dis_vs_r_'+'_'.join(map(str,['N',N, 's', s, 'L', L]))+'.png')