From f3d619304a6294c89c71eca09cb0e90ab5334a38 Mon Sep 17 00:00:00 2001
From: Igor Yemelyanov <cybercjp@gmail.com>
Date: Sat, 22 Jul 2023 01:49:29 +0300
Subject: [PATCH] Fix usage python C API for version greater 3.9 and fix
 segfault

* Fix version macros

* Mov str_lower into global vars for use in istr and pair_list

* Fix macros usage
---
 multidict/_multidict.c          | 28 +++++++++++++++--
 multidict/_multilib/defs.h      |  4 +++
 multidict/_multilib/istr.h      |  4 +++
 multidict/_multilib/pair_list.h | 55 ++++++++++++++++++++-------------
 4 files changed, 67 insertions(+), 24 deletions(-)

diff --git a/multidict/_multidict.c b/multidict/_multidict.c
index 1ba79df..bcb3713 100644
--- a/multidict/_multidict.c
+++ b/multidict/_multidict.c
@@ -709,13 +709,21 @@ static inline void
 multidict_tp_dealloc(MultiDictObject *self)
 {
     PyObject_GC_UnTrack(self);
+#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 9
+    Py_TRASHCAN_BEGIN(self, multidict_tp_dealloc)
+#else
     Py_TRASHCAN_SAFE_BEGIN(self);
+#endif
     if (self->weaklist != NULL) {
         PyObject_ClearWeakRefs((PyObject *)self);
     };
     pair_list_dealloc(&self->pairs);
     Py_TYPE(self)->tp_free((PyObject *)self);
+#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 9
+    Py_TRASHCAN_END // there should be no code after this
+#else
     Py_TRASHCAN_SAFE_END(self);
+#endif
 }
 
 static inline int
@@ -777,9 +785,12 @@ multidict_add(MultiDictObject *self, PyObject *const *args,
         return NULL;
     }
 #else
-    static _PyArg_Parser _parser = {NULL, _keywords, "add", 0};
+    static _PyArg_Parser _parser = {
+        .keywords = _keywords,
+        .fname = "add",
+        .kwtuple = NULL,
+    };
     PyObject *argsbuf[2];
-
     args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames,
                                  &_parser, 2, 2, 0, argsbuf);
     if (!args) {
@@ -1655,6 +1666,9 @@ getversion(PyObject *self, PyObject *md)
 static inline void
 module_free(void *m)
 {
+#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 9
+    Py_CLEAR(multidict_str_lower);
+#endif
     Py_CLEAR(collections_abc_mapping);
     Py_CLEAR(collections_abc_mut_mapping);
     Py_CLEAR(collections_abc_mut_multi_mapping);
@@ -1683,6 +1697,13 @@ static PyModuleDef multidict_module = {
 PyMODINIT_FUNC
 PyInit__multidict()
 {
+#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 9
+    multidict_str_lower = PyUnicode_InternFromString("lower");
+    if (multidict_str_lower == NULL) {
+        goto fail;
+    }
+#endif
+
     PyObject *module = NULL,
              *reg_func_call_result = NULL;
 
@@ -1813,6 +1834,9 @@ PyInit__multidict()
     return module;
 
 fail:
+#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 9
+    Py_XDECREF(multidict_str_lower);
+#endif
     Py_XDECREF(collections_abc_mapping);
     Py_XDECREF(collections_abc_mut_mapping);
     Py_XDECREF(collections_abc_mut_multi_mapping);
diff --git a/multidict/_multilib/defs.h b/multidict/_multilib/defs.h
index c7027c8..55c2107 100644
--- a/multidict/_multilib/defs.h
+++ b/multidict/_multilib/defs.h
@@ -5,7 +5,11 @@
 extern "C" {
 #endif
 
+#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 9
+static PyObject *multidict_str_lower = NULL;
+#else
 _Py_IDENTIFIER(lower);
+#endif
 
 /* We link this module statically for convenience.  If compiled as a shared
    library instead, some compilers don't allow addresses of Python objects
diff --git a/multidict/_multilib/istr.h b/multidict/_multilib/istr.h
index 2688f48..61dc61a 100644
--- a/multidict/_multilib/istr.h
+++ b/multidict/_multilib/istr.h
@@ -43,7 +43,11 @@ istr_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
     if (!ret) {
         goto fail;
     }
+#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 9
+    s = PyObject_CallMethodNoArgs(ret, multidict_str_lower);
+#else
     s =_PyObject_CallMethodId(ret, &PyId_lower, NULL);
+#endif
     if (!s) {
         goto fail;
     }
diff --git a/multidict/_multilib/pair_list.h b/multidict/_multilib/pair_list.h
index 7eafd21..15291d4 100644
--- a/multidict/_multilib/pair_list.h
+++ b/multidict/_multilib/pair_list.h
@@ -8,8 +8,7 @@ extern "C" {
 #include <string.h>
 #include <stddef.h>
 #include <stdint.h>
-
-typedef PyObject * (*calc_identity_func)(PyObject *key);
+#include <stdbool.h>
 
 typedef struct pair {
     PyObject  *identity;  // 8
@@ -38,12 +37,12 @@ HTTP headers into the buffer without allocating an extra memory block.
 #define EMBEDDED_CAPACITY 29
 #endif
 
-typedef struct pair_list {  // 40
-    Py_ssize_t  capacity;   // 8
-    Py_ssize_t  size;       // 8
-    uint64_t  version;      // 8
-    calc_identity_func calc_identity;  // 8
-    pair_t *pairs;          // 8
+typedef struct pair_list {
+    Py_ssize_t capacity;
+    Py_ssize_t size;
+    uint64_t version;
+    bool calc_ci_indentity;
+    pair_t *pairs;
     pair_t buffer[EMBEDDED_CAPACITY];
 } pair_list_t;
 
@@ -111,7 +110,11 @@ ci_key_to_str(PyObject *key)
         return ret;
     }
     if (PyUnicode_Check(key)) {
+#if PY_VERSION_HEX < 0x03090000
         return _PyObject_CallMethodId(key, &PyId_lower, NULL);
+#else
+        return PyObject_CallMethodNoArgs(key, multidict_str_lower);
+#endif
     }
     PyErr_SetString(PyExc_TypeError,
                     "CIMultiDict keys should be either str "
@@ -199,30 +202,38 @@ pair_list_shrink(pair_list_t *list)
 
 
 static inline int
-_pair_list_init(pair_list_t *list, calc_identity_func calc_identity)
+_pair_list_init(pair_list_t *list, bool calc_ci_identity)
 {
+    list->calc_ci_indentity = calc_ci_identity;
     list->pairs = list->buffer;
     list->capacity = EMBEDDED_CAPACITY;
     list->size = 0;
     list->version = NEXT_VERSION();
-    list->calc_identity = calc_identity;
     return 0;
 }
 
 static inline int
 pair_list_init(pair_list_t *list)
 {
-    return _pair_list_init(list, key_to_str);
+    return _pair_list_init(list, /* calc_ci_identity = */ false);
 }
 
 
 static inline int
 ci_pair_list_init(pair_list_t *list)
 {
-    return _pair_list_init(list, ci_key_to_str);
+    return _pair_list_init(list, /* calc_ci_identity = */ true);
 }
 
 
+static inline PyObject *
+pair_list_calc_identity(pair_list_t *list, PyObject *key)
+{
+    if (list->calc_ci_indentity)
+        return ci_key_to_str(key);
+    return key_to_str(key);
+}
+
 static inline void
 pair_list_dealloc(pair_list_t *list)
 {
@@ -304,7 +315,7 @@ pair_list_add(pair_list_t *list,
     PyObject *identity = NULL;
     int ret;
 
-    identity = list->calc_identity(key);
+    identity = pair_list_calc_identity(list, key);
     if (identity == NULL) {
         goto fail;
     }
@@ -412,7 +423,7 @@ pair_list_del(pair_list_t *list, PyObject *key)
     Py_hash_t hash;
     int ret;
 
-    identity = list->calc_identity(key);
+    identity = pair_list_calc_identity(list, key);
     if (identity == NULL) {
         goto fail;
     }
@@ -486,7 +497,7 @@ pair_list_contains(pair_list_t *list, PyObject *key)
     PyObject *identity = NULL;
     int tmp;
 
-    ident = list->calc_identity(key);
+    ident = pair_list_calc_identity(list, key);
     if (ident == NULL) {
         goto fail;
     }
@@ -528,7 +539,7 @@ pair_list_get_one(pair_list_t *list, PyObject *key)
     PyObject *value = NULL;
     int tmp;
 
-    ident = list->calc_identity(key);
+    ident = pair_list_calc_identity(list, key);
     if (ident == NULL) {
         goto fail;
     }
@@ -573,7 +584,7 @@ pair_list_get_all(pair_list_t *list, PyObject *key)
     PyObject *res = NULL;
     int tmp;
 
-    ident = list->calc_identity(key);
+    ident = pair_list_calc_identity(list, key);
     if (ident == NULL) {
         goto fail;
     }
@@ -631,7 +642,7 @@ pair_list_set_default(pair_list_t *list, PyObject *key, PyObject *value)
     PyObject *value2 = NULL;
     int tmp;
 
-    ident = list->calc_identity(key);
+    ident = pair_list_calc_identity(list, key);
     if (ident == NULL) {
         goto fail;
     }
@@ -680,7 +691,7 @@ pair_list_pop_one(pair_list_t *list, PyObject *key)
     int tmp;
     PyObject *ident = NULL;
 
-    ident = list->calc_identity(key);
+    ident = pair_list_calc_identity(list, key);
     if (ident == NULL) {
         goto fail;
     }
@@ -730,7 +741,7 @@ pair_list_pop_all(pair_list_t *list, PyObject *key)
     PyObject *res = NULL;
     PyObject *ident = NULL;
 
-    ident = list->calc_identity(key);
+    ident = pair_list_calc_identity(list, key);
     if (ident == NULL) {
         goto fail;
     }
@@ -826,7 +837,7 @@ pair_list_replace(pair_list_t *list, PyObject * key, PyObject *value)
     PyObject *identity = NULL;
     Py_hash_t hash;
 
-    identity = list->calc_identity(key);
+    identity = pair_list_calc_identity(list, key);
     if (identity == NULL) {
         goto fail;
     }
@@ -1101,7 +1112,7 @@ pair_list_update_from_seq(pair_list_t *list, PyObject *seq)
         Py_INCREF(key);
         Py_INCREF(value);
 
-        identity = list->calc_identity(key);
+        identity = pair_list_calc_identity(list, key);
         if (identity == NULL) {
             goto fail_1;
         }
