2024-10-02 22:15:59 +04:00

268 lines
8.3 KiB
Python

# odbc test suite kindly contributed by Frank Millman.
import os
import sys
import tempfile
import unittest
import odbc
import pythoncom
from pywin32_testutil import TestSkipped, str2bytes, str2memory
from win32com.client import constants
# We use the DAO ODBC driver
from win32com.client.gencache import EnsureDispatch
class TestStuff(unittest.TestCase):
def setUp(self):
self.tablename = "pywin32test_users"
self.db_filename = None
self.conn = self.cur = None
try:
# Test any database if a connection string is supplied...
conn_str = os.environ["TEST_ODBC_CONNECTION_STRING"]
except KeyError:
# Create a local MSAccess DB for testing.
self.db_filename = tempfile.NamedTemporaryFile().name + ".mdb"
# Create a brand-new database - what is the story with these?
for suffix in (".36", ".35", ".30"):
try:
dbe = EnsureDispatch("DAO.DBEngine" + suffix)
break
except pythoncom.com_error:
pass
else:
raise TestSkipped("Can't find a DB engine")
workspace = dbe.Workspaces(0)
newdb = workspace.CreateDatabase(
self.db_filename, constants.dbLangGeneral, constants.dbEncrypt
)
newdb.Close()
conn_str = "Driver={Microsoft Access Driver (*.mdb)};dbq=%s;Uid=;Pwd=;" % (
self.db_filename,
)
## print 'Connection string:', conn_str
self.conn = odbc.odbc(conn_str)
# And we expect a 'users' table for these tests.
self.cur = self.conn.cursor()
## self.cur.setoutputsize(1000)
try:
self.cur.execute("""drop table %s""" % self.tablename)
except (odbc.error, odbc.progError):
pass
## This needs to be adjusted for sql server syntax for unicode fields
## - memo -> TEXT
## - varchar -> nvarchar
self.assertEqual(
self.cur.execute(
"""create table %s (
userid varchar(25),
username varchar(25),
bitfield bit,
intfield integer,
floatfield float,
datefield datetime,
rawfield varbinary(100),
longtextfield memo,
longbinaryfield image
)"""
% self.tablename
),
-1,
)
def tearDown(self):
if self.cur is not None:
try:
self.cur.execute("""drop table %s""" % self.tablename)
except (odbc.error, odbc.progError) as why:
print("Failed to delete test table %s" % self.tablename, why)
self.cur.close()
self.cur = None
if self.conn is not None:
self.conn.close()
self.conn = None
if self.db_filename is not None:
try:
os.unlink(self.db_filename)
except OSError:
pass
def test_insert_select(self, userid="Frank", username="Frank Millman"):
self.assertEqual(
self.cur.execute(
"insert into %s (userid, username) \
values (?,?)"
% self.tablename,
[userid, username],
),
1,
)
self.assertEqual(
self.cur.execute(
"select * from %s \
where userid = ?"
% self.tablename,
[userid.lower()],
),
0,
)
self.assertEqual(
self.cur.execute(
"select * from %s \
where username = ?"
% self.tablename,
[username.lower()],
),
0,
)
def test_insert_select_unicode(self, userid="Frank", username="Frank Millman"):
self.assertEqual(
self.cur.execute(
"insert into %s (userid, username)\
values (?,?)"
% self.tablename,
[userid, username],
),
1,
)
self.assertEqual(
self.cur.execute(
"select * from %s \
where userid = ?"
% self.tablename,
[userid.lower()],
),
0,
)
self.assertEqual(
self.cur.execute(
"select * from %s \
where username = ?"
% self.tablename,
[username.lower()],
),
0,
)
def test_insert_select_unicode_ext(self):
userid = "t-\xe0\xf2"
username = "test-\xe0\xf2 name"
self.test_insert_select_unicode(userid, username)
def _test_val(self, fieldName, value):
for x in range(100):
self.cur.execute("delete from %s where userid='Frank'" % self.tablename)
self.assertEqual(
self.cur.execute(
"insert into %s (userid, %s) values (?,?)"
% (self.tablename, fieldName),
["Frank", value],
),
1,
)
self.cur.execute(
"select %s from %s where userid = ?" % (fieldName, self.tablename),
["Frank"],
)
rows = self.cur.fetchmany()
self.assertEqual(1, len(rows))
row = rows[0]
self.assertEqual(row[0], value)
def testBit(self):
self._test_val("bitfield", 1)
self._test_val("bitfield", 0)
def testInt(self):
self._test_val("intfield", 1)
self._test_val("intfield", 0)
try:
big = sys.maxsize
except AttributeError:
big = sys.maxint
self._test_val("intfield", big)
def testFloat(self):
self._test_val("floatfield", 1.01)
self._test_val("floatfield", 0)
def testVarchar(
self,
):
self._test_val("username", "foo")
def testLongVarchar(self):
"""Test a long text field in excess of internal cursor data size (65536)"""
self._test_val("longtextfield", "abc" * 70000)
def testLongBinary(self):
"""Test a long raw field in excess of internal cursor data size (65536)"""
self._test_val("longbinaryfield", str2memory("\0\1\2" * 70000))
def testRaw(self):
## Test binary data
self._test_val("rawfield", str2memory("\1\2\3\4\0\5\6\7\8"))
def test_widechar(self):
"""Test a unicode character that would be mangled if bound as plain character.
For example, previously the below was returned as ascii 'a'
"""
self._test_val("username", "\u0101")
def testDates(self):
import datetime
for v in ((1900, 12, 25, 23, 39, 59),):
d = datetime.datetime(*v)
self._test_val("datefield", d)
def test_set_nonzero_length(self):
self.assertEqual(
self.cur.execute(
"insert into %s (userid,username) " "values (?,?)" % self.tablename,
["Frank", "Frank Millman"],
),
1,
)
self.assertEqual(
self.cur.execute("update %s set username = ?" % self.tablename, ["Frank"]),
1,
)
self.assertEqual(self.cur.execute("select * from %s" % self.tablename), 0)
self.assertEqual(len(self.cur.fetchone()[1]), 5)
def test_set_zero_length(self):
self.assertEqual(
self.cur.execute(
"insert into %s (userid,username) " "values (?,?)" % self.tablename,
[str2bytes("Frank"), ""],
),
1,
)
self.assertEqual(self.cur.execute("select * from %s" % self.tablename), 0)
self.assertEqual(len(self.cur.fetchone()[1]), 0)
def test_set_zero_length_unicode(self):
self.assertEqual(
self.cur.execute(
"insert into %s (userid,username) " "values (?,?)" % self.tablename,
["Frank", ""],
),
1,
)
self.assertEqual(self.cur.execute("select * from %s" % self.tablename), 0)
self.assertEqual(len(self.cur.fetchone()[1]), 0)
if __name__ == "__main__":
unittest.main()