import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import FFPopSim as h
import random as rd
import Bio
from Bio import Phylo

print 'This script illustrates how the dynamics of neutral alleles and neutral genealogies\n\
are affected by linked selection. To this end, the population is set up in linkage\n\
equilibrium with mostly neutral and a few strongly deleterious sites. These deleterious \n\
sites become beneficial one after the other and sweep through the population. This perturbes\n\
allele frequencies and causes rapid coalescence. As output, the script plots allele frequency\n\
trajectories and genealogies.\n'

# specify parameters
L = 256                                           # simulate 256 loci

# set up population
pop = h.haploid_highd(L)                        # produce an instance of haploid_highd with L loci
pop.carrying_capacity = 10000                   # set the average population size to 50000
pop.outcrossing_rate = 1                        # make the species obligate outcrossing
pop.crossover_rate = 0.02 / pop.L               # set the crossover rate of the segment to 2 centimorgans
pop.mutation_rate = 0.1 / pop.carrying_capacity # per locus mutation rate equal to 0.1/N
s = 0.01                                        # selection coefficient, once alleles are become beneficial
sweeps = False                                   # switch to compare to neutral

# set fitness landscape
selection_coefficients = np.zeros(pop.L)     # most loci are neutral
m = 10
selection_coefficients[::m] = -0.1              # every m-th locus is strongly deleterious
pop.set_trait_additive(selection_coefficients)  # trait 0 is by default fitness

# track genealogies at a few positions
pop.track_locus_genealogy([35,95,155,215])

# initialize the population in linkage equilibrium with the specified allele frequencies
initial_allele_frequencies = 0.5*np.ones(pop.L)  # define some initial allele frequencies as 1/2
initial_allele_frequencies[::m] = 0.0            # set a subset of alleles to frequency 0
pop.set_allele_frequencies(initial_allele_frequencies, pop.carrying_capacity)


# evolve for 2000 generations and track the allele frequencies
maxgen = 2000
allele_frequencies = [pop.get_allele_frequencies()]
tp = [pop.generation]
while pop.generation < maxgen:
    pop.evolve(10)

    # save allele frequencies and time
    allele_frequencies.append(pop.get_allele_frequencies()) 
    tp.append(pop.generation)

    # every 200 generations, make one of the deleterious mutations beneficial
    if (pop.generation % 50 == 0):
        print "generation:", pop.generation, 'out of', maxgen
    if (sweeps and pop.generation % 200 == 0):
        # update fitness function
        selection_coefficients[m*np.random.randint(0,25)] = s
        pop.set_trait_additive(selection_coefficients)

# convert to an array to enable slicing
allele_frequencies = np.array(allele_frequencies)

# plot the allele frequency trajectories
plt.figure()

# plot the selected mutations
for locus in xrange(0,pop.L,m):
    plt.plot(tp, allele_frequencies[:,locus], c=cm.cool(locus),lw=2, ls='--')

# plot some neutral sites
for locus in xrange(5,pop.L,50):
    plt.plot(tp, allele_frequencies[:,locus], c=cm.cool(locus), lw=2)

plt.title('Effect of linked selection on neutral alleles')
plt.xlabel('Time [generations]')
plt.ylabel('Allele frequencies')
plt.text(100,0.85, "neutral alleles: solid")
plt.text(100,0.9, "sweeping alleles: dashed")
plt.text(100,0.765, "color indicates position \non the genome")
plt.ylim([0,1])
if sweeps: plt.savefig('../figures/af_'+"".join(map(str,['N=',pop.carrying_capacity,'_r=',pop.outcrossing_rate,'_L=',L, '_s=',s,'.png'])))
else: plt.savefig('../figures/af_'+"".join(map(str,['N=',pop.carrying_capacity,'_r=',pop.outcrossing_rate,'_L=',L, '_s=0','.png'])))

#plot the genealogies of the loci that where tracked
fig = plt.figure()
if sweeps: plt.suptitle('Genealogies with linked selection  (total # of generations '+str(maxgen)+')', fontsize=18)
else: plt.suptitle('Neutral genealogies (total # of generations '+str(maxgen)+')', fontsize=18)

sample_size = 30
for li,locus in enumerate(pop.genealogy.loci):
    print "\nretrieve tree at locus",locus
    tree = pop.genealogy.get_tree(locus)

    subtree = tree.create_subtree_from_keys(rd.sample(tree.leafs,sample_size)).to_Biopython_tree()
    subtree.ladderize()
    ax = plt.subplot(2,2,li+1)
    if (Bio.__version__ >=1.60): 
        Phylo.draw(subtree,label_func=lambda x:"", axes=ax)
        plt.text(100,sample_size, "locus: "+str(locus))
    else: 
        Phylo.draw(subtree,label_func=lambda x:"")
        plt.text(100,3, "locus: "+str(locus))
    plt.draw()
    
if sweeps: plt.savefig('../figures/trees_draft_'+"".join(map(str,['N=',pop.carrying_capacity,'_r=',pop.outcrossing_rate,'_L=',L, '_s=',s,'.png'])))
else: plt.savefig('../figures/trees_draft_'+"".join(map(str,['N=',pop.carrying_capacity,'_r=',pop.outcrossing_rate,'_L=',L, '_s=0','.png'])))