Comment tracer plusieurs Seaborn Jointplot dans un sous-Lot

je vais avoir des problèmes en plaçant Seaborn Jointplot à l'intérieur d'un multicolumn subplot.

import pandas as pd
import seaborn as sns

df = pd.DataFrame({'C1': {'a': 1,'b': 15,'c': 9,'d': 7,'e': 2,'f': 2,'g': 6,'h': 5,'k': 5,'l': 8},
          'C2': {'a': 6,'b': 18,'c': 13,'d': 8,'e': 6,'f': 6,'g': 8,'h': 9,'k': 13,'l': 15}})

fig = plt.figure();   
ax1 = fig.add_subplot(121);  
ax2 = fig.add_subplot(122);

sns.jointplot("C1", "C2", data=df, kind='reg', ax=ax1)
sns.jointplot("C1", "C2", data=df, kind='kde', ax=ax2)

remarquez comme seulement une partie du jointplot est placé à l'intérieur du sous-Lot et le reste est laissé à l'intérieur de deux autres cadres. Ce que je veux est à la fois le distributions également inséré à l'intérieur du subplots.

quelqu'un Peut-il aider?

19
demandé sur Fabio Lamanna 2016-01-27 19:05:11

2 réponses

cela ne peut pas être fait facilement sans piratage. jointplot appelle JointGrid méthode, qui à son tour crée un nouveau figure objet à chaque fois qu'il est appelé.

par conséquent, le hack est de faire deux lots (JG1JG2), puis faire une nouvelle figure, puis migrer les objets axes de JG1JG2 à la nouvelle figure créée.

enfin, nous ajustons les tailles et les positions des sous-lots dans la nouvelle figure que nous venons de créer.

JG1 = sns.jointplot("C1", "C2", data=df, kind='reg')
JG2 = sns.jointplot("C1", "C2", data=df, kind='kde')

#subplots migration
f = plt.figure()
for J in [JG1, JG2]:
    for A in J.fig.axes:
        f._axstack.add(f._make_key(A), A)

#subplots size adjustment
f.axes[0].set_position([0.05, 0.05, 0.4,  0.4])
f.axes[1].set_position([0.05, 0.45, 0.4,  0.05])
f.axes[2].set_position([0.45, 0.05, 0.05, 0.4])
f.axes[3].set_position([0.55, 0.05, 0.4,  0.4])
f.axes[4].set_position([0.55, 0.45, 0.4,  0.05])
f.axes[5].set_position([0.95, 0.05, 0.05, 0.4])

C'est un Hacker parce que nous utilisons maintenant _axstack et _add_key méthodes privées, qui pourraient et pourraient ne pas rester les mêmes qu'elles sont maintenant dans matplotlib versions futures.

enter image description here

20
répondu CT Zhu 2016-09-22 23:46:56

déplacer des axes dans matplotlib n'est pas aussi simple que dans les versions précédentes. Ce qui suit fonctionne avec la version actuelle de matplotlib.

Comme cela a été souligné à plusieurs endroits ( cette question également ce problème) plusieurs des commandes seaborn créent leur propre figure automatiquement. Ce code est codé en dur dans le code seaborn, de sorte qu'il n'y a actuellement aucun moyen de produire de telles parcelles dans les chiffres existants. Ceux-ci sont PairGrid, FacetGrid,JointGrid,pairplot,jointplot et lmplot.

il y a un seaborn fourche disponible qui permettrait de fournir une grille de sous-parcelles aux classes respectives de sorte que la parcelle soit créée dans une figure préexistante. Pour l'utiliser, vous devez copier le axisgrid.py de la fourche au dossier seaborn. Notez que ce n'est actuellement limités à être utilisé avec matplotlib 2.1 (éventuellement 2.0).

une alternative pourrait être de créer un seaborn et copier les axes sur une autre figure. Le principe de ceci est montré dans cette réponse et pourrait être étendu aux placettes Searborn. La mise en œuvre est un peu plus compliquée que je m'y attendais au départ. Ce qui suit est une classe SeabornFig2Grid qui peut être appelé avec une instance Seaborn grid (le retour de l'une des commandes ci-dessus), un chiffre matplotlib et un subplot_spec, qui est une position d'un gridspec grille.

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
import numpy as np

class SeabornFig2Grid():

    def __init__(self, seaborngrid, fig,  subplot_spec):
        self.fig = fig
        self.sg = seaborngrid
        self.subplot = subplot_spec
        if isinstance(self.sg, sns.axisgrid.FacetGrid) or \
            isinstance(self.sg, sns.axisgrid.PairGrid):
            self._movegrid()
        elif isinstance(self.sg, sns.axisgrid.JointGrid):
            self._movejointgrid()
        self._finalize()

    def _movegrid(self):
        """ Move PairGrid or Facetgrid """
        self._resize()
        n = self.sg.axes.shape[0]
        m = self.sg.axes.shape[1]
        self.subgrid = gridspec.GridSpecFromSubplotSpec(n,m, subplot_spec=self.subplot)
        for i in range(n):
            for j in range(m):
                self._moveaxes(self.sg.axes[i,j], self.subgrid[i,j])

    def _movejointgrid(self):
        """ Move Jointgrid """
        h= self.sg.ax_joint.get_position().height
        h2= self.sg.ax_marg_x.get_position().height
        r = int(np.round(h/h2))
        self._resize()
        self.subgrid = gridspec.GridSpecFromSubplotSpec(r+1,r+1, subplot_spec=self.subplot)

        self._moveaxes(self.sg.ax_joint, self.subgrid[1:, :-1])
        self._moveaxes(self.sg.ax_marg_x, self.subgrid[0, :-1])
        self._moveaxes(self.sg.ax_marg_y, self.subgrid[1:, -1])

    def _moveaxes(self, ax, gs):
        #https://stackoverflow.com/a/46906599/4124317
        ax.remove()
        ax.figure=self.fig
        self.fig.axes.append(ax)
        self.fig.add_axes(ax)
        ax._subplotspec = gs
        ax.set_position(gs.get_position(self.fig))
        ax.set_subplotspec(gs)

    def _finalize(self):
        plt.close(self.sg.fig)
        self.fig.canvas.mpl_connect("resize_event", self._resize)
        self.fig.canvas.draw()

    def _resize(self, evt=None):
        self.sg.fig.set_size_inches(self.fig.get_size_inches())

L'utilisation de cette classe ressembler à ceci:

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns; sns.set()
import SeabornFig2Grid as sfg


iris = sns.load_dataset("iris")
tips = sns.load_dataset("tips")

# An lmplot
g0 = sns.lmplot(x="total_bill", y="tip", hue="smoker", data=tips, 
                palette=dict(Yes="g", No="m"))
# A PairGrid
g1 = sns.PairGrid(iris, hue="species")
g1.map(plt.scatter, s=5)
# A FacetGrid
g2 = sns.FacetGrid(tips, col="time",  hue="smoker")
g2.map(plt.scatter, "total_bill", "tip", edgecolor="w")
# A JointGrid
g3 = sns.jointplot("sepal_width", "petal_length", data=iris,
                   kind="kde", space=0, color="g")


fig = plt.figure(figsize=(13,8))
gs = gridspec.GridSpec(2, 2)

mg0 = sfg.SeabornFig2Grid(g0, fig, gs[0])
mg1 = sfg.SeabornFig2Grid(g1, fig, gs[1])
mg2 = sfg.SeabornFig2Grid(g2, fig, gs[3])
mg3 = sfg.SeabornFig2Grid(g3, fig, gs[2])

gs.tight_layout(fig)
#gs.update(top=0.7)

plt.show()

enter image description here

notez qu'il peut y avoir plusieurs inconvénients à copier des axes et que ce qui précède n'est pas (encore) testé à fond.

15
répondu ImportanceOfBeingErnest 2017-12-05 23:44:16