From 0d149b84ec8d401784c2361b072f97fcedca7021 Mon Sep 17 00:00:00 2001
From: valentin <valentin.bruch@rwth-aachen.de>
Date: Sun, 9 Feb 2020 23:17:18 +0100
Subject: [PATCH] new feature: Keldysh diagrams

---
 diagrams.py | 185 +++++++++++++++++++++++++++++++++++++---------------
 1 file changed, 131 insertions(+), 54 deletions(-)

diff --git a/diagrams.py b/diagrams.py
index 00f7bb3..064ee4e 100644
--- a/diagrams.py
+++ b/diagrams.py
@@ -26,7 +26,7 @@ class BaseTeX(object):
 
     def tikz(self, position=''):
         'Return TikZ code for drawing self.'
-        return ''
+        return r'\coordinate[%s] (%s);'%(position, self.identifier)
 
     def left(self):
         'Return TikZ node id of leftmost subnode.'
@@ -43,6 +43,7 @@ class Vertex(BaseTeX):
     Each vertex consists of n circles to which contraction lines are attached.
     '''
     def __init__(self,
+            contour,
             n,
             label='',
             radius=0.5,
@@ -51,6 +52,7 @@ class Vertex(BaseTeX):
             draw_color='black'
             ):
         super().__init__()
+        self.contour = contour
         self.n = n # rank of the vertex
         self.label = label
         self.radius = radius
@@ -102,7 +104,10 @@ class Vertex(BaseTeX):
                     )
         if self.label:
             string += '\n'
-            string += r'\node[above of=%s] {%s};'%(self.identifier, self.label);
+            if self.contour.keldysh_order > 0:
+                string += r'\node[below of=%s] {%s};'%(self.identifier, self.label);
+            else:
+                string += r'\node[above of=%s] {%s};'%(self.identifier, self.label);
         return string
 
 
@@ -141,9 +146,21 @@ class Contraction:
             shape_style = ', ' + self.shape_style
         else:
             shape_style = ''
-        return r'\draw[%s] (%s) to[out=270, in=270%s] (%s);'%(
+        try:
+            if   self.v1.contour.keldysh_order > self.v2.contour.keldysh_order:
+                directions = (90, 270)
+            elif self.v1.contour.keldysh_order < self.v2.contour.keldysh_order:
+                directions = (270, 90)
+            elif self.v1.contour.keldysh_order != 0:
+                directions = (270, 270)
+            else:
+                directions = (90, 90)
+        except:
+            directions = (270, 270)
+        return r'\draw[%s] (%s) to[out=%d, in=%d%s] (%s);'%(
                 self.style,
                 self.v1.getId(self.idx1),
+                *directions,
                 shape_style,
                 self.v2.getId(self.idx2),
                 )
@@ -154,15 +171,41 @@ class Interrupt(BaseTeX):
     Node interrupting the diagram contour.
     This can be, e.g., a node ρ separating the two branches of the Keldysh contour.
     '''
-    def __init__(self, label=''):
+    def __init__(self, label='', **kwargs):
         super().__init__()
         self.label = label
+        for key, value in kwargs.items():
+            setattr(self, key, value)
 
-    def tikz(self, position=''):
+    def isKeldysh(self):
+        return hasattr(self, 'keldysh_sep')
+
+    def tikzKeldyshConnect(self, previous):
+        if getattr(self, 'keldysh_connect', None) is None:
+            return ''
+        try:
+            assert self.keldysh_direction == 'right'
+            return r'\draw[%s] (%s) to[out=0, in=0] (%s);'%(self.keldysh_connect, previous.identifier, self.identifier)
+        except:
+            return r'\draw[%s] (%s) to[out=180, in=180] (%s);'%(self.keldysh_connect, previous.identifier, self.identifier)
+
+    def tikz(self, position='', previous=None):
         'Return TikZ code for this interruption.'
-        if position:
-            position = ', ' + position
-        return r'\node[fill=white%s] (%s) {%s};'%(position, self.identifier, self.label)
+        if self.isKeldysh():
+            string = r'\coordinate[yshift=-%s] (%s) at (%s);'%(self.keldysh_sep, self.identifier, previous.identifier)
+            if getattr(self, 'keldysh_label', ''):
+                try:
+                    assert self.keldysh_direction == 'right'
+                    keldysh_label_shift = '1em'
+                except:
+                    keldysh_label_shift = '-1em'
+                string += '\n'
+                string += r'\path (%s) -- (%s) node[pos=0.5, xshift=%s] {%s};'%(previous.identifier, self.identifier, keldysh_label_shift, self.keldysh_label)
+            return string
+        else:
+            if position:
+                position = ', ' + position
+            return r'\node[fill=white%s] (%s) {%s};'%(position, self.identifier, self.label)
 
 
 class BaseLine:
@@ -190,6 +233,9 @@ class BaseLine:
     def pprint(self):
         print(self.string, end='')
 
+    def tikz(self, node1, node2):
+        return r'\draw[%s] (%s) -- (%s);'%(self.style, node1.identifier, node2.identifier)
+
 
 class Contour:
     '''
@@ -198,6 +244,7 @@ class Contour:
     '''
     def __init__(self):
         self.elements = []
+        self.keldysh_order = -1
 
     def pprint(self):
         for e in self.elements:
@@ -246,6 +293,8 @@ class Diagram:
     \end{pgfonlayer}
     \end{tikzpicture}
 
+    >>> # Keldysh diagram (still quite ugly)
+    >>> Diagram('- g12 - g23 - | -- g13 -')
     >>> # Diagram with customized vertice, contraction, base line, and interruption.
     >>> d = Diagram()
     >>> d.setVertexStyle('V', n=2, label='$V$')
@@ -262,6 +311,7 @@ class Diagram:
         self.contractions = []
         self.contours = []
         self.interruptions = []
+        self.isKeldysh = False
         # Properties
         self.sep = sep
         self.sep_unit = sep_unit
@@ -281,7 +331,7 @@ class Diagram:
                 '~':dict(style='wiggly, thick')
                 }
         self.interrupt_styles = {
-                '|':dict(label=''),
+                '|':dict(label='', keldysh_sep='4ex', keldysh_direction='left', keldysh_label='', keldysh_connect='thick'),
                 'ρ':dict(label=r'$\rho$'),
                 }
         if string:
@@ -362,9 +412,11 @@ class Diagram:
         # Start by an interrupt (possibly empty).
         try:
             self.interruptions = [Interrupt(**self.interrupt_styles[string[0]])]
+            if self.interruptions[-1].isKeldysh():
+                self.isKeldysh = True
             string = string[1:]
         except KeyError:
-            self.interruptions = [Interrupt()]
+            self.interruptions = [BaseTeX()]
 
         # Map of open indices: every open index is mapped to it origin (vertex, subindex).
         vertex_indices = {}
@@ -379,7 +431,7 @@ class Diagram:
             # Check if c represents a vertex.
             if c in self.vertex_styles:
                 # Add a new vertex to the current contour.
-                self.contours[-1].elements.append(Vertex(**self.vertex_styles[c]))
+                self.contours[-1].elements.append(Vertex(self.contours[-1], **self.vertex_styles[c]))
                 # Read the indices for this vertex.
                 for j in range(self.vertex_styles[c]['n']):
                     s = string[i+1+j]
@@ -412,6 +464,8 @@ class Diagram:
             elif c in self.interrupt_styles.keys():
                 # Add an interruption and start a new contour.
                 self.interruptions.append(Interrupt(**self.interrupt_styles[c]))
+                if self.interruptions[-1].isKeldysh():
+                    self.isKeldysh = True
                 self.contours.append(Contour())
             else:
                 # Invalid character in diagram.
@@ -465,62 +519,85 @@ class Diagram:
         \tikzset{wiggly/.style={decorate, decoration=snake}}
         \tikzset{zigzag/.style={decorate, decoration=zigzag}}
         '''
-        # List of all TikZ nodes on the diagram axis combined with the number of prior base lines.
-        nodes = []
-        # List of all base line objects with the indices of their starting node: (baseline, index).
-        baselines = {}
-        # Iterate over contours, interruptions and elements to collect all nodes and baselines.
-        for i, contour in enumerate(self.contours):
+
+        # Begin creating the TikZ picture.
+        print(r'\begin{tikzpicture}[node distance=%s, baseline=(%s.base)]'%(node_distance, self.interruptions[0].identifier), file=file)
+
+        # Create the first node.
+        print(self.interruptions[0].tikz(), file=file)
+
+        # Iterate over contours, interruptions and elements to draw vertices and interruptions.
+        last_node = self.interruptions[0]
+        keldysh_order = -1 + self.isKeldysh
+        for i, contour in enumerate(self.contours, 1):
+            # First draw the contour.
+            contour.keldysh_order = keldysh_order
+            for j, e in enumerate(contour.elements):
+                if type(e) != BaseLine:
+                    factor = getattr(contour.elements[j-1], 'factor', 1) or 1
+                    print(e.tikz(position='right of=%s, xshift=%g%s'%(last_node.right(), factor*self.sep, self.sep_unit)), file=file)
+                    last_node = e
+            # Then draw the interruption.
             try:
-                if self.interruptions[i].label:
-                    nodes.append(self.interruptions[i])
+                if self.interruptions[i].isKeldysh():
+                    if getattr(self.interruptions[i], 'keldysh_direction', 'left') == 'left':
+                        print(self.interruptions[i].tikz(previous=self.interruptions[i-1]), file=file)
+                    else:
+                        assert type(last_node) == Interrupt
+                    last_node = self.interruptions[i]
+                    keldysh_order += 1
+                elif self.interruptions[i].label:
+                    factor = getattr(contour.elements[-1], 'factor', 1) or 1
+                    print(self.interruptions[i].tikz(position='right of=%s, xshift=%g%s'%(last_node.right(), factor*self.sep, self.sep_unit)), file=file)
+                    last_node = self.interruptions[i]
             except:
                 pass
-            for e in contour.elements:
-                if type(e) == BaseLine:
-                    baselines[len(nodes)-1] = e
-                else:
-                    nodes.append(e)
         # Collect trailing interruptions.
         for i in range(len(self.contours), len(self.interruptions)):
             try:
-                if self.interruptions[i].label:
-                    nodes.append(self.interruptions[i])
+                if self.interruptions[i].label and not self.interruptions[i].isKeldysh():
+                    print(self.interruptions[i].tikz(position='right of=%s, xshift=%g%s'%(self.interruptions[i-1].identifier, self.sep, self.sep_unit)), file=file)
             except:
-                pass
+                break
 
-        # Begin creating the TikZ picture.
-        print(r'\begin{tikzpicture}[node distance=%s, baseline=(%s.base)]'%(node_distance, nodes[0].identifier), file=file)
-        # Create the first node.
-        try:
-            print(nodes[0].tikz(), file=file)
-        except:
-            print('Error while printing node tikz:', nodes[0], file=sys.stderr)
-
-        # Draw vertices and interruptions (including their labels).
-        for i, n in enumerate(nodes[1:]):
-            try:
-                factor = baselines[i].factor or 1
-            except:
-                factor = 1
-            print(n.tikz(position='right of=%s, xshift=%g%s'%(nodes[i].right(), factor*self.sep, self.sep_unit)), file=file)
 
         # Draw the rest on the background.
         print(r'\begin{pgfonlayer}{background}', file=file)
 
         # Draw base lines.
-        for i, b in baselines.items():
-            if i < 0:
-                # Base line left of all nodes.
-                print(r'\coordinate[left of=%s, xshift=-%s%s] (left);'%(nodes[0].identifier, b.factor*self.sep, self.sep_unit), file=file)
-                print(r'\draw[%s] (left) -- (%s);'%(b.style, nodes[0].identifier), file=file)
-            elif i > len(nodes)-2:
-                # Base line right of all nodes.
-                print(r'\coordinate[right of=%s, xshift=%s%s] (right);'%(nodes[-1].identifier, b.factor*self.sep, self.sep_unit), file=file)
-                print(r'\draw[%s] (%s) -- (right);'%(b.style, nodes[-1].identifier), file=file)
-            else:
-                # Base line between two nodes.
-                print(r'\draw[%s] (%s) -- (%s);'%(b.style, nodes[i].identifier, nodes[i+1].identifier), file=file)
+        for i, contour in enumerate(self.contours, 1):
+            # First draw the contour.
+            if type(contour.elements[0]) == BaseLine:
+                print(contour.elements[0].tikz(self.interruptions[i-1], contour.elements[1]), file=file)
+            for j, e in enumerate(contour.elements[1:], 1):
+                if type(e) == BaseLine:
+                    try:
+                        # Base line between last_node and e.
+                        print(e.tikz(contour.elements[j-1], contour.elements[j+1]), file=file)
+                    except:
+                        pass
+            if type(e) == BaseLine:
+                try:
+                    assert not self.interruptions[i].isKeldysh()
+                    print(e.tikz(contour.elements[-2], self.interruptions[i]), file=file)
+                except:
+                    try:
+                        print(r'\coordinate[right of=%s, xshift=%g%s] (right%d);'%(contour.elements[-2].identifier, e.factor*self.sep, self.sep_unit, i), file=file)
+                        print(r'\draw[%s] (%s) -- (right%d);'%(e.style, contour.elements[-2].identifier, i), file=file)
+                    except:
+                        pass
+            try:
+                if self.interruptions[i].isKeldysh():
+                    # Draw connection of the contours.
+                    if self.interruptions[i].keldysh_direction == 'right':
+                        self.interruptions[i].tikz(previous=last_node)
+                        string = self.interruptions[i].tikzKeldyshConnect(last_node)
+                    else:
+                        string = self.interruptions[i].tikzKeldyshConnect(self.interruptions[i-1])
+                    if string:
+                        print(string, file=file)
+            except:
+                pass
 
         # Draw all contractions.
         for c in self.contractions:
-- 
GitLab