Ei kuvausta

Collaborative Filtering.ipynb 63KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# Collaborative Filtering\n",
  8. "\n",
  9. "## The Framework"
  10. ]
  11. },
  12. {
  13. "cell_type": "code",
  14. "execution_count": 1,
  15. "metadata": {
  16. "collapsed": true
  17. },
  18. "outputs": [],
  19. "source": [
  20. "import pandas as pd\n",
  21. "import numpy as np"
  22. ]
  23. },
  24. {
  25. "cell_type": "code",
  26. "execution_count": 2,
  27. "metadata": {},
  28. "outputs": [
  29. {
  30. "data": {
  31. "text/html": [
  32. "<div>\n",
  33. "<style>\n",
  34. " .dataframe thead tr:only-child th {\n",
  35. " text-align: right;\n",
  36. " }\n",
  37. "\n",
  38. " .dataframe thead th {\n",
  39. " text-align: left;\n",
  40. " }\n",
  41. "\n",
  42. " .dataframe tbody tr th {\n",
  43. " vertical-align: top;\n",
  44. " }\n",
  45. "</style>\n",
  46. "<table border=\"1\" class=\"dataframe\">\n",
  47. " <thead>\n",
  48. " <tr style=\"text-align: right;\">\n",
  49. " <th></th>\n",
  50. " <th>user_id</th>\n",
  51. " <th>age</th>\n",
  52. " <th>sex</th>\n",
  53. " <th>occupation</th>\n",
  54. " <th>zip_code</th>\n",
  55. " </tr>\n",
  56. " </thead>\n",
  57. " <tbody>\n",
  58. " <tr>\n",
  59. " <th>0</th>\n",
  60. " <td>1</td>\n",
  61. " <td>24</td>\n",
  62. " <td>M</td>\n",
  63. " <td>technician</td>\n",
  64. " <td>85711</td>\n",
  65. " </tr>\n",
  66. " <tr>\n",
  67. " <th>1</th>\n",
  68. " <td>2</td>\n",
  69. " <td>53</td>\n",
  70. " <td>F</td>\n",
  71. " <td>other</td>\n",
  72. " <td>94043</td>\n",
  73. " </tr>\n",
  74. " <tr>\n",
  75. " <th>2</th>\n",
  76. " <td>3</td>\n",
  77. " <td>23</td>\n",
  78. " <td>M</td>\n",
  79. " <td>writer</td>\n",
  80. " <td>32067</td>\n",
  81. " </tr>\n",
  82. " <tr>\n",
  83. " <th>3</th>\n",
  84. " <td>4</td>\n",
  85. " <td>24</td>\n",
  86. " <td>M</td>\n",
  87. " <td>technician</td>\n",
  88. " <td>43537</td>\n",
  89. " </tr>\n",
  90. " <tr>\n",
  91. " <th>4</th>\n",
  92. " <td>5</td>\n",
  93. " <td>33</td>\n",
  94. " <td>F</td>\n",
  95. " <td>other</td>\n",
  96. " <td>15213</td>\n",
  97. " </tr>\n",
  98. " </tbody>\n",
  99. "</table>\n",
  100. "</div>"
  101. ],
  102. "text/plain": [
  103. " user_id age sex occupation zip_code\n",
  104. "0 1 24 M technician 85711\n",
  105. "1 2 53 F other 94043\n",
  106. "2 3 23 M writer 32067\n",
  107. "3 4 24 M technician 43537\n",
  108. "4 5 33 F other 15213"
  109. ]
  110. },
  111. "execution_count": 2,
  112. "metadata": {},
  113. "output_type": "execute_result"
  114. }
  115. ],
  116. "source": [
  117. "#Load the u.user file into a dataframe\n",
  118. "u_cols = ['user_id', 'age', 'sex', 'occupation', 'zip_code']\n",
  119. "\n",
  120. "users = pd.read_csv('../data/movielens/u.user', sep='|', names=u_cols,\n",
  121. " encoding='latin-1')\n",
  122. "\n",
  123. "users.head()"
  124. ]
  125. },
  126. {
  127. "cell_type": "code",
  128. "execution_count": 3,
  129. "metadata": {},
  130. "outputs": [
  131. {
  132. "data": {
  133. "text/html": [
  134. "<div>\n",
  135. "<style>\n",
  136. " .dataframe thead tr:only-child th {\n",
  137. " text-align: right;\n",
  138. " }\n",
  139. "\n",
  140. " .dataframe thead th {\n",
  141. " text-align: left;\n",
  142. " }\n",
  143. "\n",
  144. " .dataframe tbody tr th {\n",
  145. " vertical-align: top;\n",
  146. " }\n",
  147. "</style>\n",
  148. "<table border=\"1\" class=\"dataframe\">\n",
  149. " <thead>\n",
  150. " <tr style=\"text-align: right;\">\n",
  151. " <th></th>\n",
  152. " <th>movie_id</th>\n",
  153. " <th>title</th>\n",
  154. " <th>release date</th>\n",
  155. " <th>video release date</th>\n",
  156. " <th>IMDb URL</th>\n",
  157. " <th>unknown</th>\n",
  158. " <th>Action</th>\n",
  159. " <th>Adventure</th>\n",
  160. " <th>Animation</th>\n",
  161. " <th>Children's</th>\n",
  162. " <th>...</th>\n",
  163. " <th>Fantasy</th>\n",
  164. " <th>Film-Noir</th>\n",
  165. " <th>Horror</th>\n",
  166. " <th>Musical</th>\n",
  167. " <th>Mystery</th>\n",
  168. " <th>Romance</th>\n",
  169. " <th>Sci-Fi</th>\n",
  170. " <th>Thriller</th>\n",
  171. " <th>War</th>\n",
  172. " <th>Western</th>\n",
  173. " </tr>\n",
  174. " </thead>\n",
  175. " <tbody>\n",
  176. " <tr>\n",
  177. " <th>0</th>\n",
  178. " <td>1</td>\n",
  179. " <td>Toy Story (1995)</td>\n",
  180. " <td>01-Jan-1995</td>\n",
  181. " <td>NaN</td>\n",
  182. " <td>http://us.imdb.com/M/title-exact?Toy%20Story%2...</td>\n",
  183. " <td>0</td>\n",
  184. " <td>0</td>\n",
  185. " <td>0</td>\n",
  186. " <td>1</td>\n",
  187. " <td>1</td>\n",
  188. " <td>...</td>\n",
  189. " <td>0</td>\n",
  190. " <td>0</td>\n",
  191. " <td>0</td>\n",
  192. " <td>0</td>\n",
  193. " <td>0</td>\n",
  194. " <td>0</td>\n",
  195. " <td>0</td>\n",
  196. " <td>0</td>\n",
  197. " <td>0</td>\n",
  198. " <td>0</td>\n",
  199. " </tr>\n",
  200. " <tr>\n",
  201. " <th>1</th>\n",
  202. " <td>2</td>\n",
  203. " <td>GoldenEye (1995)</td>\n",
  204. " <td>01-Jan-1995</td>\n",
  205. " <td>NaN</td>\n",
  206. " <td>http://us.imdb.com/M/title-exact?GoldenEye%20(...</td>\n",
  207. " <td>0</td>\n",
  208. " <td>1</td>\n",
  209. " <td>1</td>\n",
  210. " <td>0</td>\n",
  211. " <td>0</td>\n",
  212. " <td>...</td>\n",
  213. " <td>0</td>\n",
  214. " <td>0</td>\n",
  215. " <td>0</td>\n",
  216. " <td>0</td>\n",
  217. " <td>0</td>\n",
  218. " <td>0</td>\n",
  219. " <td>0</td>\n",
  220. " <td>1</td>\n",
  221. " <td>0</td>\n",
  222. " <td>0</td>\n",
  223. " </tr>\n",
  224. " <tr>\n",
  225. " <th>2</th>\n",
  226. " <td>3</td>\n",
  227. " <td>Four Rooms (1995)</td>\n",
  228. " <td>01-Jan-1995</td>\n",
  229. " <td>NaN</td>\n",
  230. " <td>http://us.imdb.com/M/title-exact?Four%20Rooms%...</td>\n",
  231. " <td>0</td>\n",
  232. " <td>0</td>\n",
  233. " <td>0</td>\n",
  234. " <td>0</td>\n",
  235. " <td>0</td>\n",
  236. " <td>...</td>\n",
  237. " <td>0</td>\n",
  238. " <td>0</td>\n",
  239. " <td>0</td>\n",
  240. " <td>0</td>\n",
  241. " <td>0</td>\n",
  242. " <td>0</td>\n",
  243. " <td>0</td>\n",
  244. " <td>1</td>\n",
  245. " <td>0</td>\n",
  246. " <td>0</td>\n",
  247. " </tr>\n",
  248. " <tr>\n",
  249. " <th>3</th>\n",
  250. " <td>4</td>\n",
  251. " <td>Get Shorty (1995)</td>\n",
  252. " <td>01-Jan-1995</td>\n",
  253. " <td>NaN</td>\n",
  254. " <td>http://us.imdb.com/M/title-exact?Get%20Shorty%...</td>\n",
  255. " <td>0</td>\n",
  256. " <td>1</td>\n",
  257. " <td>0</td>\n",
  258. " <td>0</td>\n",
  259. " <td>0</td>\n",
  260. " <td>...</td>\n",
  261. " <td>0</td>\n",
  262. " <td>0</td>\n",
  263. " <td>0</td>\n",
  264. " <td>0</td>\n",
  265. " <td>0</td>\n",
  266. " <td>0</td>\n",
  267. " <td>0</td>\n",
  268. " <td>0</td>\n",
  269. " <td>0</td>\n",
  270. " <td>0</td>\n",
  271. " </tr>\n",
  272. " <tr>\n",
  273. " <th>4</th>\n",
  274. " <td>5</td>\n",
  275. " <td>Copycat (1995)</td>\n",
  276. " <td>01-Jan-1995</td>\n",
  277. " <td>NaN</td>\n",
  278. " <td>http://us.imdb.com/M/title-exact?Copycat%20(1995)</td>\n",
  279. " <td>0</td>\n",
  280. " <td>0</td>\n",
  281. " <td>0</td>\n",
  282. " <td>0</td>\n",
  283. " <td>0</td>\n",
  284. " <td>...</td>\n",
  285. " <td>0</td>\n",
  286. " <td>0</td>\n",
  287. " <td>0</td>\n",
  288. " <td>0</td>\n",
  289. " <td>0</td>\n",
  290. " <td>0</td>\n",
  291. " <td>0</td>\n",
  292. " <td>1</td>\n",
  293. " <td>0</td>\n",
  294. " <td>0</td>\n",
  295. " </tr>\n",
  296. " </tbody>\n",
  297. "</table>\n",
  298. "<p>5 rows × 24 columns</p>\n",
  299. "</div>"
  300. ],
  301. "text/plain": [
  302. " movie_id title release date video release date \\\n",
  303. "0 1 Toy Story (1995) 01-Jan-1995 NaN \n",
  304. "1 2 GoldenEye (1995) 01-Jan-1995 NaN \n",
  305. "2 3 Four Rooms (1995) 01-Jan-1995 NaN \n",
  306. "3 4 Get Shorty (1995) 01-Jan-1995 NaN \n",
  307. "4 5 Copycat (1995) 01-Jan-1995 NaN \n",
  308. "\n",
  309. " IMDb URL unknown Action \\\n",
  310. "0 http://us.imdb.com/M/title-exact?Toy%20Story%2... 0 0 \n",
  311. "1 http://us.imdb.com/M/title-exact?GoldenEye%20(... 0 1 \n",
  312. "2 http://us.imdb.com/M/title-exact?Four%20Rooms%... 0 0 \n",
  313. "3 http://us.imdb.com/M/title-exact?Get%20Shorty%... 0 1 \n",
  314. "4 http://us.imdb.com/M/title-exact?Copycat%20(1995) 0 0 \n",
  315. "\n",
  316. " Adventure Animation Children's ... Fantasy Film-Noir Horror \\\n",
  317. "0 0 1 1 ... 0 0 0 \n",
  318. "1 1 0 0 ... 0 0 0 \n",
  319. "2 0 0 0 ... 0 0 0 \n",
  320. "3 0 0 0 ... 0 0 0 \n",
  321. "4 0 0 0 ... 0 0 0 \n",
  322. "\n",
  323. " Musical Mystery Romance Sci-Fi Thriller War Western \n",
  324. "0 0 0 0 0 0 0 0 \n",
  325. "1 0 0 0 0 1 0 0 \n",
  326. "2 0 0 0 0 1 0 0 \n",
  327. "3 0 0 0 0 0 0 0 \n",
  328. "4 0 0 0 0 1 0 0 \n",
  329. "\n",
  330. "[5 rows x 24 columns]"
  331. ]
  332. },
  333. "execution_count": 3,
  334. "metadata": {},
  335. "output_type": "execute_result"
  336. }
  337. ],
  338. "source": [
  339. "#Load the u.item file into a dataframe\n",
  340. "i_cols = ['movie_id', 'title' ,'release date','video release date', 'IMDb URL', 'unknown', 'Action', 'Adventure',\n",
  341. " 'Animation', 'Children\\'s', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Fantasy',\n",
  342. " 'Film-Noir', 'Horror', 'Musical', 'Mystery', 'Romance', 'Sci-Fi', 'Thriller', 'War', 'Western']\n",
  343. "\n",
  344. "movies = pd.read_csv('../data/movielens/u.item', sep='|', names=i_cols, encoding='latin-1')\n",
  345. "\n",
  346. "movies.head()"
  347. ]
  348. },
  349. {
  350. "cell_type": "code",
  351. "execution_count": 4,
  352. "metadata": {
  353. "collapsed": true
  354. },
  355. "outputs": [],
  356. "source": [
  357. "#Remove all information except Movie ID and title\n",
  358. "movies = movies[['movie_id', 'title']]"
  359. ]
  360. },
  361. {
  362. "cell_type": "code",
  363. "execution_count": 5,
  364. "metadata": {},
  365. "outputs": [
  366. {
  367. "data": {
  368. "text/html": [
  369. "<div>\n",
  370. "<style>\n",
  371. " .dataframe thead tr:only-child th {\n",
  372. " text-align: right;\n",
  373. " }\n",
  374. "\n",
  375. " .dataframe thead th {\n",
  376. " text-align: left;\n",
  377. " }\n",
  378. "\n",
  379. " .dataframe tbody tr th {\n",
  380. " vertical-align: top;\n",
  381. " }\n",
  382. "</style>\n",
  383. "<table border=\"1\" class=\"dataframe\">\n",
  384. " <thead>\n",
  385. " <tr style=\"text-align: right;\">\n",
  386. " <th></th>\n",
  387. " <th>user_id</th>\n",
  388. " <th>movie_id</th>\n",
  389. " <th>rating</th>\n",
  390. " <th>timestamp</th>\n",
  391. " </tr>\n",
  392. " </thead>\n",
  393. " <tbody>\n",
  394. " <tr>\n",
  395. " <th>0</th>\n",
  396. " <td>196</td>\n",
  397. " <td>242</td>\n",
  398. " <td>3</td>\n",
  399. " <td>881250949</td>\n",
  400. " </tr>\n",
  401. " <tr>\n",
  402. " <th>1</th>\n",
  403. " <td>186</td>\n",
  404. " <td>302</td>\n",
  405. " <td>3</td>\n",
  406. " <td>891717742</td>\n",
  407. " </tr>\n",
  408. " <tr>\n",
  409. " <th>2</th>\n",
  410. " <td>22</td>\n",
  411. " <td>377</td>\n",
  412. " <td>1</td>\n",
  413. " <td>878887116</td>\n",
  414. " </tr>\n",
  415. " <tr>\n",
  416. " <th>3</th>\n",
  417. " <td>244</td>\n",
  418. " <td>51</td>\n",
  419. " <td>2</td>\n",
  420. " <td>880606923</td>\n",
  421. " </tr>\n",
  422. " <tr>\n",
  423. " <th>4</th>\n",
  424. " <td>166</td>\n",
  425. " <td>346</td>\n",
  426. " <td>1</td>\n",
  427. " <td>886397596</td>\n",
  428. " </tr>\n",
  429. " </tbody>\n",
  430. "</table>\n",
  431. "</div>"
  432. ],
  433. "text/plain": [
  434. " user_id movie_id rating timestamp\n",
  435. "0 196 242 3 881250949\n",
  436. "1 186 302 3 891717742\n",
  437. "2 22 377 1 878887116\n",
  438. "3 244 51 2 880606923\n",
  439. "4 166 346 1 886397596"
  440. ]
  441. },
  442. "execution_count": 5,
  443. "metadata": {},
  444. "output_type": "execute_result"
  445. }
  446. ],
  447. "source": [
  448. "#Load the u.data file into a dataframe\n",
  449. "r_cols = ['user_id', 'movie_id', 'rating', 'timestamp']\n",
  450. "\n",
  451. "ratings = pd.read_csv('../data/movielens/u.data', sep='\\t', names=r_cols,\n",
  452. " encoding='latin-1')\n",
  453. "\n",
  454. "ratings.head()"
  455. ]
  456. },
  457. {
  458. "cell_type": "code",
  459. "execution_count": 6,
  460. "metadata": {
  461. "collapsed": true
  462. },
  463. "outputs": [],
  464. "source": [
  465. "#Drop the timestamp column\n",
  466. "ratings = ratings.drop('timestamp', axis=1)"
  467. ]
  468. },
  469. {
  470. "cell_type": "code",
  471. "execution_count": 7,
  472. "metadata": {
  473. "collapsed": true
  474. },
  475. "outputs": [],
  476. "source": [
  477. "#Import the train_test_split function\n",
  478. "from sklearn.model_selection import train_test_split\n",
  479. "\n",
  480. "#Assign X as the original ratings dataframe and y as the user_id column of ratings.\n",
  481. "X = ratings.copy()\n",
  482. "y = ratings['user_id']\n",
  483. "\n",
  484. "#Split into training and test datasets, stratified along user_id\n",
  485. "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.25, stratify=y, random_state=42)"
  486. ]
  487. },
  488. {
  489. "cell_type": "code",
  490. "execution_count": 8,
  491. "metadata": {
  492. "collapsed": true
  493. },
  494. "outputs": [],
  495. "source": [
  496. "#Import the mean_squared_error function\n",
  497. "from sklearn.metrics import mean_squared_error\n",
  498. "\n",
  499. "#Function that computes the root mean squared error (or RMSE)\n",
  500. "def rmse(y_true, y_pred):\n",
  501. " return np.sqrt(mean_squared_error(y_true, y_pred))"
  502. ]
  503. },
  504. {
  505. "cell_type": "code",
  506. "execution_count": 10,
  507. "metadata": {
  508. "collapsed": true
  509. },
  510. "outputs": [],
  511. "source": [
  512. "#Define the baseline model to always return 3.\n",
  513. "def baseline(user_id, movie_id):\n",
  514. " return 3.0"
  515. ]
  516. },
  517. {
  518. "cell_type": "code",
  519. "execution_count": 11,
  520. "metadata": {
  521. "collapsed": true
  522. },
  523. "outputs": [],
  524. "source": [
  525. "#Function to compute the RMSE score obtained on the testing set by a model\n",
  526. "def score(cf_model):\n",
  527. " \n",
  528. " #Construct a list of user-movie tuples from the testing dataset\n",
  529. " id_pairs = zip(X_test['user_id'], X_test['movie_id'])\n",
  530. " \n",
  531. " #Predict the rating for every user-movie tuple\n",
  532. " y_pred = np.array([cf_model(user, movie) for (user, movie) in id_pairs])\n",
  533. " \n",
  534. " #Extract the actual ratings given by the users in the test data\n",
  535. " y_true = np.array(X_test['rating'])\n",
  536. " \n",
  537. " #Return the final RMSE score\n",
  538. " return rmse(y_true, y_pred)"
  539. ]
  540. },
  541. {
  542. "cell_type": "code",
  543. "execution_count": 12,
  544. "metadata": {},
  545. "outputs": [
  546. {
  547. "data": {
  548. "text/plain": [
  549. "1.2470926188539486"
  550. ]
  551. },
  552. "execution_count": 12,
  553. "metadata": {},
  554. "output_type": "execute_result"
  555. }
  556. ],
  557. "source": [
  558. "score(baseline)"
  559. ]
  560. },
  561. {
  562. "cell_type": "markdown",
  563. "metadata": {
  564. "collapsed": true
  565. },
  566. "source": [
  567. "## User Based Collaborative Filtering\n",
  568. "\n",
  569. "### Ratings Matrix"
  570. ]
  571. },
  572. {
  573. "cell_type": "code",
  574. "execution_count": 58,
  575. "metadata": {},
  576. "outputs": [
  577. {
  578. "data": {
  579. "text/html": [
  580. "<div>\n",
  581. "<style>\n",
  582. " .dataframe thead tr:only-child th {\n",
  583. " text-align: right;\n",
  584. " }\n",
  585. "\n",
  586. " .dataframe thead th {\n",
  587. " text-align: left;\n",
  588. " }\n",
  589. "\n",
  590. " .dataframe tbody tr th {\n",
  591. " vertical-align: top;\n",
  592. " }\n",
  593. "</style>\n",
  594. "<table border=\"1\" class=\"dataframe\">\n",
  595. " <thead>\n",
  596. " <tr style=\"text-align: right;\">\n",
  597. " <th>movie_id</th>\n",
  598. " <th>1</th>\n",
  599. " <th>2</th>\n",
  600. " <th>3</th>\n",
  601. " <th>4</th>\n",
  602. " <th>5</th>\n",
  603. " <th>6</th>\n",
  604. " <th>7</th>\n",
  605. " <th>8</th>\n",
  606. " <th>9</th>\n",
  607. " <th>10</th>\n",
  608. " <th>...</th>\n",
  609. " <th>1669</th>\n",
  610. " <th>1670</th>\n",
  611. " <th>1671</th>\n",
  612. " <th>1673</th>\n",
  613. " <th>1674</th>\n",
  614. " <th>1675</th>\n",
  615. " <th>1676</th>\n",
  616. " <th>1679</th>\n",
  617. " <th>1681</th>\n",
  618. " <th>1682</th>\n",
  619. " </tr>\n",
  620. " <tr>\n",
  621. " <th>user_id</th>\n",
  622. " <th></th>\n",
  623. " <th></th>\n",
  624. " <th></th>\n",
  625. " <th></th>\n",
  626. " <th></th>\n",
  627. " <th></th>\n",
  628. " <th></th>\n",
  629. " <th></th>\n",
  630. " <th></th>\n",
  631. " <th></th>\n",
  632. " <th></th>\n",
  633. " <th></th>\n",
  634. " <th></th>\n",
  635. " <th></th>\n",
  636. " <th></th>\n",
  637. " <th></th>\n",
  638. " <th></th>\n",
  639. " <th></th>\n",
  640. " <th></th>\n",
  641. " <th></th>\n",
  642. " <th></th>\n",
  643. " </tr>\n",
  644. " </thead>\n",
  645. " <tbody>\n",
  646. " <tr>\n",
  647. " <th>1</th>\n",
  648. " <td>5.0</td>\n",
  649. " <td>3.0</td>\n",
  650. " <td>4.0</td>\n",
  651. " <td>3.0</td>\n",
  652. " <td>3.0</td>\n",
  653. " <td>5.0</td>\n",
  654. " <td>4.0</td>\n",
  655. " <td>1.0</td>\n",
  656. " <td>5.0</td>\n",
  657. " <td>3.0</td>\n",
  658. " <td>...</td>\n",
  659. " <td>NaN</td>\n",
  660. " <td>NaN</td>\n",
  661. " <td>NaN</td>\n",
  662. " <td>NaN</td>\n",
  663. " <td>NaN</td>\n",
  664. " <td>NaN</td>\n",
  665. " <td>NaN</td>\n",
  666. " <td>NaN</td>\n",
  667. " <td>NaN</td>\n",
  668. " <td>NaN</td>\n",
  669. " </tr>\n",
  670. " <tr>\n",
  671. " <th>2</th>\n",
  672. " <td>NaN</td>\n",
  673. " <td>NaN</td>\n",
  674. " <td>NaN</td>\n",
  675. " <td>NaN</td>\n",
  676. " <td>NaN</td>\n",
  677. " <td>NaN</td>\n",
  678. " <td>NaN</td>\n",
  679. " <td>NaN</td>\n",
  680. " <td>NaN</td>\n",
  681. " <td>2.0</td>\n",
  682. " <td>...</td>\n",
  683. " <td>NaN</td>\n",
  684. " <td>NaN</td>\n",
  685. " <td>NaN</td>\n",
  686. " <td>NaN</td>\n",
  687. " <td>NaN</td>\n",
  688. " <td>NaN</td>\n",
  689. " <td>NaN</td>\n",
  690. " <td>NaN</td>\n",
  691. " <td>NaN</td>\n",
  692. " <td>NaN</td>\n",
  693. " </tr>\n",
  694. " <tr>\n",
  695. " <th>3</th>\n",
  696. " <td>NaN</td>\n",
  697. " <td>NaN</td>\n",
  698. " <td>NaN</td>\n",
  699. " <td>NaN</td>\n",
  700. " <td>NaN</td>\n",
  701. " <td>NaN</td>\n",
  702. " <td>NaN</td>\n",
  703. " <td>NaN</td>\n",
  704. " <td>NaN</td>\n",
  705. " <td>NaN</td>\n",
  706. " <td>...</td>\n",
  707. " <td>NaN</td>\n",
  708. " <td>NaN</td>\n",
  709. " <td>NaN</td>\n",
  710. " <td>NaN</td>\n",
  711. " <td>NaN</td>\n",
  712. " <td>NaN</td>\n",
  713. " <td>NaN</td>\n",
  714. " <td>NaN</td>\n",
  715. " <td>NaN</td>\n",
  716. " <td>NaN</td>\n",
  717. " </tr>\n",
  718. " <tr>\n",
  719. " <th>4</th>\n",
  720. " <td>NaN</td>\n",
  721. " <td>NaN</td>\n",
  722. " <td>NaN</td>\n",
  723. " <td>NaN</td>\n",
  724. " <td>NaN</td>\n",
  725. " <td>NaN</td>\n",
  726. " <td>NaN</td>\n",
  727. " <td>NaN</td>\n",
  728. " <td>NaN</td>\n",
  729. " <td>NaN</td>\n",
  730. " <td>...</td>\n",
  731. " <td>NaN</td>\n",
  732. " <td>NaN</td>\n",
  733. " <td>NaN</td>\n",
  734. " <td>NaN</td>\n",
  735. " <td>NaN</td>\n",
  736. " <td>NaN</td>\n",
  737. " <td>NaN</td>\n",
  738. " <td>NaN</td>\n",
  739. " <td>NaN</td>\n",
  740. " <td>NaN</td>\n",
  741. " </tr>\n",
  742. " <tr>\n",
  743. " <th>5</th>\n",
  744. " <td>NaN</td>\n",
  745. " <td>3.0</td>\n",
  746. " <td>NaN</td>\n",
  747. " <td>NaN</td>\n",
  748. " <td>NaN</td>\n",
  749. " <td>NaN</td>\n",
  750. " <td>NaN</td>\n",
  751. " <td>NaN</td>\n",
  752. " <td>NaN</td>\n",
  753. " <td>NaN</td>\n",
  754. " <td>...</td>\n",
  755. " <td>NaN</td>\n",
  756. " <td>NaN</td>\n",
  757. " <td>NaN</td>\n",
  758. " <td>NaN</td>\n",
  759. " <td>NaN</td>\n",
  760. " <td>NaN</td>\n",
  761. " <td>NaN</td>\n",
  762. " <td>NaN</td>\n",
  763. " <td>NaN</td>\n",
  764. " <td>NaN</td>\n",
  765. " </tr>\n",
  766. " </tbody>\n",
  767. "</table>\n",
  768. "<p>5 rows × 1647 columns</p>\n",
  769. "</div>"
  770. ],
  771. "text/plain": [
  772. "movie_id 1 2 3 4 5 6 7 8 9 10 ... \\\n",
  773. "user_id ... \n",
  774. "1 5.0 3.0 4.0 3.0 3.0 5.0 4.0 1.0 5.0 3.0 ... \n",
  775. "2 NaN NaN NaN NaN NaN NaN NaN NaN NaN 2.0 ... \n",
  776. "3 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... \n",
  777. "4 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... \n",
  778. "5 NaN 3.0 NaN NaN NaN NaN NaN NaN NaN NaN ... \n",
  779. "\n",
  780. "movie_id 1669 1670 1671 1673 1674 1675 1676 1679 1681 1682 \n",
  781. "user_id \n",
  782. "1 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
  783. "2 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
  784. "3 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
  785. "4 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
  786. "5 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
  787. "\n",
  788. "[5 rows x 1647 columns]"
  789. ]
  790. },
  791. "execution_count": 58,
  792. "metadata": {},
  793. "output_type": "execute_result"
  794. }
  795. ],
  796. "source": [
  797. "#Build the ratings matrix using pivot_table function\n",
  798. "r_matrix = X_train.pivot_table(values='rating', index='user_id', columns='movie_id')\n",
  799. "\n",
  800. "r_matrix.head()"
  801. ]
  802. },
  803. {
  804. "cell_type": "markdown",
  805. "metadata": {},
  806. "source": [
  807. "### Mean"
  808. ]
  809. },
  810. {
  811. "cell_type": "code",
  812. "execution_count": 88,
  813. "metadata": {
  814. "collapsed": true
  815. },
  816. "outputs": [],
  817. "source": [
  818. "#User Based Collaborative Filter using Mean Ratings\n",
  819. "def cf_user_mean(user_id, movie_id):\n",
  820. " \n",
  821. " #Check if movie_id exists in r_matrix\n",
  822. " if movie_id in r_matrix:\n",
  823. " #Compute the mean of all the ratings given to the movie\n",
  824. " mean_rating = r_matrix[movie_id].mean()\n",
  825. " \n",
  826. " else:\n",
  827. " #Default to a rating of 3.0 in the absence of any information\n",
  828. " mean_rating = 3.0\n",
  829. " \n",
  830. " return mean_rating"
  831. ]
  832. },
  833. {
  834. "cell_type": "code",
  835. "execution_count": 89,
  836. "metadata": {},
  837. "outputs": [
  838. {
  839. "data": {
  840. "text/plain": [
  841. "1.0234701463131335"
  842. ]
  843. },
  844. "execution_count": 89,
  845. "metadata": {},
  846. "output_type": "execute_result"
  847. }
  848. ],
  849. "source": [
  850. "#Compute RMSE for the Mean model\n",
  851. "score(cf_user_mean)"
  852. ]
  853. },
  854. {
  855. "cell_type": "markdown",
  856. "metadata": {},
  857. "source": [
  858. "### Weighted Mean"
  859. ]
  860. },
  861. {
  862. "cell_type": "code",
  863. "execution_count": 61,
  864. "metadata": {
  865. "collapsed": true
  866. },
  867. "outputs": [],
  868. "source": [
  869. "#Create a dummy ratings matrix with all null values imputed to 0\n",
  870. "r_matrix_dummy = r_matrix.copy().fillna(0)"
  871. ]
  872. },
  873. {
  874. "cell_type": "code",
  875. "execution_count": 62,
  876. "metadata": {
  877. "collapsed": true
  878. },
  879. "outputs": [],
  880. "source": [
  881. "# Import cosine_score \n",
  882. "from sklearn.metrics.pairwise import cosine_similarity\n",
  883. "\n",
  884. "#Compute the cosine similarity matrix using the dummy ratings matrix\n",
  885. "cosine_sim = cosine_similarity(r_matrix_dummy, r_matrix_dummy)"
  886. ]
  887. },
  888. {
  889. "cell_type": "code",
  890. "execution_count": 63,
  891. "metadata": {},
  892. "outputs": [
  893. {
  894. "data": {
  895. "text/html": [
  896. "<div>\n",
  897. "<style>\n",
  898. " .dataframe thead tr:only-child th {\n",
  899. " text-align: right;\n",
  900. " }\n",
  901. "\n",
  902. " .dataframe thead th {\n",
  903. " text-align: left;\n",
  904. " }\n",
  905. "\n",
  906. " .dataframe tbody tr th {\n",
  907. " vertical-align: top;\n",
  908. " }\n",
  909. "</style>\n",
  910. "<table border=\"1\" class=\"dataframe\">\n",
  911. " <thead>\n",
  912. " <tr style=\"text-align: right;\">\n",
  913. " <th>user_id</th>\n",
  914. " <th>1</th>\n",
  915. " <th>2</th>\n",
  916. " <th>3</th>\n",
  917. " <th>4</th>\n",
  918. " <th>5</th>\n",
  919. " <th>6</th>\n",
  920. " <th>7</th>\n",
  921. " <th>8</th>\n",
  922. " <th>9</th>\n",
  923. " <th>10</th>\n",
  924. " <th>...</th>\n",
  925. " <th>934</th>\n",
  926. " <th>935</th>\n",
  927. " <th>936</th>\n",
  928. " <th>937</th>\n",
  929. " <th>938</th>\n",
  930. " <th>939</th>\n",
  931. " <th>940</th>\n",
  932. " <th>941</th>\n",
  933. " <th>942</th>\n",
  934. " <th>943</th>\n",
  935. " </tr>\n",
  936. " <tr>\n",
  937. " <th>user_id</th>\n",
  938. " <th></th>\n",
  939. " <th></th>\n",
  940. " <th></th>\n",
  941. " <th></th>\n",
  942. " <th></th>\n",
  943. " <th></th>\n",
  944. " <th></th>\n",
  945. " <th></th>\n",
  946. " <th></th>\n",
  947. " <th></th>\n",
  948. " <th></th>\n",
  949. " <th></th>\n",
  950. " <th></th>\n",
  951. " <th></th>\n",
  952. " <th></th>\n",
  953. " <th></th>\n",
  954. " <th></th>\n",
  955. " <th></th>\n",
  956. " <th></th>\n",
  957. " <th></th>\n",
  958. " <th></th>\n",
  959. " </tr>\n",
  960. " </thead>\n",
  961. " <tbody>\n",
  962. " <tr>\n",
  963. " <th>1</th>\n",
  964. " <td>1.000000</td>\n",
  965. " <td>0.118076</td>\n",
  966. " <td>0.029097</td>\n",
  967. " <td>0.011628</td>\n",
  968. " <td>0.264677</td>\n",
  969. " <td>0.312419</td>\n",
  970. " <td>0.308729</td>\n",
  971. " <td>0.224269</td>\n",
  972. " <td>0.026017</td>\n",
  973. " <td>0.286411</td>\n",
  974. " <td>...</td>\n",
  975. " <td>0.308475</td>\n",
  976. " <td>0.055872</td>\n",
  977. " <td>0.197862</td>\n",
  978. " <td>0.131367</td>\n",
  979. " <td>0.152449</td>\n",
  980. " <td>0.084456</td>\n",
  981. " <td>0.293293</td>\n",
  982. " <td>0.056765</td>\n",
  983. " <td>0.103536</td>\n",
  984. " <td>0.326491</td>\n",
  985. " </tr>\n",
  986. " <tr>\n",
  987. " <th>2</th>\n",
  988. " <td>0.118076</td>\n",
  989. " <td>1.000000</td>\n",
  990. " <td>0.099097</td>\n",
  991. " <td>0.107680</td>\n",
  992. " <td>0.034279</td>\n",
  993. " <td>0.152789</td>\n",
  994. " <td>0.086705</td>\n",
  995. " <td>0.078864</td>\n",
  996. " <td>0.068940</td>\n",
  997. " <td>0.092399</td>\n",
  998. " <td>...</td>\n",
  999. " <td>0.086927</td>\n",
  1000. " <td>0.259636</td>\n",
  1001. " <td>0.289092</td>\n",
  1002. " <td>0.318824</td>\n",
  1003. " <td>0.149105</td>\n",
  1004. " <td>0.186347</td>\n",
  1005. " <td>0.168034</td>\n",
  1006. " <td>0.106748</td>\n",
  1007. " <td>0.136796</td>\n",
  1008. " <td>0.080358</td>\n",
  1009. " </tr>\n",
  1010. " <tr>\n",
  1011. " <th>3</th>\n",
  1012. " <td>0.029097</td>\n",
  1013. " <td>0.099097</td>\n",
  1014. " <td>1.000000</td>\n",
  1015. " <td>0.252131</td>\n",
  1016. " <td>0.026893</td>\n",
  1017. " <td>0.062539</td>\n",
  1018. " <td>0.039767</td>\n",
  1019. " <td>0.089474</td>\n",
  1020. " <td>0.078162</td>\n",
  1021. " <td>0.037670</td>\n",
  1022. " <td>...</td>\n",
  1023. " <td>0.040918</td>\n",
  1024. " <td>0.019031</td>\n",
  1025. " <td>0.065417</td>\n",
  1026. " <td>0.055373</td>\n",
  1027. " <td>0.086503</td>\n",
  1028. " <td>0.018418</td>\n",
  1029. " <td>0.096993</td>\n",
  1030. " <td>0.109631</td>\n",
  1031. " <td>0.092574</td>\n",
  1032. " <td>0.018987</td>\n",
  1033. " </tr>\n",
  1034. " <tr>\n",
  1035. " <th>4</th>\n",
  1036. " <td>0.011628</td>\n",
  1037. " <td>0.107680</td>\n",
  1038. " <td>0.252131</td>\n",
  1039. " <td>1.000000</td>\n",
  1040. " <td>0.000000</td>\n",
  1041. " <td>0.045543</td>\n",
  1042. " <td>0.078812</td>\n",
  1043. " <td>0.095354</td>\n",
  1044. " <td>0.059498</td>\n",
  1045. " <td>0.053879</td>\n",
  1046. " <td>...</td>\n",
  1047. " <td>0.024226</td>\n",
  1048. " <td>0.050703</td>\n",
  1049. " <td>0.056561</td>\n",
  1050. " <td>0.107294</td>\n",
  1051. " <td>0.098892</td>\n",
  1052. " <td>0.000000</td>\n",
  1053. " <td>0.132900</td>\n",
  1054. " <td>0.142798</td>\n",
  1055. " <td>0.097066</td>\n",
  1056. " <td>0.015176</td>\n",
  1057. " </tr>\n",
  1058. " <tr>\n",
  1059. " <th>5</th>\n",
  1060. " <td>0.264677</td>\n",
  1061. " <td>0.034279</td>\n",
  1062. " <td>0.026893</td>\n",
  1063. " <td>0.000000</td>\n",
  1064. " <td>1.000000</td>\n",
  1065. " <td>0.202843</td>\n",
  1066. " <td>0.299619</td>\n",
  1067. " <td>0.163724</td>\n",
  1068. " <td>0.038474</td>\n",
  1069. " <td>0.153021</td>\n",
  1070. " <td>...</td>\n",
  1071. " <td>0.262547</td>\n",
  1072. " <td>0.048524</td>\n",
  1073. " <td>0.048312</td>\n",
  1074. " <td>0.022202</td>\n",
  1075. " <td>0.091910</td>\n",
  1076. " <td>0.066000</td>\n",
  1077. " <td>0.156172</td>\n",
  1078. " <td>0.115842</td>\n",
  1079. " <td>0.124297</td>\n",
  1080. " <td>0.267574</td>\n",
  1081. " </tr>\n",
  1082. " <tr>\n",
  1083. " <th>6</th>\n",
  1084. " <td>0.312419</td>\n",
  1085. " <td>0.152789</td>\n",
  1086. " <td>0.062539</td>\n",
  1087. " <td>0.045543</td>\n",
  1088. " <td>0.202843</td>\n",
  1089. " <td>1.000000</td>\n",
  1090. " <td>0.375963</td>\n",
  1091. " <td>0.131795</td>\n",
  1092. " <td>0.110944</td>\n",
  1093. " <td>0.400758</td>\n",
  1094. " <td>...</td>\n",
  1095. " <td>0.287549</td>\n",
  1096. " <td>0.080312</td>\n",
  1097. " <td>0.162988</td>\n",
  1098. " <td>0.182856</td>\n",
  1099. " <td>0.114262</td>\n",
  1100. " <td>0.092090</td>\n",
  1101. " <td>0.261859</td>\n",
  1102. " <td>0.097606</td>\n",
  1103. " <td>0.206104</td>\n",
  1104. " <td>0.187637</td>\n",
  1105. " </tr>\n",
  1106. " <tr>\n",
  1107. " <th>7</th>\n",
  1108. " <td>0.308729</td>\n",
  1109. " <td>0.086705</td>\n",
  1110. " <td>0.039767</td>\n",
  1111. " <td>0.078812</td>\n",
  1112. " <td>0.299619</td>\n",
  1113. " <td>0.375963</td>\n",
  1114. " <td>1.000000</td>\n",
  1115. " <td>0.211282</td>\n",
  1116. " <td>0.107795</td>\n",
  1117. " <td>0.328923</td>\n",
  1118. " <td>...</td>\n",
  1119. " <td>0.290002</td>\n",
  1120. " <td>0.074170</td>\n",
  1121. " <td>0.094619</td>\n",
  1122. " <td>0.084235</td>\n",
  1123. " <td>0.115620</td>\n",
  1124. " <td>0.100625</td>\n",
  1125. " <td>0.233843</td>\n",
  1126. " <td>0.039199</td>\n",
  1127. " <td>0.224227</td>\n",
  1128. " <td>0.296332</td>\n",
  1129. " </tr>\n",
  1130. " <tr>\n",
  1131. " <th>8</th>\n",
  1132. " <td>0.224269</td>\n",
  1133. " <td>0.078864</td>\n",
  1134. " <td>0.089474</td>\n",
  1135. " <td>0.095354</td>\n",
  1136. " <td>0.163724</td>\n",
  1137. " <td>0.131795</td>\n",
  1138. " <td>0.211282</td>\n",
  1139. " <td>1.000000</td>\n",
  1140. " <td>0.037040</td>\n",
  1141. " <td>0.183375</td>\n",
  1142. " <td>...</td>\n",
  1143. " <td>0.165008</td>\n",
  1144. " <td>0.066843</td>\n",
  1145. " <td>0.058766</td>\n",
  1146. " <td>0.068759</td>\n",
  1147. " <td>0.087159</td>\n",
  1148. " <td>0.129381</td>\n",
  1149. " <td>0.188662</td>\n",
  1150. " <td>0.121223</td>\n",
  1151. " <td>0.083910</td>\n",
  1152. " <td>0.273238</td>\n",
  1153. " </tr>\n",
  1154. " <tr>\n",
  1155. " <th>9</th>\n",
  1156. " <td>0.026017</td>\n",
  1157. " <td>0.068940</td>\n",
  1158. " <td>0.078162</td>\n",
  1159. " <td>0.059498</td>\n",
  1160. " <td>0.038474</td>\n",
  1161. " <td>0.110944</td>\n",
  1162. " <td>0.107795</td>\n",
  1163. " <td>0.037040</td>\n",
  1164. " <td>1.000000</td>\n",
  1165. " <td>0.155435</td>\n",
  1166. " <td>...</td>\n",
  1167. " <td>0.011708</td>\n",
  1168. " <td>0.000000</td>\n",
  1169. " <td>0.101710</td>\n",
  1170. " <td>0.034568</td>\n",
  1171. " <td>0.045002</td>\n",
  1172. " <td>0.052699</td>\n",
  1173. " <td>0.107486</td>\n",
  1174. " <td>0.055766</td>\n",
  1175. " <td>0.070065</td>\n",
  1176. " <td>0.088281</td>\n",
  1177. " </tr>\n",
  1178. " <tr>\n",
  1179. " <th>10</th>\n",
  1180. " <td>0.286411</td>\n",
  1181. " <td>0.092399</td>\n",
  1182. " <td>0.037670</td>\n",
  1183. " <td>0.053879</td>\n",
  1184. " <td>0.153021</td>\n",
  1185. " <td>0.400758</td>\n",
  1186. " <td>0.328923</td>\n",
  1187. " <td>0.183375</td>\n",
  1188. " <td>0.155435</td>\n",
  1189. " <td>1.000000</td>\n",
  1190. " <td>...</td>\n",
  1191. " <td>0.278558</td>\n",
  1192. " <td>0.049310</td>\n",
  1193. " <td>0.153506</td>\n",
  1194. " <td>0.065471</td>\n",
  1195. " <td>0.060088</td>\n",
  1196. " <td>0.033686</td>\n",
  1197. " <td>0.197107</td>\n",
  1198. " <td>0.085402</td>\n",
  1199. " <td>0.118945</td>\n",
  1200. " <td>0.162538</td>\n",
  1201. " </tr>\n",
  1202. " </tbody>\n",
  1203. "</table>\n",
  1204. "<p>10 rows × 943 columns</p>\n",
  1205. "</div>"
  1206. ],
  1207. "text/plain": [
  1208. "user_id 1 2 3 4 5 6 7 \\\n",
  1209. "user_id \n",
  1210. "1 1.000000 0.118076 0.029097 0.011628 0.264677 0.312419 0.308729 \n",
  1211. "2 0.118076 1.000000 0.099097 0.107680 0.034279 0.152789 0.086705 \n",
  1212. "3 0.029097 0.099097 1.000000 0.252131 0.026893 0.062539 0.039767 \n",
  1213. "4 0.011628 0.107680 0.252131 1.000000 0.000000 0.045543 0.078812 \n",
  1214. "5 0.264677 0.034279 0.026893 0.000000 1.000000 0.202843 0.299619 \n",
  1215. "6 0.312419 0.152789 0.062539 0.045543 0.202843 1.000000 0.375963 \n",
  1216. "7 0.308729 0.086705 0.039767 0.078812 0.299619 0.375963 1.000000 \n",
  1217. "8 0.224269 0.078864 0.089474 0.095354 0.163724 0.131795 0.211282 \n",
  1218. "9 0.026017 0.068940 0.078162 0.059498 0.038474 0.110944 0.107795 \n",
  1219. "10 0.286411 0.092399 0.037670 0.053879 0.153021 0.400758 0.328923 \n",
  1220. "\n",
  1221. "user_id 8 9 10 ... 934 935 936 \\\n",
  1222. "user_id ... \n",
  1223. "1 0.224269 0.026017 0.286411 ... 0.308475 0.055872 0.197862 \n",
  1224. "2 0.078864 0.068940 0.092399 ... 0.086927 0.259636 0.289092 \n",
  1225. "3 0.089474 0.078162 0.037670 ... 0.040918 0.019031 0.065417 \n",
  1226. "4 0.095354 0.059498 0.053879 ... 0.024226 0.050703 0.056561 \n",
  1227. "5 0.163724 0.038474 0.153021 ... 0.262547 0.048524 0.048312 \n",
  1228. "6 0.131795 0.110944 0.400758 ... 0.287549 0.080312 0.162988 \n",
  1229. "7 0.211282 0.107795 0.328923 ... 0.290002 0.074170 0.094619 \n",
  1230. "8 1.000000 0.037040 0.183375 ... 0.165008 0.066843 0.058766 \n",
  1231. "9 0.037040 1.000000 0.155435 ... 0.011708 0.000000 0.101710 \n",
  1232. "10 0.183375 0.155435 1.000000 ... 0.278558 0.049310 0.153506 \n",
  1233. "\n",
  1234. "user_id 937 938 939 940 941 942 943 \n",
  1235. "user_id \n",
  1236. "1 0.131367 0.152449 0.084456 0.293293 0.056765 0.103536 0.326491 \n",
  1237. "2 0.318824 0.149105 0.186347 0.168034 0.106748 0.136796 0.080358 \n",
  1238. "3 0.055373 0.086503 0.018418 0.096993 0.109631 0.092574 0.018987 \n",
  1239. "4 0.107294 0.098892 0.000000 0.132900 0.142798 0.097066 0.015176 \n",
  1240. "5 0.022202 0.091910 0.066000 0.156172 0.115842 0.124297 0.267574 \n",
  1241. "6 0.182856 0.114262 0.092090 0.261859 0.097606 0.206104 0.187637 \n",
  1242. "7 0.084235 0.115620 0.100625 0.233843 0.039199 0.224227 0.296332 \n",
  1243. "8 0.068759 0.087159 0.129381 0.188662 0.121223 0.083910 0.273238 \n",
  1244. "9 0.034568 0.045002 0.052699 0.107486 0.055766 0.070065 0.088281 \n",
  1245. "10 0.065471 0.060088 0.033686 0.197107 0.085402 0.118945 0.162538 \n",
  1246. "\n",
  1247. "[10 rows x 943 columns]"
  1248. ]
  1249. },
  1250. "execution_count": 63,
  1251. "metadata": {},
  1252. "output_type": "execute_result"
  1253. }
  1254. ],
  1255. "source": [
  1256. "#Convert into pandas dataframe \n",
  1257. "cosine_sim = pd.DataFrame(cosine_sim, index=r_matrix.index, columns=r_matrix.index)\n",
  1258. "\n",
  1259. "cosine_sim.head(10)"
  1260. ]
  1261. },
  1262. {
  1263. "cell_type": "code",
  1264. "execution_count": 140,
  1265. "metadata": {
  1266. "collapsed": true
  1267. },
  1268. "outputs": [],
  1269. "source": [
  1270. "#User Based Collaborative Filter using Weighted Mean Ratings\n",
  1271. "def cf_user_wmean(user_id, movie_id):\n",
  1272. " \n",
  1273. " #Check if movie_id exists in r_matrix\n",
  1274. " if movie_id in r_matrix:\n",
  1275. " \n",
  1276. " #Get the similarity scores for the user in question with every other user\n",
  1277. " sim_scores = cosine_sim[user_id]\n",
  1278. " \n",
  1279. " #Get the user ratings for the movie in question\n",
  1280. " m_ratings = r_matrix[movie_id]\n",
  1281. " \n",
  1282. " #Extract the indices containing NaN in the m_ratings series\n",
  1283. " idx = m_ratings[m_ratings.isnull()].index\n",
  1284. " \n",
  1285. " #Drop the NaN values from the m_ratings Series\n",
  1286. " m_ratings = m_ratings.dropna()\n",
  1287. " \n",
  1288. " #Drop the corresponding cosine scores from the sim_scores series\n",
  1289. " sim_scores = sim_scores.drop(idx)\n",
  1290. " \n",
  1291. " #Compute the final weighted mean\n",
  1292. " wmean_rating = np.dot(sim_scores, m_ratings)/ sim_scores.sum()\n",
  1293. " \n",
  1294. " else:\n",
  1295. " #Default to a rating of 3.0 in the absence of any information\n",
  1296. " wmean_rating = 3.0\n",
  1297. " \n",
  1298. " return wmean_rating"
  1299. ]
  1300. },
  1301. {
  1302. "cell_type": "code",
  1303. "execution_count": 139,
  1304. "metadata": {},
  1305. "outputs": [
  1306. {
  1307. "data": {
  1308. "text/plain": [
  1309. "1.0174483808407588"
  1310. ]
  1311. },
  1312. "execution_count": 139,
  1313. "metadata": {},
  1314. "output_type": "execute_result"
  1315. }
  1316. ],
  1317. "source": [
  1318. "score(cf_user_wmean)"
  1319. ]
  1320. },
  1321. {
  1322. "cell_type": "markdown",
  1323. "metadata": {},
  1324. "source": [
  1325. "### Demographics"
  1326. ]
  1327. },
  1328. {
  1329. "cell_type": "code",
  1330. "execution_count": 145,
  1331. "metadata": {},
  1332. "outputs": [
  1333. {
  1334. "data": {
  1335. "text/html": [
  1336. "<div>\n",
  1337. "<style>\n",
  1338. " .dataframe thead tr:only-child th {\n",
  1339. " text-align: right;\n",
  1340. " }\n",
  1341. "\n",
  1342. " .dataframe thead th {\n",
  1343. " text-align: left;\n",
  1344. " }\n",
  1345. "\n",
  1346. " .dataframe tbody tr th {\n",
  1347. " vertical-align: top;\n",
  1348. " }\n",
  1349. "</style>\n",
  1350. "<table border=\"1\" class=\"dataframe\">\n",
  1351. " <thead>\n",
  1352. " <tr style=\"text-align: right;\">\n",
  1353. " <th></th>\n",
  1354. " <th>user_id</th>\n",
  1355. " <th>movie_id</th>\n",
  1356. " <th>rating</th>\n",
  1357. " <th>age</th>\n",
  1358. " <th>sex</th>\n",
  1359. " <th>occupation</th>\n",
  1360. " <th>zip_code</th>\n",
  1361. " </tr>\n",
  1362. " </thead>\n",
  1363. " <tbody>\n",
  1364. " <tr>\n",
  1365. " <th>0</th>\n",
  1366. " <td>889</td>\n",
  1367. " <td>684</td>\n",
  1368. " <td>2</td>\n",
  1369. " <td>24</td>\n",
  1370. " <td>M</td>\n",
  1371. " <td>technician</td>\n",
  1372. " <td>78704</td>\n",
  1373. " </tr>\n",
  1374. " <tr>\n",
  1375. " <th>1</th>\n",
  1376. " <td>889</td>\n",
  1377. " <td>279</td>\n",
  1378. " <td>2</td>\n",
  1379. " <td>24</td>\n",
  1380. " <td>M</td>\n",
  1381. " <td>technician</td>\n",
  1382. " <td>78704</td>\n",
  1383. " </tr>\n",
  1384. " <tr>\n",
  1385. " <th>2</th>\n",
  1386. " <td>889</td>\n",
  1387. " <td>29</td>\n",
  1388. " <td>3</td>\n",
  1389. " <td>24</td>\n",
  1390. " <td>M</td>\n",
  1391. " <td>technician</td>\n",
  1392. " <td>78704</td>\n",
  1393. " </tr>\n",
  1394. " <tr>\n",
  1395. " <th>3</th>\n",
  1396. " <td>889</td>\n",
  1397. " <td>190</td>\n",
  1398. " <td>3</td>\n",
  1399. " <td>24</td>\n",
  1400. " <td>M</td>\n",
  1401. " <td>technician</td>\n",
  1402. " <td>78704</td>\n",
  1403. " </tr>\n",
  1404. " <tr>\n",
  1405. " <th>4</th>\n",
  1406. " <td>889</td>\n",
  1407. " <td>232</td>\n",
  1408. " <td>3</td>\n",
  1409. " <td>24</td>\n",
  1410. " <td>M</td>\n",
  1411. " <td>technician</td>\n",
  1412. " <td>78704</td>\n",
  1413. " </tr>\n",
  1414. " </tbody>\n",
  1415. "</table>\n",
  1416. "</div>"
  1417. ],
  1418. "text/plain": [
  1419. " user_id movie_id rating age sex occupation zip_code\n",
  1420. "0 889 684 2 24 M technician 78704\n",
  1421. "1 889 279 2 24 M technician 78704\n",
  1422. "2 889 29 3 24 M technician 78704\n",
  1423. "3 889 190 3 24 M technician 78704\n",
  1424. "4 889 232 3 24 M technician 78704"
  1425. ]
  1426. },
  1427. "execution_count": 145,
  1428. "metadata": {},
  1429. "output_type": "execute_result"
  1430. }
  1431. ],
  1432. "source": [
  1433. "#Merge the original users dataframe with the training set \n",
  1434. "merged_df = pd.merge(X_train, users)\n",
  1435. "\n",
  1436. "merged_df.head()"
  1437. ]
  1438. },
  1439. {
  1440. "cell_type": "code",
  1441. "execution_count": 150,
  1442. "metadata": {},
  1443. "outputs": [
  1444. {
  1445. "data": {
  1446. "text/plain": [
  1447. "sex\n",
  1448. "F 3.827586\n",
  1449. "M 3.918919\n",
  1450. "Name: rating, dtype: float64"
  1451. ]
  1452. },
  1453. "execution_count": 150,
  1454. "metadata": {},
  1455. "output_type": "execute_result"
  1456. }
  1457. ],
  1458. "source": [
  1459. "#Compute the mean rating of every movie by gender\n",
  1460. "gender_mean = merged_df[['movie_id', 'sex', 'rating']].groupby(['movie_id', 'sex'])['rating'].mean()"
  1461. ]
  1462. },
  1463. {
  1464. "cell_type": "code",
  1465. "execution_count": null,
  1466. "metadata": {
  1467. "collapsed": true
  1468. },
  1469. "outputs": [],
  1470. "source": [
  1471. "#Set the index of the users dataframe to the user_id\n",
  1472. "users = users.set_index('user_id')"
  1473. ]
  1474. },
  1475. {
  1476. "cell_type": "code",
  1477. "execution_count": 165,
  1478. "metadata": {
  1479. "collapsed": true
  1480. },
  1481. "outputs": [],
  1482. "source": [
  1483. "#Gender Based Collaborative Filter using Mean Ratings\n",
  1484. "def cf_gender(user_id, movie_id):\n",
  1485. " \n",
  1486. " #Check if movie_id exists in r_matrix (or training set)\n",
  1487. " if movie_id in r_matrix:\n",
  1488. " #Identify the gender of the user\n",
  1489. " gender = users.loc[user_id]['sex']\n",
  1490. " \n",
  1491. " #Check if the gender has rated the movie\n",
  1492. " if gender in gender_mean[movie_id]:\n",
  1493. " \n",
  1494. " #Compute the mean rating given by that gender to the movie\n",
  1495. " gender_rating = gender_mean[movie_id][gender]\n",
  1496. " \n",
  1497. " else:\n",
  1498. " gender_rating = 3.0\n",
  1499. " \n",
  1500. " else:\n",
  1501. " #Default to a rating of 3.0 in the absence of any information\n",
  1502. " gender_rating = 3.0\n",
  1503. " \n",
  1504. " return gender_rating"
  1505. ]
  1506. },
  1507. {
  1508. "cell_type": "code",
  1509. "execution_count": 166,
  1510. "metadata": {},
  1511. "outputs": [
  1512. {
  1513. "name": "stderr",
  1514. "output_type": "stream",
  1515. "text": [
  1516. "/usr/local/lib/python3.6/site-packages/pandas/core/indexes/multi.py:819: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
  1517. " return self._engine.get_value(s, k)\n"
  1518. ]
  1519. },
  1520. {
  1521. "data": {
  1522. "text/plain": [
  1523. "1.0330308800874282"
  1524. ]
  1525. },
  1526. "execution_count": 166,
  1527. "metadata": {},
  1528. "output_type": "execute_result"
  1529. }
  1530. ],
  1531. "source": [
  1532. "score(cf_gender)"
  1533. ]
  1534. },
  1535. {
  1536. "cell_type": "code",
  1537. "execution_count": 174,
  1538. "metadata": {},
  1539. "outputs": [
  1540. {
  1541. "data": {
  1542. "text/html": [
  1543. "<div>\n",
  1544. "<style>\n",
  1545. " .dataframe thead tr:only-child th {\n",
  1546. " text-align: right;\n",
  1547. " }\n",
  1548. "\n",
  1549. " .dataframe thead th {\n",
  1550. " text-align: left;\n",
  1551. " }\n",
  1552. "\n",
  1553. " .dataframe tbody tr th {\n",
  1554. " vertical-align: top;\n",
  1555. " }\n",
  1556. "</style>\n",
  1557. "<table border=\"1\" class=\"dataframe\">\n",
  1558. " <thead>\n",
  1559. " <tr>\n",
  1560. " <th>occupation</th>\n",
  1561. " <th colspan=\"2\" halign=\"left\">administrator</th>\n",
  1562. " <th colspan=\"2\" halign=\"left\">artist</th>\n",
  1563. " <th>doctor</th>\n",
  1564. " <th colspan=\"2\" halign=\"left\">educator</th>\n",
  1565. " <th colspan=\"2\" halign=\"left\">engineer</th>\n",
  1566. " <th>entertainment</th>\n",
  1567. " <th>...</th>\n",
  1568. " <th colspan=\"2\" halign=\"left\">salesman</th>\n",
  1569. " <th colspan=\"2\" halign=\"left\">scientist</th>\n",
  1570. " <th colspan=\"2\" halign=\"left\">student</th>\n",
  1571. " <th colspan=\"2\" halign=\"left\">technician</th>\n",
  1572. " <th colspan=\"2\" halign=\"left\">writer</th>\n",
  1573. " </tr>\n",
  1574. " <tr>\n",
  1575. " <th>sex</th>\n",
  1576. " <th>F</th>\n",
  1577. " <th>M</th>\n",
  1578. " <th>F</th>\n",
  1579. " <th>M</th>\n",
  1580. " <th>M</th>\n",
  1581. " <th>F</th>\n",
  1582. " <th>M</th>\n",
  1583. " <th>F</th>\n",
  1584. " <th>M</th>\n",
  1585. " <th>F</th>\n",
  1586. " <th>...</th>\n",
  1587. " <th>F</th>\n",
  1588. " <th>M</th>\n",
  1589. " <th>F</th>\n",
  1590. " <th>M</th>\n",
  1591. " <th>F</th>\n",
  1592. " <th>M</th>\n",
  1593. " <th>F</th>\n",
  1594. " <th>M</th>\n",
  1595. " <th>F</th>\n",
  1596. " <th>M</th>\n",
  1597. " </tr>\n",
  1598. " <tr>\n",
  1599. " <th>movie_id</th>\n",
  1600. " <th></th>\n",
  1601. " <th></th>\n",
  1602. " <th></th>\n",
  1603. " <th></th>\n",
  1604. " <th></th>\n",
  1605. " <th></th>\n",
  1606. " <th></th>\n",
  1607. " <th></th>\n",
  1608. " <th></th>\n",
  1609. " <th></th>\n",
  1610. " <th></th>\n",
  1611. " <th></th>\n",
  1612. " <th></th>\n",
  1613. " <th></th>\n",
  1614. " <th></th>\n",
  1615. " <th></th>\n",
  1616. " <th></th>\n",
  1617. " <th></th>\n",
  1618. " <th></th>\n",
  1619. " <th></th>\n",
  1620. " <th></th>\n",
  1621. " </tr>\n",
  1622. " </thead>\n",
  1623. " <tbody>\n",
  1624. " <tr>\n",
  1625. " <th>1</th>\n",
  1626. " <td>4.0</td>\n",
  1627. " <td>4.222222</td>\n",
  1628. " <td>4.25</td>\n",
  1629. " <td>3.500000</td>\n",
  1630. " <td>3.666667</td>\n",
  1631. " <td>3.50</td>\n",
  1632. " <td>3.923077</td>\n",
  1633. " <td>4.0</td>\n",
  1634. " <td>3.970588</td>\n",
  1635. " <td>5.0</td>\n",
  1636. " <td>...</td>\n",
  1637. " <td>4.0</td>\n",
  1638. " <td>4.000000</td>\n",
  1639. " <td>3.5</td>\n",
  1640. " <td>3.888889</td>\n",
  1641. " <td>3.833333</td>\n",
  1642. " <td>3.709091</td>\n",
  1643. " <td>4.0</td>\n",
  1644. " <td>4.200000</td>\n",
  1645. " <td>4.166667</td>\n",
  1646. " <td>3.142857</td>\n",
  1647. " </tr>\n",
  1648. " <tr>\n",
  1649. " <th>2</th>\n",
  1650. " <td>3.0</td>\n",
  1651. " <td>3.750000</td>\n",
  1652. " <td>NaN</td>\n",
  1653. " <td>NaN</td>\n",
  1654. " <td>NaN</td>\n",
  1655. " <td>NaN</td>\n",
  1656. " <td>3.250000</td>\n",
  1657. " <td>NaN</td>\n",
  1658. " <td>3.363636</td>\n",
  1659. " <td>NaN</td>\n",
  1660. " <td>...</td>\n",
  1661. " <td>NaN</td>\n",
  1662. " <td>NaN</td>\n",
  1663. " <td>NaN</td>\n",
  1664. " <td>NaN</td>\n",
  1665. " <td>2.333333</td>\n",
  1666. " <td>3.333333</td>\n",
  1667. " <td>NaN</td>\n",
  1668. " <td>2.714286</td>\n",
  1669. " <td>5.000000</td>\n",
  1670. " <td>2.666667</td>\n",
  1671. " </tr>\n",
  1672. " <tr>\n",
  1673. " <th>3</th>\n",
  1674. " <td>3.5</td>\n",
  1675. " <td>2.500000</td>\n",
  1676. " <td>NaN</td>\n",
  1677. " <td>NaN</td>\n",
  1678. " <td>NaN</td>\n",
  1679. " <td>4.00</td>\n",
  1680. " <td>2.500000</td>\n",
  1681. " <td>NaN</td>\n",
  1682. " <td>3.625000</td>\n",
  1683. " <td>NaN</td>\n",
  1684. " <td>...</td>\n",
  1685. " <td>NaN</td>\n",
  1686. " <td>1.000000</td>\n",
  1687. " <td>NaN</td>\n",
  1688. " <td>NaN</td>\n",
  1689. " <td>2.000000</td>\n",
  1690. " <td>3.217391</td>\n",
  1691. " <td>NaN</td>\n",
  1692. " <td>4.000000</td>\n",
  1693. " <td>NaN</td>\n",
  1694. " <td>1.000000</td>\n",
  1695. " </tr>\n",
  1696. " <tr>\n",
  1697. " <th>4</th>\n",
  1698. " <td>3.0</td>\n",
  1699. " <td>3.888889</td>\n",
  1700. " <td>NaN</td>\n",
  1701. " <td>4.666667</td>\n",
  1702. " <td>3.000000</td>\n",
  1703. " <td>2.75</td>\n",
  1704. " <td>3.636364</td>\n",
  1705. " <td>NaN</td>\n",
  1706. " <td>3.555556</td>\n",
  1707. " <td>NaN</td>\n",
  1708. " <td>...</td>\n",
  1709. " <td>4.0</td>\n",
  1710. " <td>3.666667</td>\n",
  1711. " <td>NaN</td>\n",
  1712. " <td>3.600000</td>\n",
  1713. " <td>3.285714</td>\n",
  1714. " <td>3.724138</td>\n",
  1715. " <td>NaN</td>\n",
  1716. " <td>3.200000</td>\n",
  1717. " <td>4.250000</td>\n",
  1718. " <td>3.500000</td>\n",
  1719. " </tr>\n",
  1720. " <tr>\n",
  1721. " <th>5</th>\n",
  1722. " <td>4.0</td>\n",
  1723. " <td>2.333333</td>\n",
  1724. " <td>NaN</td>\n",
  1725. " <td>NaN</td>\n",
  1726. " <td>NaN</td>\n",
  1727. " <td>4.00</td>\n",
  1728. " <td>1.500000</td>\n",
  1729. " <td>NaN</td>\n",
  1730. " <td>2.666667</td>\n",
  1731. " <td>NaN</td>\n",
  1732. " <td>...</td>\n",
  1733. " <td>NaN</td>\n",
  1734. " <td>NaN</td>\n",
  1735. " <td>NaN</td>\n",
  1736. " <td>3.500000</td>\n",
  1737. " <td>4.333333</td>\n",
  1738. " <td>3.272727</td>\n",
  1739. " <td>NaN</td>\n",
  1740. " <td>3.333333</td>\n",
  1741. " <td>4.000000</td>\n",
  1742. " <td>2.666667</td>\n",
  1743. " </tr>\n",
  1744. " </tbody>\n",
  1745. "</table>\n",
  1746. "<p>5 rows × 41 columns</p>\n",
  1747. "</div>"
  1748. ],
  1749. "text/plain": [
  1750. "occupation administrator artist doctor educator \\\n",
  1751. "sex F M F M M F \n",
  1752. "movie_id \n",
  1753. "1 4.0 4.222222 4.25 3.500000 3.666667 3.50 \n",
  1754. "2 3.0 3.750000 NaN NaN NaN NaN \n",
  1755. "3 3.5 2.500000 NaN NaN NaN 4.00 \n",
  1756. "4 3.0 3.888889 NaN 4.666667 3.000000 2.75 \n",
  1757. "5 4.0 2.333333 NaN NaN NaN 4.00 \n",
  1758. "\n",
  1759. "occupation engineer entertainment ... salesman \\\n",
  1760. "sex M F M F ... F \n",
  1761. "movie_id ... \n",
  1762. "1 3.923077 4.0 3.970588 5.0 ... 4.0 \n",
  1763. "2 3.250000 NaN 3.363636 NaN ... NaN \n",
  1764. "3 2.500000 NaN 3.625000 NaN ... NaN \n",
  1765. "4 3.636364 NaN 3.555556 NaN ... 4.0 \n",
  1766. "5 1.500000 NaN 2.666667 NaN ... NaN \n",
  1767. "\n",
  1768. "occupation scientist student technician \\\n",
  1769. "sex M F M F M F \n",
  1770. "movie_id \n",
  1771. "1 4.000000 3.5 3.888889 3.833333 3.709091 4.0 \n",
  1772. "2 NaN NaN NaN 2.333333 3.333333 NaN \n",
  1773. "3 1.000000 NaN NaN 2.000000 3.217391 NaN \n",
  1774. "4 3.666667 NaN 3.600000 3.285714 3.724138 NaN \n",
  1775. "5 NaN NaN 3.500000 4.333333 3.272727 NaN \n",
  1776. "\n",
  1777. "occupation writer \n",
  1778. "sex M F M \n",
  1779. "movie_id \n",
  1780. "1 4.200000 4.166667 3.142857 \n",
  1781. "2 2.714286 5.000000 2.666667 \n",
  1782. "3 4.000000 NaN 1.000000 \n",
  1783. "4 3.200000 4.250000 3.500000 \n",
  1784. "5 3.333333 4.000000 2.666667 \n",
  1785. "\n",
  1786. "[5 rows x 41 columns]"
  1787. ]
  1788. },
  1789. "execution_count": 174,
  1790. "metadata": {},
  1791. "output_type": "execute_result"
  1792. }
  1793. ],
  1794. "source": [
  1795. "#Compute the mean rating by gender and occupation\n",
  1796. "gen_occ_mean = merged_df[['sex', 'rating', 'movie_id', 'occupation']].pivot_table(\n",
  1797. " values='rating', index='movie_id', columns=['occupation', 'sex'], aggfunc='mean')\n",
  1798. "\n",
  1799. "gen_occ_mean.head()"
  1800. ]
  1801. },
  1802. {
  1803. "cell_type": "code",
  1804. "execution_count": 198,
  1805. "metadata": {
  1806. "collapsed": true
  1807. },
  1808. "outputs": [],
  1809. "source": [
  1810. "#Gender and Occupation Based Collaborative Filter using Mean Ratings\n",
  1811. "def cf_gen_occ(user_id, movie_id):\n",
  1812. " \n",
  1813. " #Check if movie_id exists in gen_occ_mean\n",
  1814. " if movie_id in gen_occ_mean.index:\n",
  1815. " \n",
  1816. " #Identify the user\n",
  1817. " user = users.loc[user_id]\n",
  1818. " \n",
  1819. " #Identify the gender and occupation\n",
  1820. " gender = user['sex']\n",
  1821. " occ = user['occupation']\n",
  1822. " \n",
  1823. " #Check if the occupation has rated the movie\n",
  1824. " if occ in gen_occ_mean.loc[movie_id]:\n",
  1825. " \n",
  1826. " #Check if the gender has rated the movie\n",
  1827. " if gender in gen_occ_mean.loc[movie_id][occ]:\n",
  1828. " \n",
  1829. " #Extract the required rating\n",
  1830. " rating = gen_occ_mean.loc[movie_id][occ][gender]\n",
  1831. " \n",
  1832. " #Default to 3.0 if the rating is null\n",
  1833. " if np.isnan(rating):\n",
  1834. " rating = 3.0\n",
  1835. " \n",
  1836. " return rating\n",
  1837. " \n",
  1838. " #Return the default rating \n",
  1839. " return 3.0"
  1840. ]
  1841. },
  1842. {
  1843. "cell_type": "code",
  1844. "execution_count": 199,
  1845. "metadata": {},
  1846. "outputs": [
  1847. {
  1848. "data": {
  1849. "text/plain": [
  1850. "1.1391976012043645"
  1851. ]
  1852. },
  1853. "execution_count": 199,
  1854. "metadata": {},
  1855. "output_type": "execute_result"
  1856. }
  1857. ],
  1858. "source": [
  1859. "score(cf_gen_occ)"
  1860. ]
  1861. },
  1862. {
  1863. "cell_type": "markdown",
  1864. "metadata": {},
  1865. "source": [
  1866. "## Model Based Approaches"
  1867. ]
  1868. },
  1869. {
  1870. "cell_type": "code",
  1871. "execution_count": 231,
  1872. "metadata": {},
  1873. "outputs": [
  1874. {
  1875. "name": "stdout",
  1876. "output_type": "stream",
  1877. "text": [
  1878. "Evaluating RMSE of algorithm KNNBasic.\n",
  1879. "\n",
  1880. "------------\n",
  1881. "Fold 1\n",
  1882. "Computing the msd similarity matrix...\n",
  1883. "Done computing similarity matrix.\n",
  1884. "RMSE: 0.9776\n",
  1885. "------------\n",
  1886. "Fold 2\n",
  1887. "Computing the msd similarity matrix...\n",
  1888. "Done computing similarity matrix.\n",
  1889. "RMSE: 0.9789\n",
  1890. "------------\n",
  1891. "Fold 3\n",
  1892. "Computing the msd similarity matrix...\n",
  1893. "Done computing similarity matrix.\n",
  1894. "RMSE: 0.9695\n",
  1895. "------------\n",
  1896. "Fold 4\n",
  1897. "Computing the msd similarity matrix...\n",
  1898. "Done computing similarity matrix.\n",
  1899. "RMSE: 0.9810\n",
  1900. "------------\n",
  1901. "Fold 5\n",
  1902. "Computing the msd similarity matrix...\n",
  1903. "Done computing similarity matrix.\n",
  1904. "RMSE: 0.9849\n",
  1905. "------------\n",
  1906. "------------\n",
  1907. "Mean RMSE: 0.9784\n",
  1908. "------------\n",
  1909. "------------\n"
  1910. ]
  1911. },
  1912. {
  1913. "data": {
  1914. "text/plain": [
  1915. "CaseInsensitiveDefaultDict(list,\n",
  1916. " {'rmse': [0.97764007686097709,\n",
  1917. " 0.97889035204999741,\n",
  1918. " 0.9694859699934969,\n",
  1919. " 0.98099811511904433,\n",
  1920. " 0.98488926832497381]})"
  1921. ]
  1922. },
  1923. "execution_count": 231,
  1924. "metadata": {},
  1925. "output_type": "execute_result"
  1926. }
  1927. ],
  1928. "source": [
  1929. "#Import the required classes and methods from the surprise library\n",
  1930. "from surprise import Reader, Dataset, KNNBasic, evaluate\n",
  1931. "\n",
  1932. "#Define a Reader object\n",
  1933. "#The Reader object helps in parsing the file or dataframe containing ratings\n",
  1934. "reader = Reader()\n",
  1935. "\n",
  1936. "#Create the dataset to be used for building the filter\n",
  1937. "data = Dataset.load_from_df(ratings, reader)\n",
  1938. "\n",
  1939. "#Define the algorithm object; in this case kNN\n",
  1940. "knn = KNNBasic()\n",
  1941. "\n",
  1942. "#Evaluate the performance in terms of RMSE\n",
  1943. "evaluate(knn, data, measures=['RMSE'])"
  1944. ]
  1945. },
  1946. {
  1947. "cell_type": "code",
  1948. "execution_count": 232,
  1949. "metadata": {},
  1950. "outputs": [
  1951. {
  1952. "name": "stdout",
  1953. "output_type": "stream",
  1954. "text": [
  1955. "Evaluating RMSE of algorithm SVD.\n",
  1956. "\n",
  1957. "------------\n",
  1958. "Fold 1\n",
  1959. "RMSE: 0.9371\n",
  1960. "------------\n",
  1961. "Fold 2\n",
  1962. "RMSE: 0.9417\n",
  1963. "------------\n",
  1964. "Fold 3\n",
  1965. "RMSE: 0.9289\n",
  1966. "------------\n",
  1967. "Fold 4\n",
  1968. "RMSE: 0.9379\n",
  1969. "------------\n",
  1970. "Fold 5\n",
  1971. "RMSE: 0.9379\n",
  1972. "------------\n",
  1973. "------------\n",
  1974. "Mean RMSE: 0.9367\n",
  1975. "------------\n",
  1976. "------------\n"
  1977. ]
  1978. },
  1979. {
  1980. "data": {
  1981. "text/plain": [
  1982. "CaseInsensitiveDefaultDict(list,\n",
  1983. " {'rmse': [0.93714337825960081,\n",
  1984. " 0.9417378198331483,\n",
  1985. " 0.92893737314257874,\n",
  1986. " 0.93793761103739881,\n",
  1987. " 0.93789928866069328]})"
  1988. ]
  1989. },
  1990. "execution_count": 232,
  1991. "metadata": {},
  1992. "output_type": "execute_result"
  1993. }
  1994. ],
  1995. "source": [
  1996. "#Import SVD\n",
  1997. "from surprise import SVD\n",
  1998. "\n",
  1999. "#Define the SVD algorithm object\n",
  2000. "svd = SVD()\n",
  2001. "\n",
  2002. "#Evaluate the performance in terms of RMSE\n",
  2003. "evaluate(svd, data, measures=['RMSE'])"
  2004. ]
  2005. },
  2006. {
  2007. "cell_type": "code",
  2008. "execution_count": null,
  2009. "metadata": {
  2010. "collapsed": true
  2011. },
  2012. "outputs": [],
  2013. "source": []
  2014. }
  2015. ],
  2016. "metadata": {
  2017. "kernelspec": {
  2018. "display_name": "Python 3",
  2019. "language": "python",
  2020. "name": "python3"
  2021. },
  2022. "language_info": {
  2023. "codemirror_mode": {
  2024. "name": "ipython",
  2025. "version": 3
  2026. },
  2027. "file_extension": ".py",
  2028. "mimetype": "text/x-python",
  2029. "name": "python",
  2030. "nbconvert_exporter": "python",
  2031. "pygments_lexer": "ipython3",
  2032. "version": "3.6.0"
  2033. }
  2034. },
  2035. "nbformat": 4,
  2036. "nbformat_minor": 2
  2037. }