Twisted code review…

If you have a few minutes and speak python & twisted, it would be useful to have an extra set of eyes on this section of code. The basic idea of this is to be a reconnecting thrift client, such that I can just write simple client.function(a,b,c) calls without having to worry about if there is or isn’t a client and it will queue reconnect as needed.

  1from thrift.transport import TTwisted
  2from thrift.protocol import TBinaryProtocol
  3from twisted.internet.protocol import ReconnectingClientFactory
  4from twisted.internet import defer, reactor
  5from twisted.python import failure
  6from collections import deque
  7from redback import log
  8
  9class ClientBusy(Exception):
 10    pass
 11
 12class ClientDead(Exception):
 13    pass
 14
 15class InvalidThriftRequest(Exception):
 16    pass
 17
 18class ManagedThriftRequest(object):
 19    def __init__(self, method, *args, **kw) :
 20        self.method = method
 21        self.args = args
 22        self.kw = kw
 23
 24class ManagedClient(object) :
 25    def __init__(self, factory) :
 26        self.__factory = factory
 27
 28    def _is_connected(self) :
 29        return self.__factory.is_connected()
 30
 31    def __getattr__(self, name) :
 32        if hasattr(self.__factory.client_class, name) :
 33            def f(*args, **kw) :
 34                return self.__factory.pushRequest(ManagedThriftRequest(name, *args, **kw))
 35                return f
 36        raise InvalidThriftRequest("Cant find method: %s" % name)
 37
 38class ManagedThriftClientProtocol(TTwisted.ThriftClientProtocol):
 39      
 40    def __init__(self, client_class, iprot_factory, oprot_factory=None):
 41        TTwisted.ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory)
 42        self.client_class = client_class
 43        self.deferred = None
 44        self.alive = False
 45
 46    def connectionMade(self):
 47        log.debug(self, "Connection made to", "%s.%s" % (self.client_class.__module__, self.client_class.__name__), "[%s]:%s" % self.transport.addr)
 48        self.alive = True
 49        TTwisted.ThriftClientProtocol.connectionMade(self)
 50        self.client.protocol = self
 51        self.factory.clientIdle(self)
 52
 53    def connectionLost(self, reason=None):
 54        log.debug(self, "Connection lost to", "%s.%s" % (self.client_class.__module__, self.client_class.__name__), "[%s]:%s" % self.transport.addr)
 55        self.alive = False
 56          
 57        try :
 58            TTwisted.ThriftClientProtocol.connectionLost(self, reason)
 59            self.factory.clientGone(self)
 60        except Exception, e :
 61            log.error(self.connectionLost, e)
 62
 63    def _complete(self, res, request, dfd):
 64        self.deferred = None
 65        if isinstance(res, failure.Failure) and dfd :
 66            self.factory.pushRequest(request, dfd)
 67        else :
 68            if dfd :
 69                dfd.callback(res)
 70            self.factory.clientIdle(self)
 71        return res
 72
 73    def submitRequest(self, request, dfd):
 74        if not self.alive :
 75            raise ClientBusy
 76        if not self.deferred :
 77            fun = getattr(self.client, request.method, None)
 78            if not fun:
 79                raise InvalidThriftRequest("No such method as : %s" % request.method)
 80            else :
 81                try :
 82                    d = fun(*request.args, **request.kw)
 83                except Exception, e :
 84                    log.error(self.submitRequest, "calling : ", request.method, e)
 85                self.deferred = d
 86            d.addBoth(self._complete, request, dfd)
 87            return d
 88        else:
 89            raise ClientBusy
 90
 91class ManagedClientFactory(ReconnectingClientFactory):
 92    maxDelay = 5
 93    thriftFactory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory
 94    protocol = ManagedThriftClientProtocol
 95    submitLoopSleep = 0
 96    client_class = None
 97
 98    def __init__(self, client_class=None):
 99        self._stack = deque()
100        self._protos = defer.DeferredQueue()
101        self.deferred = defer.Deferred()
102        self.client_class = client_class or self.client_class
103        self.client = ManagedClient(self)
104
105    def _errback(self, reason=None):
106        if self.deferred :
107            self.deferred.errback(reason)
108        self.deferred = None
109
110    def _callback(self, value=None):
111        if self.deferred :
112            self.deferred.callback(value)
113        self.deferred = None
114
115    def clientConnectionFailed(self, connector, reason):
116        ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
117        self._errback(reason)
118
119    def buildProtocol(self, addr) :
120        self.resetDelay()
121        p = self.protocol(self.client_class, self.thriftFactory())
122        p.factory = self
123        #self._protos.put(p)
124        return p
125
126    def clientIdle(self, proto) :
127        self._callback(True)
128        if proto.alive :
129            self._protos.put(proto)
130
131    def clientGone(self, proto):
132        pass
133        #if proto in self._protos :
134        # self._protos.remove(proto)
135
136    def _protoErr(self, proto):
137        pass
138        #import traceback
139        #traceback.print_stack()
140        #print "IN PROTO ERR", proto
141
142    def _protoReady(self, proto):
143        if proto.deferred :
144            log.msg(self._protoReady, "Proto currently active!")
145            return
146          
147        if not proto.alive :
148            log.msg(self._protoReady, "Proto currently dead!")
149            return
150        try:
151            request, deferred = self._stack.popleft()
152        except defer.QueueUnderflow :
153            pass
154        d = proto.submitRequest(request, deferred)
155        return d
156
157    def pushRequest(self, request, din=None) :
158        d = din or defer.Deferred()
159        self._stack.append((request, d))
160        dfd = self._protos.get()
161        dfd.addCallback(self._protoReady)
162        dfd.addErrback(self._protoErr)
163        return d
164
165    def shutdown(self) :
166        """Shutdown this factory"""
167        self.stopTrying()
168        for p in self._protos:
169            if p.transport:
170                p.transport.loseConnection()