shithub: choc

Download patch

ref: 7e4a268dc91d1e751e1274f60849d08b1dbf417a
parent: 76178ee71572e399ede72eac8d448595e6daac63
author: Simon Howard <fraggle@soulsphere.org>
date: Sat Feb 9 15:39:32 EST 2019

net: Add reference counting scheme for addresses.

This resolves a TODO with the NAT hole punching plan.

As part of this, add some comments to the net_io.h header to clarify
the semantics around references relating to the resolve and receive
functions, and how addresses must be released.

--- a/src/d_loop.c
+++ b/src/d_loop.c
@@ -448,6 +448,7 @@
 
         net_loop_client_module.InitClient();
         addr = net_loop_client_module.ResolveAddress(NULL);
+        NET_ReferenceAddress(addr);
     }
     else
     {
@@ -484,6 +485,7 @@
         {
             net_sdl_module.InitClient();
             addr = net_sdl_module.ResolveAddress(myargv[i+1]);
+            NET_ReferenceAddress(addr);
 
             if (addr == NULL)
             {
@@ -506,6 +508,7 @@
         }
 
         printf("D_InitNetGame: Connected to %s\n", NET_AddrToString(addr));
+        NET_ReleaseAddress(addr);
 
         // Wait for launch message received from server.
 
--- a/src/net_client.c
+++ b/src/net_client.c
@@ -290,7 +290,7 @@
     {
         net_client_connected = false;
 
-        NET_FreeAddress(server_addr);
+        NET_ReleaseAddress(server_addr);
 
         // Shut down network module, etc.  To do.
     }
@@ -969,12 +969,9 @@
         {
             NET_CL_ParsePacket(packet);
         }
-        else
-        {
-            NET_FreeAddress(addr);
-        }
 
         NET_FreePacket(packet);
+        NET_ReleaseAddress(addr);
     }
 
     // Run the common connection code to send any packets as needed
@@ -1029,6 +1026,7 @@
     int last_send_time;
 
     server_addr = addr;
+    NET_ReferenceAddress(addr);
 
     memcpy(net_local_wad_sha1sum, data->wad_sha1sum, sizeof(sha1_digest_t));
     memcpy(net_local_deh_sha1sum, data->deh_sha1sum, sizeof(sha1_digest_t));
--- a/src/net_defs.h
+++ b/src/net_defs.h
@@ -95,6 +95,7 @@
 struct _net_addr_s
 {
     net_module_t *module;
+    int refcount;
     void *handle;
 };
 
--- a/src/net_io.c
+++ b/src/net_io.c
@@ -59,8 +59,6 @@
     int i;
     net_addr_t *result;
 
-    result = NULL;
-
     for (i=0; i<context->num_modules; ++i)
     {
         result = context->modules[i]->ResolveAddress(addr);
@@ -67,11 +65,12 @@
 
         if (result != NULL)
         {
-            break;
+            NET_ReferenceAddress(result);
+            return result;
         }
     }
 
-    return result;
+    return NULL;
 }
 
 void NET_SendPacket(net_addr_t *addr, net_packet_t *packet)
@@ -101,6 +100,7 @@
     {
         if (context->modules[i]->RecvPacket(addr, packet))
         {
+            NET_ReferenceAddress(*addr);
             return true;
         }
     }
@@ -120,9 +120,28 @@
     return buf;
 }
 
-void NET_FreeAddress(net_addr_t *addr)
+void NET_ReferenceAddress(net_addr_t *addr)
 {
-    addr->module->FreeAddress(addr);
+    if (addr == NULL)
+    {
+        return;
+    }
+    ++addr->refcount;
+    //printf("%s: +refcount=%d\n", NET_AddrToString(addr), addr->refcount);
 }
 
+void NET_ReleaseAddress(net_addr_t *addr)
+{
+    if (addr == NULL)
+    {
+        return;
+    }
+
+    --addr->refcount;
+    //printf("%s: -refcount=%d\n", NET_AddrToString(addr), addr->refcount);
+    if (addr->refcount <= 0)
+    {
+        addr->module->FreeAddress(addr);
+    }
+}
 
--- a/src/net_io.h
+++ b/src/net_io.h
@@ -22,14 +22,39 @@
 
 extern net_addr_t net_broadcast_addr;
 
+// Create a new network context.
 net_context_t *NET_NewContext(void);
+
+// Add a network module to a context.
 void NET_AddModule(net_context_t *context, net_module_t *module);
+
+// Send a packet to the given address.
 void NET_SendPacket(net_addr_t *addr, net_packet_t *packet);
+
+// Send a broadcast using all modules in the given context.
 void NET_SendBroadcast(net_context_t *context, net_packet_t *packet);
-boolean NET_RecvPacket(net_context_t *context, net_addr_t **addr, 
+
+// Check all modules in the given context and receive a packet, returning true
+// if a packet was received. The result is stored in *packet and the source is
+// stored in *addr, with an implicit reference added. The packet must be freed
+// by the caller and the reference releasd.
+boolean NET_RecvPacket(net_context_t *context, net_addr_t **addr,
                        net_packet_t **packet);
+
+// Return a string representation of the given address. The result points to a
+// static buffer and will become invalid with the next call.
 char *NET_AddrToString(net_addr_t *addr);
-void NET_FreeAddress(net_addr_t *addr);
+
+// Add a reference to the given address.
+void NET_ReferenceAddress(net_addr_t *addr);
+
+// Release a reference to the given address. When there are no more references,
+// the address will be freed.
+void NET_ReleaseAddress(net_addr_t *addr);
+
+// Resolve a string representation of an address. If successful, a net_addr_t
+// pointer is received with an implicit reference that must be freed by the
+// caller when it is no longer needed.
 net_addr_t *NET_ResolveAddress(net_context_t *context, const char *address);
 
 #endif  /* #ifndef NET_IO_H */
--- a/src/net_query.c
+++ b/src/net_query.c
@@ -205,11 +205,25 @@
     target->printed = false;
     target->query_attempts = 0;
     target->addr = addr;
+    NET_ReferenceAddress(addr);
     ++num_targets;
 
     return target;
 }
 
+static void FreeTargets(void)
+{
+    int i;
+
+    for (i = 0; i < num_targets; ++i)
+    {
+        NET_ReleaseAddress(targets[i].addr);
+    }
+    free(targets);
+    targets = NULL;
+    num_targets = 0;
+}
+
 // Transmit a query packet
 
 static void NET_Query_SendQuery(net_addr_t *addr)
@@ -333,10 +347,10 @@
         // there.
 
         addr = NET_ResolveAddress(query_context, addr_str);
-
         if (addr != NULL)
         {
             GetTargetForAddr(addr, true);
+            NET_ReleaseAddress(addr);
         }
     }
 
@@ -375,6 +389,7 @@
     if (NET_RecvPacket(query_context, &addr, &packet))
     {
         NET_Query_ParsePacket(addr, packet, callback, user_data);
+        NET_ReleaseAddress(addr);
         NET_FreePacket(packet);
     }
 }
@@ -631,6 +646,7 @@
 
     target = GetTargetForAddr(master, true);
     target->type = QUERY_TARGET_MASTER;
+    NET_ReleaseAddress(master);
 
     return 1;
 }
@@ -746,6 +762,7 @@
         NET_Query_QueryLoop(NET_QueryPrintCallback, NULL);
 
         printf("\n%i server(s) found.\n", GetNumResponses());
+        FreeTargets();
     }
 }
 
@@ -758,6 +775,7 @@
         NET_Query_QueryLoop(NET_QueryPrintCallback, NULL);
 
         printf("\n%i server(s) found.\n", GetNumResponses());
+        FreeTargets();
     }
 }
 
@@ -790,6 +808,8 @@
     if (target->state == QUERY_TARGET_RESPONDED)
     {
         NET_QueryPrintCallback(addr, &target->data, target->ping_time, NULL);
+        NET_ReleaseAddress(addr);
+        FreeTargets();
     }
     else
     {
@@ -801,6 +821,7 @@
 {
     query_target_t *target;
     query_target_t *responder;
+    net_addr_t *result;
 
     NET_Query_Init();
 
@@ -817,12 +838,16 @@
 
     if (responder != NULL)
     {
-        return responder->addr;
+        result = responder->addr;
+        NET_ReferenceAddress(result);
     }
     else
     {
-        return NULL;
+        result = NULL;
     }
+
+    FreeTargets();
+    return result;
 }
 
 // Block until a packet of the given type is received from the given
@@ -845,6 +870,9 @@
             I_Sleep(20);
             continue;
         }
+
+        // Caller doesn't need additional reference.
+        NET_ReleaseAddress(packet_src);
 
         if (packet_src == addr
          && NET_ReadInt16(packet, &read_packet_type)
--- a/src/net_sdl.c
+++ b/src/net_sdl.c
@@ -127,6 +127,7 @@
     new_entry = Z_Malloc(sizeof(addrpair_t), PU_STATIC, 0);
 
     new_entry->sdl_addr = *addr;
+    new_entry->net_addr.refcount = 0;
     new_entry->net_addr.handle = &new_entry->sdl_addr;
     new_entry->net_addr.module = &net_sdl_module;
 
--- a/src/net_server.c
+++ b/src/net_server.c
@@ -565,6 +565,7 @@
     client->connect_time = I_GetTimeMS();
     NET_Conn_InitServer(&client->connection, addr, protocol);
     client->addr = addr;
+    NET_ReferenceAddress(addr);
     client->last_send_time = -1;
 
     // init the ticcmd send queue
@@ -1447,10 +1448,7 @@
     NET_WriteInt16(sendpacket, NET_PACKET_TYPE_NAT_HOLE_PUNCH);
     NET_SendPacket(addr, sendpacket);
     NET_FreePacket(sendpacket);
-
-    // TODO: We should NET_FreeAddress(addr) here, but this could cause a
-    // problem if the client has already connected. The address system needs
-    // to be changed to use a reference-counting system to prevent this.
+    NET_ReleaseAddress(addr);
 }
 
 static void NET_SV_MasterPacket(net_packet_t *packet)
@@ -1555,14 +1553,6 @@
                 break;
         }
     }
-
-    // If this address is not in the list of clients, be sure to
-    // free it back.
-
-    if (NET_SV_FindClient(addr) == NULL)
-    {
-        NET_FreeAddress(addr);
-    }
 }
 
 
@@ -1807,7 +1797,7 @@
         }
 
         free(client->name);
-        NET_FreeAddress(client->addr);
+        NET_ReleaseAddress(client->addr);
 
         // Are there any clients left connected?  If not, return the
         // server to the waiting-for-players state.
@@ -1895,15 +1885,9 @@
         net_addr_t *new_addr;
 
         new_addr = NET_Query_ResolveMaster(server_context);
+        NET_ReleaseAddress(master_server);
+        master_server = new_addr;
 
-        // Has the master server changed address?
-
-        if (new_addr != NULL && new_addr != master_server)
-        {
-            NET_FreeAddress(master_server);
-            master_server = new_addr;
-        }
-
         master_resolve_time = now;
     }
 
@@ -1962,6 +1946,7 @@
     {
         NET_SV_Packet(packet, addr);
         NET_FreePacket(packet);
+        NET_ReleaseAddress(addr);
     }
 
     if (master_server != NULL)