Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ struct DuckDBPyConnection : public enable_shared_from_this<DuckDBPyConnection> {
//! MemoryFileSystem used to temporarily store file-like objects for reading
shared_ptr<ModifiedMemoryFileSystem> internal_object_filesystem;
case_insensitive_map_t<unique_ptr<ExternalDependency>> registered_functions;
case_insensitive_set_t registered_objects;

public:
explicit DuckDBPyConnection() {
Expand Down
15 changes: 15 additions & 0 deletions src/duckdb_py/include/duckdb_python/python_replacement_scan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,25 @@
#include "duckdb/common/case_insensitive_map.hpp"
#include "duckdb/parser/tableref.hpp"
#include "duckdb/function/replacement_scan.hpp"
#include "duckdb_python/python_dependency.hpp"
#include "duckdb_python/pybind11/pybind_wrapper.hpp"

namespace duckdb {

class PythonRegisteredObjectState : public ClientContextState {
public:
static constexpr const char *Key = "python_registered_objects";

void Register(const string &name, const py::object &object);
void Unregister(const string &name);
py::object Get(const string &name);
bool Contains(const string &name);

private:
mutex lock;
case_insensitive_map_t<shared_ptr<DependencyItem>> registered_objects;
};

struct PythonReplacementScan {
public:
static unique_ptr<TableRef> Replace(ClientContext &context, ReplacementScanInput &input,
Expand Down
37 changes: 26 additions & 11 deletions src/duckdb_py/pyconnection.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "duckdb_python/pyconnection/pyconnection.hpp"

#include "duckdb/catalog/catalog.hpp"
#include "duckdb/catalog/default/default_types.hpp"
#include "duckdb/common/arrow/arrow.hpp"
#include "duckdb/common/enums/profiler_format.hpp"
#include "duckdb/common/types.hpp"
Expand Down Expand Up @@ -52,6 +54,17 @@ shared_ptr<PythonImportCache> DuckDBPyConnection::import_cache = nullptr;
PythonEnvironmentType DuckDBPyConnection::environment = PythonEnvironmentType::NORMAL; // NOLINT: allow global
std::string DuckDBPyConnection::formatted_python_version = "";

static shared_ptr<PythonRegisteredObjectState> GetPythonRegisteredObjectState(ClientContext &context) {
return context.registered_state->GetOrCreate<PythonRegisteredObjectState>(PythonRegisteredObjectState::Key);
}

static bool TemporaryObjectExists(ClientContext &context, const string &name) {
auto &catalog = Catalog::GetCatalog(context, TEMP_CATALOG);
EntryLookupInfo lookup_info(CatalogType::TABLE_ENTRY, name);
auto entry = catalog.GetEntry(context, DEFAULT_SCHEMA, lookup_info, OnEntryNotFound::RETURN_NULL);
return entry != nullptr;
}

DuckDBPyConnection::~DuckDBPyConnection() {
try {
py::gil_scoped_release gil;
Expand Down Expand Up @@ -743,11 +756,16 @@ shared_ptr<DuckDBPyConnection> DuckDBPyConnection::RegisterPythonObject(const st
const py::object &python_object) {
auto &connection = con.GetConnection();
auto &client = *connection.context;
auto object = PythonReplacementScan::ReplacementObject(python_object, name, client);
auto view_rel = make_shared_ptr<ViewRelation>(connection.context, std::move(object), name);
bool replace = registered_objects.count(name);
view_rel->CreateView(name, replace, true);
registered_objects.insert(name);
auto registered_state = GetPythonRegisteredObjectState(client);
if (!registered_state->Contains(name)) {
bool temp_object_exists = false;
client.RunFunctionInTransaction([&]() { temp_object_exists = TemporaryObjectExists(client, name); }, false);
if (temp_object_exists) {
throw CatalogException("View with name \"%s\" already exists!", name);
}
}
PythonReplacementScan::ReplacementObject(python_object, name, client);
registered_state->Register(name, python_object);
return shared_from_this();
}

Expand Down Expand Up @@ -1821,15 +1839,12 @@ unordered_set<string> DuckDBPyConnection::GetTableNames(const string &query, boo

shared_ptr<DuckDBPyConnection> DuckDBPyConnection::UnregisterPythonObject(const string &name) {
auto &connection = con.GetConnection();
if (!registered_objects.count(name)) {
auto registered_state = GetPythonRegisteredObjectState(*connection.context);
if (!registered_state->Contains(name)) {
return shared_from_this();
}
D_ASSERT(py::gil_check());
py::gil_scoped_release release;
// FIXME: DROP TEMPORARY VIEW? doesn't exist?
const auto quoted_name = SQLQuotedIdentifier::ToString(name);
connection.Query("DROP VIEW " + quoted_name + "");
registered_objects.erase(name);
registered_state->Unregister(name);
return shared_from_this();
}

Expand Down
38 changes: 38 additions & 0 deletions src/duckdb_py/python_replacement_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,34 @@

namespace duckdb {

void PythonRegisteredObjectState::Register(const string &name, const py::object &object) {
py::gil_scoped_acquire gil;
lock_guard<mutex> guard(lock);
registered_objects[name] = PythonDependencyItem::Create(object);
}

void PythonRegisteredObjectState::Unregister(const string &name) {
py::gil_scoped_acquire gil;
lock_guard<mutex> guard(lock);
registered_objects.erase(name);
}

py::object PythonRegisteredObjectState::Get(const string &name) {
py::gil_scoped_acquire gil;
lock_guard<mutex> guard(lock);
auto entry = registered_objects.find(name);
if (entry == registered_objects.end()) {
return py::none();
}
auto &dependency = entry->second->Cast<PythonDependencyItem>();
return dependency.object->obj;
}

bool PythonRegisteredObjectState::Contains(const string &name) {
lock_guard<mutex> guard(lock);
return registered_objects.find(name) != registered_objects.end();
}

static void CreateArrowScan(const string &name, py::object entry, TableFunctionRef &table_function,
vector<unique_ptr<ParsedExpression>> &children, ClientProperties &client_properties,
PyArrowObjectType type, DatabaseInstance &db) {
Expand Down Expand Up @@ -238,6 +266,16 @@ static unique_ptr<TableRef> ReplaceInternal(ClientContext &context, const string
return nullptr;
}

auto registered_objects =
context.registered_state->Get<PythonRegisteredObjectState>(PythonRegisteredObjectState::Key);
if (registered_objects) {
py::gil_scoped_acquire acquire;
auto entry = registered_objects->Get(table_name);
if (!entry.is_none()) {
return PythonReplacementScan::TryReplacementObject(entry, table_name, context);
}
}

lookup_result = context.TryGetCurrentSetting("python_scan_all_frames", result);
D_ASSERT((bool)lookup_result);
auto scan_all_frames = result.GetValue<bool>();
Expand Down
18 changes: 18 additions & 0 deletions tests/fast/pandas/test_pandas_unregister.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gc
import tempfile
import weakref

import pandas as pd
import pytest
Expand Down Expand Up @@ -50,3 +51,20 @@ def test_pandas_unregister2(self, duckdb_cursor):
with pytest.raises(duckdb.CatalogException, match="Table with name dataframe does not exist"):
connection.execute("SELECT * FROM dataframe;").fetchdf()
connection.close()

def test_pandas_unregister_releases_object_inside_transaction(self, duckdb_cursor):
duckdb_cursor.execute("CREATE TABLE t(i BIGINT)")
duckdb_cursor.begin()

df = pd.DataFrame({"i": [1, 2, 3]})
ref = weakref.ref(df)

duckdb_cursor.register("dataframe", df)
duckdb_cursor.execute("INSERT INTO t SELECT * FROM dataframe")
duckdb_cursor.unregister("dataframe")

del df
gc.collect()

assert ref() is None
duckdb_cursor.rollback()