|
|
|
@ -12,6 +12,7 @@ |
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
|
# limitations under the License. |
|
|
|
|
import socket |
|
|
|
|
|
|
|
|
|
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS |
|
|
|
|
from twisted.internet import defer, reactor |
|
|
|
@ -30,7 +31,10 @@ logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
SERVER_CACHE = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# our record of an individual server which can be tried to reach a destination. |
|
|
|
|
# |
|
|
|
|
# "host" is actually a dotted-quad or ipv6 address string. Except when there's |
|
|
|
|
# no SRV record, in which case it is the original hostname. |
|
|
|
|
_Server = collections.namedtuple( |
|
|
|
|
"_Server", "priority weight host port expires" |
|
|
|
|
) |
|
|
|
@ -219,9 +223,10 @@ class SRVClientEndpoint(object): |
|
|
|
|
return self.default_server |
|
|
|
|
else: |
|
|
|
|
raise ConnectError( |
|
|
|
|
"Not server available for %s" % self.service_name |
|
|
|
|
"No server available for %s" % self.service_name |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# look for all servers with the same priority |
|
|
|
|
min_priority = self.servers[0].priority |
|
|
|
|
weight_indexes = list( |
|
|
|
|
(index, server.weight + 1) |
|
|
|
@ -231,11 +236,22 @@ class SRVClientEndpoint(object): |
|
|
|
|
|
|
|
|
|
total_weight = sum(weight for index, weight in weight_indexes) |
|
|
|
|
target_weight = random.randint(0, total_weight) |
|
|
|
|
|
|
|
|
|
for index, weight in weight_indexes: |
|
|
|
|
target_weight -= weight |
|
|
|
|
if target_weight <= 0: |
|
|
|
|
server = self.servers[index] |
|
|
|
|
# XXX: this looks totally dubious: |
|
|
|
|
# |
|
|
|
|
# (a) we never reuse a server until we have been through |
|
|
|
|
# all of the servers at the same priority, so if the |
|
|
|
|
# weights are A: 100, B:1, we always do ABABAB instead of |
|
|
|
|
# AAAA...AAAB (approximately). |
|
|
|
|
# |
|
|
|
|
# (b) After using all the servers at the lowest priority, |
|
|
|
|
# we move onto the next priority. We should only use the |
|
|
|
|
# second priority if servers at the top priority are |
|
|
|
|
# unreachable. |
|
|
|
|
# |
|
|
|
|
del self.servers[index] |
|
|
|
|
self.used_servers.append(server) |
|
|
|
|
return server |
|
|
|
@ -280,26 +296,21 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
payload = answer.payload |
|
|
|
|
host = str(payload.target) |
|
|
|
|
srv_ttl = answer.ttl |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
answers, _, _ = yield dns_client.lookupAddress(host) |
|
|
|
|
except DNSNameError: |
|
|
|
|
continue |
|
|
|
|
hosts = yield _get_hosts_for_srv_record( |
|
|
|
|
dns_client, str(payload.target) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
for answer in answers: |
|
|
|
|
if answer.type == dns.A and answer.payload: |
|
|
|
|
ip = answer.payload.dottedQuad() |
|
|
|
|
host_ttl = min(srv_ttl, answer.ttl) |
|
|
|
|
for (ip, ttl) in hosts: |
|
|
|
|
host_ttl = min(answer.ttl, ttl) |
|
|
|
|
|
|
|
|
|
servers.append(_Server( |
|
|
|
|
host=ip, |
|
|
|
|
port=int(payload.port), |
|
|
|
|
priority=int(payload.priority), |
|
|
|
|
weight=int(payload.weight), |
|
|
|
|
expires=int(clock.time()) + host_ttl, |
|
|
|
|
)) |
|
|
|
|
servers.append(_Server( |
|
|
|
|
host=ip, |
|
|
|
|
port=int(payload.port), |
|
|
|
|
priority=int(payload.priority), |
|
|
|
|
weight=int(payload.weight), |
|
|
|
|
expires=int(clock.time()) + host_ttl, |
|
|
|
|
)) |
|
|
|
|
|
|
|
|
|
servers.sort() |
|
|
|
|
cache[service_name] = list(servers) |
|
|
|
@ -317,3 +328,68 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t |
|
|
|
|
raise e |
|
|
|
|
|
|
|
|
|
defer.returnValue(servers) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def _get_hosts_for_srv_record(dns_client, host): |
|
|
|
|
"""Look up each of the hosts in a SRV record |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
dns_client (twisted.names.dns.IResolver): |
|
|
|
|
host (basestring): host to look up |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
Deferred[list[(str, int)]]: a list of (host, ttl) pairs |
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
ip4_servers = [] |
|
|
|
|
ip6_servers = [] |
|
|
|
|
|
|
|
|
|
def cb(res): |
|
|
|
|
# lookupAddress and lookupIP6Address return a three-tuple |
|
|
|
|
# giving the answer, authority, and additional sections of the |
|
|
|
|
# response. |
|
|
|
|
# |
|
|
|
|
# we only care about the answers. |
|
|
|
|
|
|
|
|
|
return res[0] |
|
|
|
|
|
|
|
|
|
def eb(res): |
|
|
|
|
res.trap(DNSNameError) |
|
|
|
|
return [] |
|
|
|
|
|
|
|
|
|
# no logcontexts here, so we can safely fire these off and gatherResults |
|
|
|
|
d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb) |
|
|
|
|
d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb) |
|
|
|
|
results = yield defer.gatherResults([d1, d2], consumeErrors=True) |
|
|
|
|
|
|
|
|
|
for result in results: |
|
|
|
|
for answer in result: |
|
|
|
|
if not answer.payload: |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
if answer.type == dns.A: |
|
|
|
|
ip = answer.payload.dottedQuad() |
|
|
|
|
ip4_servers.append((ip, answer.ttl)) |
|
|
|
|
elif answer.type == dns.AAAA: |
|
|
|
|
ip = socket.inet_ntop( |
|
|
|
|
socket.AF_INET6, answer.payload.address, |
|
|
|
|
) |
|
|
|
|
ip6_servers.append((ip, answer.ttl)) |
|
|
|
|
else: |
|
|
|
|
# the most likely candidate here is a CNAME record. |
|
|
|
|
# rfc2782 says srvs may not point to aliases. |
|
|
|
|
logger.warn( |
|
|
|
|
"Ignoring unexpected DNS record type %s for %s", |
|
|
|
|
answer.type, host, |
|
|
|
|
) |
|
|
|
|
continue |
|
|
|
|
except Exception as e: |
|
|
|
|
logger.warn("Ignoring invalid DNS response for %s: %s", |
|
|
|
|
host, e) |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
# keep the ipv4 results before the ipv6 results, mostly to match historical |
|
|
|
|
# behaviour. |
|
|
|
|
defer.returnValue(ip4_servers + ip6_servers) |
|
|
|
|